The Strassen’s method of matrix multiplication is a typical divide and conquer algorithm. We have discussed Strassen’s Algorithm here. However, let’s get again on what’s behind the divide and conquer approach and implement it.
Prerequisite: It is required to see this post before further understanding.
Implementation
// CPP program to implement Strassen’s Matrix // Multiplication Algorithm #include <bits/stdc++.h> using namespace std;
typedef long long lld;
/* Strassen's Algorithm for matrix multiplication Complexity: O(n^2.808) */ inline lld** MatrixMultiply(lld** a, lld** b, int n,
int l, int m)
{ lld** c = new lld*[n];
for ( int i = 0; i < n; i++)
c[i] = new lld[m];
for ( int i = 0; i < n; i++) {
for ( int j = 0; j < m; j++) {
c[i][j] = 0;
for ( int k = 0; k < l; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
return c;
} inline lld** Strassen(lld** a, lld** b, int n,
int l, int m)
{ if (n == 1 || l == 1 || m == 1)
return MatrixMultiply(a, b, n, l, m);
lld** c = new lld*[n];
for ( int i = 0; i < n; i++)
c[i] = new lld[m];
int adjN = (n >> 1) + (n & 1);
int adjL = (l >> 1) + (l & 1);
int adjM = (m >> 1) + (m & 1);
lld**** As = new lld***[2];
for ( int x = 0; x < 2; x++) {
As[x] = new lld**[2];
for ( int y = 0; y < 2; y++) {
As[x][y] = new lld*[adjN];
for ( int i = 0; i < adjN; i++) {
As[x][y][i] = new lld[adjL];
for ( int j = 0; j < adjL; j++) {
int I = i + (x & 1) * adjN;
int J = j + (y & 1) * adjL;
As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
}
}
}
}
lld**** Bs = new lld***[2];
for ( int x = 0; x < 2; x++) {
Bs[x] = new lld**[2];
for ( int y = 0; y < 2; y++) {
Bs[x][y] = new lld*[adjN];
for ( int i = 0; i < adjL; i++) {
Bs[x][y][i] = new lld[adjM];
for ( int j = 0; j < adjM; j++) {
int I = i + (x & 1) * adjL;
int J = j + (y & 1) * adjM;
Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
}
}
}
}
lld*** s = new lld**[10];
for ( int i = 0; i < 10; i++) {
switch (i) {
case 0:
s[i] = new lld*[adjL];
for ( int j = 0; j < adjL; j++) {
s[i][j] = new lld[adjM];
for ( int k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
}
}
break ;
case 1:
s[i] = new lld*[adjN];
for ( int j = 0; j < adjN; j++) {
s[i][j] = new lld[adjL];
for ( int k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
}
}
break ;
case 2:
s[i] = new lld*[adjN];
for ( int j = 0; j < adjN; j++) {
s[i][j] = new lld[adjL];
for ( int k = 0; k < adjL; k++) {
s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
}
}
break ;
case 3:
s[i] = new lld*[adjL];
for ( int j = 0; j < adjL; j++) {
s[i][j] = new lld[adjM];
for ( int k = 0; k < adjM; k++) {
s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
}
}
break ;
case 4:
s[i] = new lld*[adjN];
for ( int j = 0; j < adjN; j++) {
s[i][j] = new lld[adjL];
for ( int k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
}
}
break ;
case 5:
s[i] = new lld*[adjL];
for ( int j = 0; j < adjL; j++) {
s[i][j] = new lld[adjM];
for ( int k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 6:
s[i] = new lld*[adjN];
for ( int j = 0; j < adjN; j++) {
s[i][j] = new lld[adjL];
for ( int k = 0; k < adjL; k++) {
s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
}
}
break ;
case 7:
s[i] = new lld*[adjL];
for ( int j = 0; j < adjL; j++) {
s[i][j] = new lld[adjM];
for ( int k = 0; k < adjM; k++) {
s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 8:
s[i] = new lld*[adjN];
for ( int j = 0; j < adjN; j++) {
s[i][j] = new lld[adjL];
for ( int k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
}
}
break ;
case 9:
s[i] = new lld*[adjL];
for ( int j = 0; j < adjL; j++) {
s[i][j] = new lld[adjM];
for ( int k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
}
}
break ;
}
}
lld*** p = new lld**[7];
p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
for ( int i = 0; i < adjN; i++) {
for ( int j = 0; j < adjM; j++) {
c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
if (j + adjM < m)
c[i][j + adjM] = p[0][i][j] + p[1][i][j];
if (i + adjN < n)
c[i + adjN][j] = p[2][i][j] + p[3][i][j];
if (i + adjN < n && j + adjM < m)
c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
}
}
for ( int x = 0; x < 2; x++) {
for ( int y = 0; y < 2; y++) {
for ( int i = 0; i < adjN; i++) {
delete [] As[x][y][i];
}
delete [] As[x][y];
}
delete [] As[x];
}
delete [] As;
for ( int x = 0; x < 2; x++) {
for ( int y = 0; y < 2; y++) {
for ( int i = 0; i < adjL; i++) {
delete [] Bs[x][y][i];
}
delete [] Bs[x][y];
}
delete [] Bs[x];
}
delete [] Bs;
for ( int i = 0; i < 10; i++) {
switch (i) {
case 0:
case 3:
case 5:
case 7:
case 9:
for ( int j = 0; j < adjL; j++) {
delete [] s[i][j];
}
break ;
case 1:
case 2:
case 4:
case 6:
case 8:
for ( int j = 0; j < adjN; j++) {
delete [] s[i][j];
}
break ;
}
delete [] s[i];
}
delete [] s;
for ( int i = 0; i < 7; i++) {
for ( int j = 0; j < (n >> 1); j++) {
delete [] p[i][j];
}
delete [] p[i];
}
delete [] p;
return c;
} int main()
{ lld** matA;
matA = new lld*[2];
for ( int i = 0; i < 2; i++)
matA[i] = new lld[3];
matA[0][0] = 1;
matA[0][1] = 2;
matA[0][2] = 3;
matA[1][0] = 4;
matA[1][1] = 5;
matA[1][2] = 6;
lld** matB;
matB = new lld*[3];
for ( int i = 0; i < 3; i++)
matB[i] = new lld[2];
matB[0][0] = 7;
matB[0][1] = 8;
matB[1][0] = 9;
matB[1][1] = 10;
matB[2][0] = 11;
matB[2][1] = 12;
lld** matC = Strassen(matA, matB, 2, 3, 2);
for ( int i = 0; i < 2; i++) {
for ( int j = 0; j < 2; j++) {
printf ( "%lld " , matC[i][j]);
}
printf ( "\n" );
}
return 0;
} |
public class Main {
public static void main(String[] args) {
// Initialize matrices A and B
int [][] matA = {{ 1 , 2 , 3 }, { 4 , 5 , 6 }};
int [][] matB = {{ 7 , 8 }, { 9 , 10 }, { 11 , 12 }};
// Pad the matrices with zeros if necessary
int n = Math.max(Math.max(matA.length, matA[ 0 ].length), Math.max(matB.length, matB[ 0 ].length));
int m = 1 ;
while (m < n) m *= 2 ;
int [][] newMatA = new int [m][m];
int [][] newMatB = new int [m][m];
for ( int i = 0 ; i < matA.length; i++) {
for ( int j = 0 ; j < matA[ 0 ].length; j++) {
newMatA[i][j] = matA[i][j];
}
}
for ( int i = 0 ; i < matB.length; i++) {
for ( int j = 0 ; j < matB[ 0 ].length; j++) {
newMatB[i][j] = matB[i][j];
}
}
// Perform Strassen's algorithm on matrices A and B
int [][] matC = strassen(newMatA, newMatB);
// Print the resulting matrix
for ( int i = 0 ; i < matA.length; i++) {
for ( int j = 0 ; j < matB[ 0 ].length; j++) {
System.out.print(matC[i][j] + " " );
}
System.out.println();
}
}
public static int [][] strassen( int [][] A, int [][] B) {
int n = A.length;
int [][] R = new int [n][n];
// base case: 1x1 matrix
if (n == 1 ) {
R[ 0 ][ 0 ] = A[ 0 ][ 0 ] * B[ 0 ][ 0 ];
} else {
// split input matrices into quarters
int [][] A11 = new int [n/ 2 ][n/ 2 ];
int [][] A12 = new int [n/ 2 ][n/ 2 ];
int [][] A21 = new int [n/ 2 ][n/ 2 ];
int [][] A22 = new int [n/ 2 ][n/ 2 ];
int [][] B11 = new int [n/ 2 ][n/ 2 ];
int [][] B12 = new int [n/ 2 ][n/ 2 ];
int [][] B21 = new int [n/ 2 ][n/ 2 ];
int [][] B22 = new int [n/ 2 ][n/ 2 ];
// Fill in the sub-matrices
split(A, A11, 0 , 0 );
split(A, A12, 0 , n/ 2 );
split(A, A21, n/ 2 , 0 );
split(A, A22, n/ 2 , n/ 2 );
split(B, B11, 0 , 0 );
split(B, B12, 0 , n/ 2 );
split(B, B21, n/ 2 , 0 );
split(B, B22, n/ 2 , n/ 2 );
// Calculate p1 to p7
int [][] P1 = strassen(add(A11, A22), add(B11,B22));
int [][] P2 = strassen(add(A21, A22), B11);
int [][] P3 = strassen(A11, subtract(B12, B22));
int [][] P4 = strassen(A22, subtract(B21, B11));
int [][] P5 = strassen(add(A11, A12), B22);
int [][] P6 = strassen(subtract(A21, A11), add(B11, B12));
int [][] P7 = strassen(subtract(A12, A22), add(B21, B22));
// Calculate the 4 quarters of the resulting matrix
int [][] C11 = add(subtract(add(P1, P4), P5), P7);
int [][] C12 = add(P3, P5);
int [][] C21 = add(P2, P4);
int [][] C22 = add(subtract(add(P1, P3), P2), P6);
// Combine the 4 quarters into a single result matrix
join(C11, R, 0 , 0 );
join(C12, R, 0 , n/ 2 );
join(C21, R, n/ 2 , 0 );
join(C22, R, n/ 2 , n/ 2 );
}
return R;
}
// Function to add two matrices
public static int [][] add( int [][] A, int [][] B) {
int n = A.length;
int [][] C = new int [n][n];
for ( int i = 0 ; i < n; i++) {
for ( int j = 0 ; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// Function to subtract two matrices
public static int [][] subtract( int [][] A, int [][] B) {
int n = A.length;
int [][] C = new int [n][n];
for ( int i = 0 ; i < n; i++) {
for ( int j = 0 ; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// Function to split parent matrix into child matrices
public static void split( int [][] P, int [][] C, int iB, int jB) {
for ( int i1 = 0 , i2=iB; i1<C.length; i1++, i2++)
for ( int j1 = 0 , j2=jB; j1<C.length; j1++, j2++)
C[i1][j1] = P[i2][j2];
}
// Function to join child matrices into parent matrix
public static void join( int [][] C, int [][] P, int iB, int jB) {
for ( int i1 = 0 , i2=iB; i1<C.length; i1++, i2++)
for ( int j1 = 0 , j2=jB; j1<C.length; j1++, j2++)
P[i2][j2] = C[i1][j1];
}
} |
using System;
public class Program
{ /* Strassen's Algorithm for matrix multiplication
Complexity: O(n^2.808) */
static long [][] MatrixMultiply( long [][] a, long [][] b, int n, int l, int m)
{
long [][] c = new long [n][];
for ( int i = 0; i < n; i++)
c[i] = new long [m];
for ( int i = 0; i < n; i++)
{
for ( int j = 0; j < m; j++)
{
c[i][j] = 0;
for ( int k = 0; k < l; k++)
{
c[i][j] += a[i][k] * b[k][j];
}
}
}
return c;
}
static long [][] Strassen( long [][] a, long [][] b, int n, int l, int m)
{
if (n == 1 || l == 1 || m == 1)
return MatrixMultiply(a, b, n, l, m);
long [][] c = new long [n][];
for ( int i = 0; i < n; i++)
c[i] = new long [m];
int adjN = (n >> 1) + (n & 1);
int adjL = (l >> 1) + (l & 1);
int adjM = (m >> 1) + (m & 1);
long [][][][] As = new long [2][][][];
for ( int x = 0; x < 2; x++)
{
As[x] = new long [2][][];
for ( int y = 0; y < 2; y++)
{
As[x][y] = new long [adjN][];
for ( int i = 0; i < adjN; i++)
{
As[x][y][i] = new long [adjL];
for ( int j = 0; j < adjL; j++)
{
int I = i + (x & 1) * adjN;
int J = j + (y & 1) * adjL;
As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
}
}
}
}
long [][][][] Bs = new long [2][][][];
for ( int x = 0; x < 2; x++)
{
Bs[x] = new long [2][][];
for ( int y = 0; y < 2; y++)
{
Bs[x][y] = new long [adjL][];
for ( int i = 0; i < adjL; i++)
{
Bs[x][y][i] = new long [adjM];
for ( int j = 0; j < adjM; j++)
{
int I = i + (x & 1) * adjL;
int J = j + (y & 1) * adjM;
Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
}
}
}
}
long [][][] s = new long [10][][];
for ( int i = 0; i < 10; i++)
{
switch (i)
{
case 0:
s[i] = new long [adjL][];
for ( int j = 0; j < adjL; j++)
{
s[i][j] = new long [adjM];
for ( int k = 0; k < adjM; k++)
{
s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
}
}
break ;
case 1:
s[i] = new long [adjN][];
for ( int j = 0; j < adjN; j++)
{
s[i][j] = new long [adjL];
for ( int k = 0; k < adjL; k++)
{
s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
}
}
break ;
case 2:
s[i] = new long [adjN][];
for ( int j = 0; j < adjN; j++)
{
s[i][j] = new long [adjL];
for ( int k = 0; k < adjL; k++)
{
s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
}
}
break ;
case 3:
s[i] = new long [adjL][];
for ( int j = 0; j < adjL; j++)
{
s[i][j] = new long [adjM];
for ( int k = 0; k < adjM; k++)
{
s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
}
}
break ;
case 4:
s[i] = new long [adjN][];
for ( int j = 0; j < adjN; j++)
{
s[i][j] = new long [adjL];
for ( int k = 0; k < adjL; k++)
{
s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
}
}
break ;
case 5:
s[i] = new long [adjL][];
for ( int j = 0; j < adjL; j++)
{
s[i][j] = new long [adjM];
for ( int k = 0; k < adjM; k++)
{
s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 6:
s[i] = new long [adjN][];
for ( int j = 0; j < adjN; j++)
{
s[i][j] = new long [adjL];
for ( int k = 0; k < adjL; k++)
{
s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
}
}
break ;
case 7:
s[i] = new long [adjL][];
for ( int j = 0; j < adjL; j++)
{
s[i][j] = new long [adjM];
for ( int k = 0; k < adjM; k++)
{
s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 8:
s[i] = new long [adjN][];
for ( int j = 0; j < adjN; j++)
{
s[i][j] = new long [adjL];
for ( int k = 0; k < adjL; k++)
{
s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
}
}
break ;
case 9:
s[i] = new long [adjL][];
for ( int j = 0; j < adjL; j++)
{
s[i][j] = new long [adjM];
for ( int k = 0; k < adjM; k++)
{
s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
}
}
break ;
}
}
long [][][] p = new long [7][][];
p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
for ( int i = 0; i < adjN; i++)
{
for ( int j = 0; j < adjM; j++)
{
c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
if (j + adjM < m)
c[i][j + adjM] = p[0][i][j] + p[1][i][j];
if (i + adjN < n)
c[i + adjN][j] = p[2][i][j] + p[3][i][j];
if (i + adjN < n && j + adjM < m)
c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
}
}
for ( int x = 0; x < 2; x++)
{
for ( int y = 0; y < 2; y++)
{
for ( int i = 0; i < adjN; i++)
{
Array.Clear(As[x][y][i], 0, As[x][y][i].Length);
As[x][y][i] = null ;
}
As[x][y] = null ;
}
As[x] = null ;
}
for ( int x = 0; x < 2; x++)
{
for ( int y = 0; y < 2; y++)
{
for ( int i = 0; i < adjL; i++)
{
Array.Clear(Bs[x][y][i], 0, Bs[x][y][i].Length);
Bs[x][y][i] = null ;
}
Bs[x][y] = null ;
}
Bs[x] = null ;
}
for ( int i = 0; i < 10; i++)
{
switch (i)
{
case 0:
case 3:
case 5:
case 7:
case 9:
for ( int j = 0; j < adjL; j++)
{
Array.Clear(s[i][j], 0, s[i][j].Length);
s[i][j] = null ;
}
break ;
case 1:
case 2:
case 4:
case 6:
case 8:
for ( int j = 0; j < adjN; j++)
{
Array.Clear(s[i][j], 0, s[i][j].Length);
s[i][j] = null ;
}
break ;
}
s[i] = null ;
}
for ( int i = 0; i < 7; i++)
{
for ( int j = 0; j < (n >> 1); j++)
{
Array.Clear(p[i][j], 0, p[i][j].Length);
p[i][j] = null ;
}
p[i] = null ;
}
return c;
}
public static void Main( string [] args)
{
long [][] matA = new long [2][];
matA[0] = new long [] { 1, 2, 3 };
matA[1] = new long [] { 4, 5, 6 };
long [][] matB = new long [3][];
matB[0] = new long [] { 7, 8 };
matB[1] = new long [] { 9, 10 };
matB[2] = new long [] { 11, 12 };
long [][] matC = Strassen(matA, matB, 2, 3, 2);
for ( int i = 0; i < 2; i++)
{
for ( int j = 0; j < 2; j++)
{
Console.Write(matC[i][j] + " " );
}
Console.WriteLine();
}
}
} |
// Strassen's Algorithm for matrix multiplication // Complexity: O(n^2.808) // Function to perform standard matrix multiplication function MatrixMultiply(a, b, n, l, m) {
// Initialize result matrix with zeros
let c = Array.from(Array(n), () => new Array(m).fill(0));
// Perform standard matrix multiplication
for (let i = 0; i < n; i++) {
for (let j = 0; j < m; j++) {
for (let k = 0; k < l; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
return c;
} // Function to perform Strassen's matrix multiplication function Strassen(a, b, n, l, m) {
// If the matrices are small enough, perform standard multiplication
if (n === 1 || l === 1 || m === 1)
return MatrixMultiply(a, b, n, l, m);
// Initialize result matrix with zeros
let c = Array.from(Array(n), () => new Array(m).fill(0));
// Calculate the size of the submatrices
let adjN = (n >> 1) + (n & 1);
let adjL = (l >> 1) + (l & 1);
let adjM = (m >> 1) + (m & 1);
// Initialize the submatrices of a and b
let As = Array.from(Array(2), () => Array.from(Array(2), () => Array.from(Array(adjN), () => new Array(adjL).fill(0))));
let Bs = Array.from(Array(2), () => Array.from(Array(2), () => Array.from(Array(adjL), () => new Array(adjM).fill(0))));
// Fill the submatrices of a and b
for (let x = 0; x < 2; x++) {
for (let y = 0; y < 2; y++) {
for (let i = 0; i < adjN; i++) {
for (let j = 0; j < adjL; j++) {
let I = i + (x & 1) * adjN;
let J = j + (y & 1) * adjL;
As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
}
}
}
}
for (let x = 0; x < 2; x++) {
for (let y = 0; y < 2; y++) {
for (let i = 0; i < adjL; i++) {
for (let j = 0; j < adjM; j++) {
let I = i + (x & 1) * adjL;
let J = j + (y & 1) * adjM;
Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
}
}
}
}
// Initialize the auxiliary matrices s
let s = Array.from(Array(10), () => Array.from(Array(adjN), () => new Array(adjL).fill(0)));
// Fill the auxiliary matrices s
for (let i = 0; i < 10; i++) {
switch (i) {
case 0:
case 3:
case 5:
case 7:
case 9:
s[i] = Array.from(Array(adjL), () => new Array(adjM).fill(0));
break ;
case 1:
case 2:
case 4:
case 6:
case 8:
s[i] = Array.from(Array(adjN), () => new Array(adjL).fill(0));
break ;
}
}
for (let i = 0; i < 10; i++) {
switch (i) {
case 0:
for (let j = 0; j < adjL; j++) {
for (let k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
}
}
break ;
case 1:
for (let j = 0; j < adjN; j++) {
for (let k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
}
}
break ;
case 2:
for (let j = 0; j < adjN; j++) {
for (let k = 0; k < adjL; k++) {
s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
}
}
break ;
case 3:
for (let j = 0; j < adjL; j++) {
for (let k = 0; k < adjM; k++) {
s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
}
}
break ;
case 4:
for (let j = 0; j < adjN; j++) {
for (let k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
}
}
break ;
case 5:
for (let j = 0; j < adjL; j++) {
for (let k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 6:
for (let j = 0; j < adjN; j++) {
for (let k = 0; k < adjL; k++) {
s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
}
}
break ;
case 7:
for (let j = 0; j < adjL; j++) {
for (let k = 0; k < adjM; k++) {
s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
}
}
break ;
case 8:
for (let j = 0; j < adjN; j++) {
for (let k = 0; k < adjL; k++) {
s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
}
}
break ;
case 9:
for (let j = 0; j < adjL; j++) {
for (let k = 0; k < adjM; k++) {
s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
}
}
break ;
}
}
// Initialize the matrices p
let p = Array.from(Array(7), () => Array.from(Array(adjN), () => new Array(adjL).fill(0)));
// Fill the matrices p
p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
// Calculate the result matrix c
for (let i = 0; i < adjN; i++) {
for (let j = 0; j < adjM; j++) {
c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
if (j + adjM < m)
c[i][j + adjM] = p[0][i][j] + p[1][i][j];
if (i + adjN < n)
c[i + adjN][j] = p[2][i][j] + p[3][i][j];
if (i + adjN < n && j + adjM < m)
c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
}
}
return c;
} // Test the function with two matrices let matA = [[1, 2, 3], [4, 5, 6]]; let matB = [[7, 8], [9, 10], [11, 12]]; let matC = Strassen(matA, matB, 2, 3, 2); for (let i = 0; i < 2; i++) {
console.log(matC[i].join( ' ' ));
} |
import numpy as np
def strassen(A, B):
n = A.shape[ 0 ]
# base case: 1x1 matrix
if n = = 1 :
return A * B
else :
# split input matrices into quarters
mid = n / / 2
A11, A12, A21, A22 = A[:mid, :mid], A[:mid, mid:], A[mid:, :mid], A[mid:, mid:]
B11, B12, B21, B22 = B[:mid, :mid], B[:mid, mid:], B[mid:, :mid], B[mid:, mid:]
# calculate p1 to p7
P1 = strassen(A11 + A22, B11 + B22)
P2 = strassen(A21 + A22, B11)
P3 = strassen(A11, B12 - B22)
P4 = strassen(A22, B21 - B11)
P5 = strassen(A11 + A12, B22)
P6 = strassen(A21 - A11, B11 + B12)
P7 = strassen(A12 - A22, B21 + B22)
# calculate the 4 quarters of the resulting matrix
C11 = P1 + P4 - P5 + P7
C12 = P3 + P5
C21 = P2 + P4
C22 = P1 + P3 - P2 + P6
# combine the 4 quarters into a single result matrix
C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
return C
def main():
# Initialize matrices A and B
A = np.array([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]])
B = np.array([[ 7 , 8 ], [ 9 , 10 ], [ 11 , 12 ]])
# Pad the matrices with zeros if necessary
n = max (A.shape[ 0 ], A.shape[ 1 ], B.shape[ 0 ], B.shape[ 1 ])
m = 1
while m < n:
m * = 2
A = np.pad(A, (( 0 , m - A.shape[ 0 ]), ( 0 , m - A.shape[ 1 ])), mode = 'constant' )
B = np.pad(B, (( 0 , m - B.shape[ 0 ]), ( 0 , m - B.shape[ 1 ])), mode = 'constant' )
# Perform Strassen's algorithm on matrices A and B
C = strassen(A, B)
# Print the resulting matrix
print (C[: 2 , : 2 ])
if __name__ = = "__main__" :
main()
|
58 64 139 154
Time Complexity : O(n ^2.808), the algorithm first checks if the size of the matrices is 1, and if so, returns the result of a standard matrix multiplication. Otherwise, it divides the matrices into 4 submatrices and performs 7 matrix multiplications recursively. Finally, it combines the results of the multiplications to obtain the final result.
Space complexity : O(n^2)