GeeksforGeeks App
Open App
Browser
Continue

# Strassen’s Matrix Multiplication Algorithm | Implementation

The Strassen’s method of matrix multiplication is a typical divide and conquer algorithm. We have discussed Strassen’s Algorithm here. However, let’s get again on what’s behind the divide and conquer approach and implement it.

Prerequisite: It is required to see this post before further understanding.

Implementation

## C++

 `// CPP program to implement Strassen’s Matrix``// Multiplication Algorithm``#include ``using` `namespace` `std;``typedef` `long` `long` `lld;` `/* Strassen's Algorithm for matrix multiplication``Complexity: O(n^2.808) */` `inline` `lld** MatrixMultiply(lld** a, lld** b, ``int` `n,``                                    ``int` `l, ``int` `m)``{``    ``lld** c = ``new` `lld*[n];``    ``for` `(``int` `i = 0; i < n; i++)``        ``c[i] = ``new` `lld[m];` `    ``for` `(``int` `i = 0; i < n; i++) {``        ``for` `(``int` `j = 0; j < m; j++) {``            ``c[i][j] = 0;``            ``for` `(``int` `k = 0; k < l; k++) {``                ``c[i][j] += a[i][k] * b[k][j];``            ``}``        ``}``    ``}``    ``return` `c;``}` `inline` `lld** Strassen(lld** a, lld** b, ``int` `n,``                                ``int` `l, ``int` `m)``{``    ``if` `(n == 1 || l == 1 || m == 1)``        ``return` `MatrixMultiply(a, b, n, l, m);` `    ``lld** c = ``new` `lld*[n];``    ``for` `(``int` `i = 0; i < n; i++)``        ``c[i] = ``new` `lld[m];` `    ``int` `adjN = (n >> 1) + (n & 1);``    ``int` `adjL = (l >> 1) + (l & 1);``    ``int` `adjM = (m >> 1) + (m & 1);` `    ``lld**** As = ``new` `lld***[2];``    ``for` `(``int` `x = 0; x < 2; x++) {``        ``As[x] = ``new` `lld**[2];``        ``for` `(``int` `y = 0; y < 2; y++) {``            ``As[x][y] = ``new` `lld*[adjN];``            ``for` `(``int` `i = 0; i < adjN; i++) {``                ``As[x][y][i] = ``new` `lld[adjL];``                ``for` `(``int` `j = 0; j < adjL; j++) {``                    ``int` `I = i + (x & 1) * adjN;``                    ``int` `J = j + (y & 1) * adjL;``                    ``As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;``                ``}``            ``}``        ``}``    ``}` `    ``lld**** Bs = ``new` `lld***[2];``    ``for` `(``int` `x = 0; x < 2; x++) {``        ``Bs[x] = ``new` `lld**[2];``        ``for` `(``int` `y = 0; y < 2; y++) {``            ``Bs[x][y] = ``new` `lld*[adjN];``            ``for` `(``int` `i = 0; i < adjL; i++) {``                ``Bs[x][y][i] = ``new` `lld[adjM];``                ``for` `(``int` `j = 0; j < adjM; j++) {``                    ``int` `I = i + (x & 1) * adjL;``                    ``int` `J = j + (y & 1) * adjM;``                    ``Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;``                ``}``            ``}``        ``}``    ``}` `    ``lld*** s = ``new` `lld**[10];``    ``for` `(``int` `i = 0; i < 10; i++) {``        ``switch` `(i) {``        ``case` `0:``            ``s[i] = ``new` `lld*[adjL];``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``s[i][j] = ``new` `lld[adjM];``                ``for` `(``int` `k = 0; k < adjM; k++) {``                    ``s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `1:``            ``s[i] = ``new` `lld*[adjN];``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``s[i][j] = ``new` `lld[adjL];``                ``for` `(``int` `k = 0; k < adjL; k++) {``                    ``s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `2:``            ``s[i] = ``new` `lld*[adjN];``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``s[i][j] = ``new` `lld[adjL];``                ``for` `(``int` `k = 0; k < adjL; k++) {``                    ``s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `3:``            ``s[i] = ``new` `lld*[adjL];``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``s[i][j] = ``new` `lld[adjM];``                ``for` `(``int` `k = 0; k < adjM; k++) {``                    ``s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];``                ``}``            ``}``            ``break``;``        ``case` `4:``            ``s[i] = ``new` `lld*[adjN];``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``s[i][j] = ``new` `lld[adjL];``                ``for` `(``int` `k = 0; k < adjL; k++) {``                    ``s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `5:``            ``s[i] = ``new` `lld*[adjL];``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``s[i][j] = ``new` `lld[adjM];``                ``for` `(``int` `k = 0; k < adjM; k++) {``                    ``s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `6:``            ``s[i] = ``new` `lld*[adjN];``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``s[i][j] = ``new` `lld[adjL];``                ``for` `(``int` `k = 0; k < adjL; k++) {``                    ``s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `7:``            ``s[i] = ``new` `lld*[adjL];``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``s[i][j] = ``new` `lld[adjM];``                ``for` `(``int` `k = 0; k < adjM; k++) {``                    ``s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];``                ``}``            ``}``            ``break``;``        ``case` `8:``            ``s[i] = ``new` `lld*[adjN];``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``s[i][j] = ``new` `lld[adjL];``                ``for` `(``int` `k = 0; k < adjL; k++) {``                    ``s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];``                ``}``            ``}``            ``break``;``        ``case` `9:``            ``s[i] = ``new` `lld*[adjL];``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``s[i][j] = ``new` `lld[adjM];``                ``for` `(``int` `k = 0; k < adjM; k++) {``                    ``s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];``                ``}``            ``}``            ``break``;``        ``}``    ``}` `    ``lld*** p = ``new` `lld**[7];``    ``p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);``    ``p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);``    ``p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);``    ``p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);``    ``p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);``    ``p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);``    ``p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);` `    ``for` `(``int` `i = 0; i < adjN; i++) {``        ``for` `(``int` `j = 0; j < adjM; j++) {``            ``c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];``            ``if` `(j + adjM < m)``                ``c[i][j + adjM] = p[0][i][j] + p[1][i][j];``            ``if` `(i + adjN < n)``                ``c[i + adjN][j] = p[2][i][j] + p[3][i][j];``            ``if` `(i + adjN < n && j + adjM < m)``                ``c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];``        ``}``    ``}` `    ``for` `(``int` `x = 0; x < 2; x++) {``        ``for` `(``int` `y = 0; y < 2; y++) {``            ``for` `(``int` `i = 0; i < adjN; i++) {``                ``delete``[] As[x][y][i];``            ``}``            ``delete``[] As[x][y];``        ``}``        ``delete``[] As[x];``    ``}``    ``delete``[] As;` `    ``for` `(``int` `x = 0; x < 2; x++) {``        ``for` `(``int` `y = 0; y < 2; y++) {``            ``for` `(``int` `i = 0; i < adjL; i++) {``                ``delete``[] Bs[x][y][i];``            ``}``            ``delete``[] Bs[x][y];``        ``}``        ``delete``[] Bs[x];``    ``}``    ``delete``[] Bs;` `    ``for` `(``int` `i = 0; i < 10; i++) {``        ``switch` `(i) {``        ``case` `0:``        ``case` `3:``        ``case` `5:``        ``case` `7:``        ``case` `9:``            ``for` `(``int` `j = 0; j < adjL; j++) {``                ``delete``[] s[i][j];``            ``}``            ``break``;``        ``case` `1:``        ``case` `2:``        ``case` `4:``        ``case` `6:``        ``case` `8:``            ``for` `(``int` `j = 0; j < adjN; j++) {``                ``delete``[] s[i][j];``            ``}``            ``break``;``        ``}``        ``delete``[] s[i];``    ``}``    ``delete``[] s;` `    ``for` `(``int` `i = 0; i < 7; i++) {``        ``for` `(``int` `j = 0; j < (n >> 1); j++) {``            ``delete``[] p[i][j];``        ``}``        ``delete``[] p[i];``    ``}``    ``delete``[] p;` `    ``return` `c;``}` `int` `main()``{``    ``lld** matA;``    ``matA = ``new` `lld*[2];``    ``for` `(``int` `i = 0; i < 2; i++)``        ``matA[i] = ``new` `lld[3];``    ``matA[0][0] = 1;``    ``matA[0][1] = 2;``    ``matA[0][2] = 3;``    ``matA[1][0] = 4;``    ``matA[1][1] = 5;``    ``matA[1][2] = 6;` `    ``lld** matB;``    ``matB = ``new` `lld*[3];``    ``for` `(``int` `i = 0; i < 3; i++)``        ``matB[i] = ``new` `lld[2];``    ``matB[0][0] = 7;``    ``matB[0][1] = 8;``    ``matB[1][0] = 9;``    ``matB[1][1] = 10;``    ``matB[2][0] = 11;``    ``matB[2][1] = 12;` `    ``lld** matC = Strassen(matA, matB, 2, 3, 2);``    ``for` `(``int` `i = 0; i < 2; i++) {``        ``for` `(``int` `j = 0; j < 2; j++) {``            ``printf``(``"%lld "``, matC[i][j]);``        ``}``        ``printf``(``"\n"``);``    ``}` `    ``return` `0;``}`

Output

```58 64
139 154 ```

Time Complexity : O(n ^2.808), the algorithm first checks if the size of the matrices is 1, and if so, returns the result of a standard matrix multiplication. Otherwise, it divides the matrices into 4 submatrices and performs 7 matrix multiplications recursively. Finally, it combines the results of the multiplications to obtain the final result.

Space complexity : O(n^2)

My Personal Notes arrow_drop_up