Open In App

# Easy way to remember Strassen’s Matrix Equation

Strassen’s matrix is a Divide and Conquer method that helps us to multiply two matrices(of size n X n).

You can refer to the link, for having the knowledge about Strassen’s Matrix first : Divide and Conquer | Set 5 (Strassen’s Matrix Multiplication) But this method needs to cram few equations, so I’ll tell you the simplest way to remember those :

You just need to remember 4 Rules :

• AHED (Learn it as ‘Ahead’)
• Diagonal
• Last CR
• First CR

Also, consider X as (Row +) and Y as (Column -) matrix Follow the Steps :

• Write P1 = A; P2 = H; P3 = E; P4 = D
• For P5 we will use Diagonal Rule i.e. (Sum the Diagonal Elements Of Matrix X ) * (Sum the Diagonal Elements Of Matrix Y ), we get P5 = (A + D)* (E + H)
• For P6 we will use Last CR Rule i.e. Last Column of X and Last Row of Y and remember that Row+ and Column- so i.e. (B – D) * (G + H), we get P6 = (B – D) * (G + H)
• For P7 we will use First CR Rule i.e. First Column of X and First Row of Y and remember that Row+ and Column- so i.e. (A – C) * (E + F), we get P7 = (A – C) * (E + F)
• Come Back to P1 : we have A there and it’s adjacent element in Y Matrix is E, since Y is Column Matrix so we select a column in Y such that E won’t come, we find F H Column, so multiply A with (F – H) So, finally P1 = A * (F – H)
• Come Back to P2 : we have H there and it’s adjacent element in X Matrix is D, since X is Row Matrix so we select a Row in X such that D won’t come, we find A B Column, so multiply H with (A + B) So, finally P2 = (A + B) * H
• Come Back to P3 : we have E there and it’s adjacent element in X Matrix is A, since X is Row Matrix so we select a Row in X such that A won’t come, we find C D Column, so multiply E with (C + D) So, finally P3 = (C + D) * E
• Come Back to P4 : we have D there and it’s adjacent element in Y Matrix is H, since Y is Column Matrix so we select a column in Y such that H won’t come, we find G E Column, so multiply D with (G – E) So, finally P4 = D * (G – E)
• Remember Counting : Write P1 + P2 at C2
• Write P3 + P4 at its diagonal Position i.e. at C3
• Write P4 + P5 + P6 at 1st position and subtract P2 i.e. C1 = P4 + P5 + P6 – P2
• Write odd values at last Position with alternating – and + sign i.e. P1 P3 P5 P7 becomes C4 = P1 – P3 + P5 – P7

Implementation:

## C++

 `#include ``#include ``#define vi vector``#define vii vector``using` `namespace` `std;``/* finding next square of 2*/``int` `nextPowerOf2(``int` `k)``{``    ``return` `pow``(2, ``int``(``ceil``(log2(k))));``}``// printing matrix``void` `display(vii C, ``int` `m, ``int` `n)``{``    ``for` `(``int` `i = 0; i < m; i++)``    ``{``        ``cout << ``"|"``             ``<< ``" "``;``        ``for` `(``int` `j = 0; j < n; j++)``        ``{``            ``cout << C[i][j] << ``" "``;``        ``}``        ``cout << ``"|"` `<< endl;``    ``}``}``//! addition and subtraction``void` `add(vii &A, vii &B, vii &C, ``int` `size)``{``    ``for` `(``int` `i = 0; i < size; i++)``    ``{``        ``for` `(``int` `j = 0; j < size; j++)``        ``{``            ``C[i][j] = A[i][j] + B[i][j];``        ``}``    ``}``}``void` `sub(vii &A, vii &B, vii &C, ``int` `size)``{``    ``for` `(``int` `i = 0; i < size; i++)``    ``{``        ``for` `(``int` `j = 0; j < size; j++)``        ``{``            ``C[i][j] = A[i][j] - B[i][j];``        ``}``    ``}``}``//!-----------------------------``void` `Strassen_algorithm(vii &A, vii &B, vii &C, ``int` `size)``{``    ``if` `(size == 1)``    ``{``        ``C[0][0] = A[0][0] * B[0][0];``        ``return``;``    ``}``    ``else``    ``{``        ``int` `newSize = size / 2;``        ``vi z(newSize);``        ``vii a(newSize, z), b(newSize, z), c(newSize, z), d(newSize, z),``            ``e(newSize, z), f(newSize, z), g(newSize, z), h(newSize, z),``            ``c11(newSize, z), c12(newSize, z), c21(newSize, z), c22(newSize, z),``            ``p1(newSize, z), p2(newSize, z), p3(newSize, z), p4(newSize, z),``            ``p5(newSize, z), p6(newSize, z), p7(newSize, z), fResult(newSize, z),``            ``sResult(newSize, z);``        ``int` `i, j;` `        ``//! divide the matrix in equal parts``        ``for` `(i = 0; i < newSize; i++)``        ``{``            ``for` `(j = 0; j < newSize; j++)``            ``{``                ``a[i][j] = A[i][j];``                ``b[i][j] = A[i][j + newSize];``                ``c[i][j] = A[i + newSize][j];``                ``d[i][j] = A[i + newSize][j + newSize];` `                ``e[i][j] = B[i][j];``                ``f[i][j] = B[i][j + newSize];``                ``g[i][j] = B[i + newSize][j];``                ``h[i][j] = B[i + newSize][j + newSize];``            ``}``        ``}``        ``/*``             ``A         B           C``           ``[a b]   * [e f]   =  [c11 c12]``           ``     [g h]      [c21 c22]``           ``p1,p2,p3,p4=AHED for this: A:Row(+) and B:Column(-)``           ``p5=Diagonal :both +ve``           ``p6=Last CR  :A:Row(-) B:Column(+)``           ``p7=First CR :A:Row(-) B:Column(+)``        ``*/``        ``//! calculating all strassen formulas``        ``//*p1=a*(f-h)``        ``sub(f, h, sResult, newSize);``        ``Strassen_algorithm(a, sResult, p1, newSize);` `        ``//*p2=h*(a+b)``        ``add(a, b, fResult, newSize);``        ``Strassen_algorithm(fResult, h, p2, newSize);` `        ``//*p3=e*(c+d)``        ``add(c, d, fResult, newSize);``        ``Strassen_algorithm(fResult, e, p3, newSize);` `        ``//*p4=d*(g-e)``        ``sub(g, e, sResult, newSize);``        ``Strassen_algorithm(d, sResult, p4, newSize);` `        ``//*p5=(a+d)*(e+h)``        ``add(a, d, fResult, newSize);``        ``add(e, h, sResult, newSize);``        ``Strassen_algorithm(fResult, sResult, p5, newSize);` `        ``//*p6=(b-d)*(g+h)``        ``sub(b, d, fResult, newSize);``        ``add(g, h, sResult, newSize);``        ``Strassen_algorithm(fResult, sResult, p6, newSize);` `        ``//*p7=(a-c)*(e+f)``        ``sub(a, c, fResult, newSize);``        ``add(e, f, sResult, newSize);``        ``Strassen_algorithm(fResult, sResult, p7, newSize);` `        ``/* calculating all elements of C by p1,p2,p3``        ``c11=p4+p5+p6-p2``        ``c12=p1+p2``        ``c21=p3+p4``        ``c22=p1-p3+p5-p7``        ``*/``        ``add(p1, p2, c12, newSize); ``//!``        ``add(p3, p4, c21, newSize); ``//!` `        ``add(p4, p5, fResult, newSize);``        ``add(fResult, p6, sResult, newSize);``        ``sub(sResult, p2, c11, newSize); ``//!` `        ``sub(p1, p3, fResult, newSize);``        ``add(fResult, p5, sResult, newSize);``        ``sub(sResult, p7, c22, newSize); ``//!` `        ``// Grouping the results obtained in a single matrix:``        ``for` `(i = 0; i < newSize; i++)``        ``{``            ``for` `(j = 0; j < newSize; j++)``            ``{``                ``C[i][j] = c11[i][j];``                ``C[i][j + newSize] = c12[i][j];``                ``C[i + newSize][j] = c21[i][j];``                ``C[i + newSize][j + newSize] = c22[i][j];``            ``}``        ``}``    ``}``}``/*for converting matrix to square matrix*/``void` `ConvertToSquareMat(vii &A, vii &B, ``int` `r1, ``int` `c1, ``int` `r2, ``int` `c2)``{``    ``int` `maxSize = max({r1, c1, r2, c2});``    ``int` `size = nextPowerOf2(maxSize);` `    ``vi z(size);``    ``vii Aa(size, z), Bb(size, z), Cc(size, z);` `    ``for` `(unsigned ``int` `i = 0; i < r1; i++)``    ``{``        ``for` `(unsigned ``int` `j = 0; j < c1; j++)``        ``{``            ``Aa[i][j] = A[i][j];``        ``}``    ``}``    ``for` `(unsigned ``int` `i = 0; i < r2; i++)``    ``{``        ``for` `(unsigned ``int` `j = 0; j < c2; j++)``        ``{``            ``Bb[i][j] = B[i][j];``        ``}``    ``}``    ``Strassen_algorithm(Aa, Bb, Cc, size);``    ``vi temp1(c2);``    ``vii C(r1, temp1);``    ``for` `(unsigned ``int` `i = 0; i < r1; i++)``    ``{``        ``for` `(unsigned ``int` `j = 0; j < c2; j++)``        ``{``            ``C[i][j] = Cc[i][j];``        ``}``    ``}``    ``display(C, r1, c1);``}``int` `main()``{``    ``vii a = {``        ``{1, 2, 3},``        ``{1, 2, 3},``        ``{0, 0, 2}};``    ``vii b = {``        ``{1, 0, 0},``        ``{0, 1, 0},``        ``{0, 0, 1}};``    ``ConvertToSquareMat(a, b, 3, 3, 3, 3); ``// A[][],B[][],R1,C1,R2,C2``    ``return` `0;``}`

## Python3

 `# Python equivalent of the above code``import` `math` `#  finding next square of 2``def` `nextPowerOf2(k):``    ``return` `pow``(``2``, ``int``(math.ceil(math.log2(k))))` `# printing matrix``def` `display(C, m, n):``    ``for` `i ``in` `range``(m):``        ``print``(``'| '``, end``=``'')``        ``for` `j ``in` `range``(n):``            ``print``(C[i][j], end``=``' '``)``        ``print``(``'|'``)` `# addition and subtraction``def` `add(A, B, C, size):``    ``for` `i ``in` `range``(size):``        ``for` `j ``in` `range``(size):``            ``C[i][j] ``=` `A[i][j] ``+` `B[i][j]` `def` `sub(A, B, C, size):``    ``for` `i ``in` `range``(size):``        ``for` `j ``in` `range``(size):``            ``C[i][j] ``=` `A[i][j] ``-` `B[i][j]` `# Strassen algorithm``def` `Strassen_algorithm(A, B, C, size):``    ``if` `size ``=``=` `1``:``        ``C[``0``][``0``] ``=` `A[``0``][``0``] ``*` `B[``0``][``0``]``        ``return``    ``else``:``        ``newSize ``=` `size ``/``/` `2``        ``a ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``b ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``c ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``d ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``e ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``f ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``g ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``h ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``c11 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``c12 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``c21 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``c22 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p1 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p2 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p3 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p4 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p5 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p6 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``p7 ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``fResult ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ``sResult ``=` `[[``0` `for` `j ``in` `range``(newSize)] ``for` `i ``in` `range``(newSize)]``        ` `        ``# divide the matrix in equal parts``        ``for` `i ``in` `range``(newSize):``            ``for` `j ``in` `range``(newSize):``                ``a[i][j] ``=` `A[i][j]``                ``b[i][j] ``=` `A[i][j ``+` `newSize]``                ``c[i][j] ``=` `A[i ``+` `newSize][j]``                ``d[i][j] ``=` `A[i ``+` `newSize][j ``+` `newSize]` `                ``e[i][j] ``=` `B[i][j]``                ``f[i][j] ``=` `B[i][j ``+` `newSize]``                ``g[i][j] ``=` `B[i ``+` `newSize][j]``                ``h[i][j] ``=` `B[i ``+` `newSize][j ``+` `newSize]``        ` `        ``# calculating all strassen formulas``        ``#*p1=a*(f-h)``        ``sub(f, h, sResult, newSize)``        ``Strassen_algorithm(a, sResult, p1, newSize)` `        ``#*p2=h*(a+b)``        ``add(a, b, fResult, newSize)``        ``Strassen_algorithm(fResult, h, p2, newSize)` `        ``#*p3=e*(c+d)``        ``add(c, d, fResult, newSize)``        ``Strassen_algorithm(fResult, e, p3, newSize)` `        ``#*p4=d*(g-e)``        ``sub(g, e, sResult, newSize)``        ``Strassen_algorithm(d, sResult, p4, newSize)` `        ``#*p5=(a+d)*(e+h)``        ``add(a, d, fResult, newSize)``        ``add(e, h, sResult, newSize)``        ``Strassen_algorithm(fResult, sResult, p5, newSize)` `        ``#*p6=(b-d)*(g+h)``        ``sub(b, d, fResult, newSize)``        ``add(g, h, sResult, newSize)``        ``Strassen_algorithm(fResult, sResult, p6, newSize)` `        ``#*p7=(a-c)*(e+f)``        ``sub(a, c, fResult, newSize)``        ``add(e, f, sResult, newSize)``        ``Strassen_algorithm(fResult, sResult, p7, newSize)``        ` `        ``# calculating all elements of C by p1,p2,p3``        ``# c11=p4+p5+p6-p2``        ``add(p1, p2, c12, newSize)``        ``add(p3, p4, c21, newSize)` `        ``add(p4, p5, fResult, newSize)``        ``add(fResult, p6, sResult, newSize)``        ``sub(sResult, p2, c11, newSize)` `        ``# c12=p1+p2``        ``# c21=p3+p4``        ``# c22=p1-p3+p5-p7``        ``sub(p1, p3, fResult, newSize)``        ``add(fResult, p5, sResult, newSize)``        ``sub(sResult, p7, c22, newSize)` `        ``# Grouping the results obtained in a single matrix:``        ``for` `i ``in` `range``(newSize):``            ``for` `j ``in` `range``(newSize):``                ``C[i][j] ``=` `c11[i][j]``                ``C[i][j ``+` `newSize] ``=` `c12[i][j]``                ``C[i ``+` `newSize][j] ``=` `c21[i][j]``                ``C[i ``+` `newSize][j ``+` `newSize] ``=` `c22[i][j]` `# for converting matrix to square matrix``def` `ConvertToSquareMat(A, B, r1, c1, r2, c2):``    ``maxSize ``=` `max``(r1, c1, r2, c2)``    ``size ``=` `nextPowerOf2(maxSize)` `    ``Aa ``=` `[[``0` `for` `j ``in` `range``(size)] ``for` `i ``in` `range``(size)]``    ``Bb ``=` `[[``0` `for` `j ``in` `range``(size)] ``for` `i ``in` `range``(size)]``    ``Cc ``=` `[[``0` `for` `j ``in` `range``(size)] ``for` `i ``in` `range``(size)]` `    ``for` `i ``in` `range``(r1):``        ``for` `j ``in` `range``(c1):``            ``Aa[i][j] ``=` `A[i][j]``    ``for` `i ``in` `range``(r2):``        ``for` `j ``in` `range``(c2):``            ``Bb[i][j] ``=` `B[i][j]` `    ``Strassen_algorithm(Aa, Bb, Cc, size)``    ``C ``=` `[[``0` `for` `j ``in` `range``(c2)] ``for` `i ``in` `range``(r1)]``    ``for` `i ``in` `range``(r1):``        ``for` `j ``in` `range``(c2):``            ``C[i][j] ``=` `Cc[i][j]``    ``display(C, r1, c1)` `# Driver code``A ``=` `[[``1``, ``2``, ``3``], [``1``, ``2``, ``3``], [``0``, ``0``, ``2``]]``B ``=` `[[``1``, ``0``, ``0``], [``0``, ``1``, ``0``], [``0``, ``0``, ``1``]]``ConvertToSquareMat(A, B, ``3``, ``3``, ``3``, ``3``) ``# A[][],B[][],R1,C1,R2,C2`

Output

```| 1 2 3 |
| 1 2 3 |
| 0 0 2 |```

This article is contributed by Mohit Gupta 🙂. If you like GeeksforGeeks and would like to contribute, you can also write an article using write.geeksforgeeks.org or mail your article to review-team@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.