Open In App
Related Articles

Strassen’s Matrix Multiplication Algorithm | Implementation

Improve
Improve
Improve
Like Article
Like
Save Article
Save
Report issue
Report

The Strassen’s method of matrix multiplication is a typical divide and conquer algorithm. We have discussed Strassen’s Algorithm here. However, let’s get again on what’s behind the divide and conquer approach and implement it. 

Prerequisite: It is required to see this post before further understanding.

Implementation

C++

// CPP program to implement Strassen’s Matrix
// Multiplication Algorithm
#include <bits/stdc++.h>
using namespace std;
typedef long long lld;
 
/* Strassen's Algorithm for matrix multiplication
Complexity: O(n^2.808) */
 
inline lld** MatrixMultiply(lld** a, lld** b, int n,
                                    int l, int m)
{
    lld** c = new lld*[n];
    for (int i = 0; i < n; i++)
        c[i] = new lld[m];
 
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            c[i][j] = 0;
            for (int k = 0; k < l; k++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    return c;
}
 
inline lld** Strassen(lld** a, lld** b, int n,
                                int l, int m)
{
    if (n == 1 || l == 1 || m == 1)
        return MatrixMultiply(a, b, n, l, m);
 
    lld** c = new lld*[n];
    for (int i = 0; i < n; i++)
        c[i] = new lld[m];
 
    int adjN = (n >> 1) + (n & 1);
    int adjL = (l >> 1) + (l & 1);
    int adjM = (m >> 1) + (m & 1);
 
    lld**** As = new lld***[2];
    for (int x = 0; x < 2; x++) {
        As[x] = new lld**[2];
        for (int y = 0; y < 2; y++) {
            As[x][y] = new lld*[adjN];
            for (int i = 0; i < adjN; i++) {
                As[x][y][i] = new lld[adjL];
                for (int j = 0; j < adjL; j++) {
                    int I = i + (x & 1) * adjN;
                    int J = j + (y & 1) * adjL;
                    As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
                }
            }
        }
    }
 
    lld**** Bs = new lld***[2];
    for (int x = 0; x < 2; x++) {
        Bs[x] = new lld**[2];
        for (int y = 0; y < 2; y++) {
            Bs[x][y] = new lld*[adjN];
            for (int i = 0; i < adjL; i++) {
                Bs[x][y][i] = new lld[adjM];
                for (int j = 0; j < adjM; j++) {
                    int I = i + (x & 1) * adjL;
                    int J = j + (y & 1) * adjM;
                    Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
                }
            }
        }
    }
 
    lld*** s = new lld**[10];
    for (int i = 0; i < 10; i++) {
        switch (i) {
        case 0:
            s[i] = new lld*[adjL];
            for (int j = 0; j < adjL; j++) {
                s[i][j] = new lld[adjM];
                for (int k = 0; k < adjM; k++) {
                    s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
                }
            }
            break;
        case 1:
            s[i] = new lld*[adjN];
            for (int j = 0; j < adjN; j++) {
                s[i][j] = new lld[adjL];
                for (int k = 0; k < adjL; k++) {
                    s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
                }
            }
            break;
        case 2:
            s[i] = new lld*[adjN];
            for (int j = 0; j < adjN; j++) {
                s[i][j] = new lld[adjL];
                for (int k = 0; k < adjL; k++) {
                    s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
                }
            }
            break;
        case 3:
            s[i] = new lld*[adjL];
            for (int j = 0; j < adjL; j++) {
                s[i][j] = new lld[adjM];
                for (int k = 0; k < adjM; k++) {
                    s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
                }
            }
            break;
        case 4:
            s[i] = new lld*[adjN];
            for (int j = 0; j < adjN; j++) {
                s[i][j] = new lld[adjL];
                for (int k = 0; k < adjL; k++) {
                    s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
                }
            }
            break;
        case 5:
            s[i] = new lld*[adjL];
            for (int j = 0; j < adjL; j++) {
                s[i][j] = new lld[adjM];
                for (int k = 0; k < adjM; k++) {
                    s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
                }
            }
            break;
        case 6:
            s[i] = new lld*[adjN];
            for (int j = 0; j < adjN; j++) {
                s[i][j] = new lld[adjL];
                for (int k = 0; k < adjL; k++) {
                    s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
                }
            }
            break;
        case 7:
            s[i] = new lld*[adjL];
            for (int j = 0; j < adjL; j++) {
                s[i][j] = new lld[adjM];
                for (int k = 0; k < adjM; k++) {
                    s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
                }
            }
            break;
        case 8:
            s[i] = new lld*[adjN];
            for (int j = 0; j < adjN; j++) {
                s[i][j] = new lld[adjL];
                for (int k = 0; k < adjL; k++) {
                    s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
                }
            }
            break;
        case 9:
            s[i] = new lld*[adjL];
            for (int j = 0; j < adjL; j++) {
                s[i][j] = new lld[adjM];
                for (int k = 0; k < adjM; k++) {
                    s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
                }
            }
            break;
        }
    }
 
    lld*** p = new lld**[7];
    p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
    p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
    p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
    p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
    p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
    p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
    p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
 
    for (int i = 0; i < adjN; i++) {
        for (int j = 0; j < adjM; j++) {
            c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
            if (j + adjM < m)
                c[i][j + adjM] = p[0][i][j] + p[1][i][j];
            if (i + adjN < n)
                c[i + adjN][j] = p[2][i][j] + p[3][i][j];
            if (i + adjN < n && j + adjM < m)
                c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
        }
    }
 
    for (int x = 0; x < 2; x++) {
        for (int y = 0; y < 2; y++) {
            for (int i = 0; i < adjN; i++) {
                delete[] As[x][y][i];
            }
            delete[] As[x][y];
        }
        delete[] As[x];
    }
    delete[] As;
 
    for (int x = 0; x < 2; x++) {
        for (int y = 0; y < 2; y++) {
            for (int i = 0; i < adjL; i++) {
                delete[] Bs[x][y][i];
            }
            delete[] Bs[x][y];
        }
        delete[] Bs[x];
    }
    delete[] Bs;
 
    for (int i = 0; i < 10; i++) {
        switch (i) {
        case 0:
        case 3:
        case 5:
        case 7:
        case 9:
            for (int j = 0; j < adjL; j++) {
                delete[] s[i][j];
            }
            break;
        case 1:
        case 2:
        case 4:
        case 6:
        case 8:
            for (int j = 0; j < adjN; j++) {
                delete[] s[i][j];
            }
            break;
        }
        delete[] s[i];
    }
    delete[] s;
 
    for (int i = 0; i < 7; i++) {
        for (int j = 0; j < (n >> 1); j++) {
            delete[] p[i][j];
        }
        delete[] p[i];
    }
    delete[] p;
 
    return c;
}
 
int main()
{
    lld** matA;
    matA = new lld*[2];
    for (int i = 0; i < 2; i++)
        matA[i] = new lld[3];
    matA[0][0] = 1;
    matA[0][1] = 2;
    matA[0][2] = 3;
    matA[1][0] = 4;
    matA[1][1] = 5;
    matA[1][2] = 6;
 
    lld** matB;
    matB = new lld*[3];
    for (int i = 0; i < 3; i++)
        matB[i] = new lld[2];
    matB[0][0] = 7;
    matB[0][1] = 8;
    matB[1][0] = 9;
    matB[1][1] = 10;
    matB[2][0] = 11;
    matB[2][1] = 12;
 
    lld** matC = Strassen(matA, matB, 2, 3, 2);
    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < 2; j++) {
            printf("%lld ", matC[i][j]);
        }
        printf("\n");
    }
 
    return 0;
}

                    

Java

public class Main {
    public static void main(String[] args) {
        // Initialize matrices A and B
        int[][] matA = {{1, 2, 3}, {4, 5, 6}};
        int[][] matB = {{7, 8}, {9, 10}, {11, 12}};
 
        // Pad the matrices with zeros if necessary
        int n = Math.max(Math.max(matA.length, matA[0].length), Math.max(matB.length, matB[0].length));
        int m = 1;
        while (m < n) m *= 2;
        int[][] newMatA = new int[m][m];
        int[][] newMatB = new int[m][m];
        for (int i = 0; i < matA.length; i++) {
            for (int j = 0; j < matA[0].length; j++) {
                newMatA[i][j] = matA[i][j];
            }
        }
        for (int i = 0; i < matB.length; i++) {
            for (int j = 0; j < matB[0].length; j++) {
                newMatB[i][j] = matB[i][j];
            }
        }
 
        // Perform Strassen's algorithm on matrices A and B
        int[][] matC = strassen(newMatA, newMatB);
 
        // Print the resulting matrix
        for (int i = 0; i < matA.length; i++) {
            for (int j = 0; j < matB[0].length; j++) {
                System.out.print(matC[i][j] + " ");
            }
            System.out.println();
        }
    }
 
    public static int[][] strassen(int[][] A, int[][] B) {
        int n = A.length;
        int[][] R = new int[n][n];
 
        // base case: 1x1 matrix
        if (n == 1) {
            R[0][0] = A[0][0] * B[0][0];
        } else {
            // split input matrices into quarters
            int[][] A11 = new int[n/2][n/2];
            int[][] A12 = new int[n/2][n/2];
            int[][] A21 = new int[n/2][n/2];
            int[][] A22 = new int[n/2][n/2];
 
            int[][] B11 = new int[n/2][n/2];
            int[][] B12 = new int[n/2][n/2];
            int[][] B21 = new int[n/2][n/2];
            int[][] B22 = new int[n/2][n/2];
 
            // Fill in the sub-matrices
            split(A, A11, 0 , 0);
            split(A, A12, 0 , n/2);
            split(A, A21, n/2, 0);
            split(A, A22, n/2, n/2);
 
            split(B, B11, 0 , 0);
            split(B, B12, 0 , n/2);
            split(B, B21, n/2, 0);
            split(B, B22, n/2, n/2);
 
            // Calculate p1 to p7
            int [][] P1 = strassen(add(A11, A22), add(B11,B22));
            int [][] P2 = strassen(add(A21, A22), B11);
            int [][] P3 = strassen(A11, subtract(B12, B22));
            int [][] P4 = strassen(A22, subtract(B21, B11));
            int [][] P5 = strassen(add(A11, A12), B22);
            int [][] P6 = strassen(subtract(A21, A11), add(B11, B12));
            int [][] P7 = strassen(subtract(A12, A22), add(B21, B22));
 
            // Calculate the 4 quarters of the resulting matrix
            int [][] C11 = add(subtract(add(P1, P4), P5), P7);
            int [][] C12 = add(P3, P5);
            int [][] C21 = add(P2, P4);
            int [][] C22 = add(subtract(add(P1, P3), P2), P6);
 
            // Combine the 4 quarters into a single result matrix
            join(C11, R, 0 , 0);
            join(C12, R, 0 , n/2);
            join(C21, R, n/2, 0);
            join(C22, R, n/2, n/2);
        }
        return R;
    }
 
    // Function to add two matrices
    public static int[][] add(int[][] A, int[][] B) {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i][j] = A[i][j] + B[i][j];
            }
        }
        return C;
    }
 
    // Function to subtract two matrices
    public static int[][] subtract(int[][] A, int[][] B) {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                C[i][j] = A[i][j] - B[i][j];
            }
        }
        return C;
    }
 
    // Function to split parent matrix into child matrices
    public static void split(int[][] P, int[][] C, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<C.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<C.length; j1++, j2++)
                C[i1][j1] = P[i2][j2];
    }
 
    // Function to join child matrices into parent matrix
    public static void join(int[][] C, int[][] P, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<C.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<C.length; j1++, j2++)
                P[i2][j2] = C[i1][j1];
    }
}

                    

Python3

import numpy as np
 
def strassen(A, B):
    n = A.shape[0]
 
    # base case: 1x1 matrix
    if n == 1:
        return A * B
    else:
        # split input matrices into quarters
        mid = n // 2
        A11, A12, A21, A22 = A[:mid, :mid], A[:mid, mid:], A[mid:, :mid], A[mid:, mid:]
        B11, B12, B21, B22 = B[:mid, :mid], B[:mid, mid:], B[mid:, :mid], B[mid:, mid:]
 
        # calculate p1 to p7
        P1 = strassen(A11 + A22, B11 + B22)
        P2 = strassen(A21 + A22, B11)
        P3 = strassen(A11, B12 - B22)
        P4 = strassen(A22, B21 - B11)
        P5 = strassen(A11 + A12, B22)
        P6 = strassen(A21 - A11, B11 + B12)
        P7 = strassen(A12 - A22, B21 + B22)
 
        # calculate the 4 quarters of the resulting matrix
        C11 = P1 + P4 - P5 + P7
        C12 = P3 + P5
        C21 = P2 + P4
        C22 = P1 + P3 - P2 + P6
 
        # combine the 4 quarters into a single result matrix
        C = np.vstack((np.hstack((C11, C12)), np.hstack((C21, C22))))
 
        return C
 
def main():
    # Initialize matrices A and B
    A = np.array([[1, 2, 3], [4, 5, 6]])
    B = np.array([[7, 8], [9, 10], [11, 12]])
 
    # Pad the matrices with zeros if necessary
    n = max(A.shape[0], A.shape[1], B.shape[0], B.shape[1])
    m = 1
    while m < n:
        m *= 2
    A = np.pad(A, ((0, m - A.shape[0]), (0, m - A.shape[1])), mode='constant')
    B = np.pad(B, ((0, m - B.shape[0]), (0, m - B.shape[1])), mode='constant')
 
    # Perform Strassen's algorithm on matrices A and B
    C = strassen(A, B)
 
    # Print the resulting matrix
    print(C[:2, :2])
 
if __name__ == "__main__":
    main()

                    

Javascript

// Strassen's Algorithm for matrix multiplication
// Complexity: O(n^2.808)
 
// Function to perform standard matrix multiplication
function MatrixMultiply(a, b, n, l, m) {
    // Initialize result matrix with zeros
    let c = Array.from(Array(n), () => new Array(m).fill(0));
 
    // Perform standard matrix multiplication
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < m; j++) {
            for (let k = 0; k < l; k++) {
                c[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    return c;
}
 
// Function to perform Strassen's matrix multiplication
function Strassen(a, b, n, l, m) {
    // If the matrices are small enough, perform standard multiplication
    if (n === 1 || l === 1 || m === 1)
        return MatrixMultiply(a, b, n, l, m);
 
    // Initialize result matrix with zeros
    let c = Array.from(Array(n), () => new Array(m).fill(0));
 
    // Calculate the size of the submatrices
    let adjN = (n >> 1) + (n & 1);
    let adjL = (l >> 1) + (l & 1);
    let adjM = (m >> 1) + (m & 1);
 
    // Initialize the submatrices of a and b
    let As = Array.from(Array(2), () => Array.from(Array(2), () => Array.from(Array(adjN), () => new Array(adjL).fill(0))));
    let Bs = Array.from(Array(2), () => Array.from(Array(2), () => Array.from(Array(adjL), () => new Array(adjM).fill(0))));
 
    // Fill the submatrices of a and b
    for (let x = 0; x < 2; x++) {
        for (let y = 0; y < 2; y++) {
            for (let i = 0; i < adjN; i++) {
                for (let j = 0; j < adjL; j++) {
                    let I = i + (x & 1) * adjN;
                    let J = j + (y & 1) * adjL;
                    As[x][y][i][j] = (I < n && J < l) ? a[I][J] : 0;
                }
            }
        }
    }
 
    for (let x = 0; x < 2; x++) {
        for (let y = 0; y < 2; y++) {
            for (let i = 0; i < adjL; i++) {
                for (let j = 0; j < adjM; j++) {
                    let I = i + (x & 1) * adjL;
                    let J = j + (y & 1) * adjM;
                    Bs[x][y][i][j] = (I < l && J < m) ? b[I][J] : 0;
                }
            }
        }
    }
 
    // Initialize the auxiliary matrices s
    let s = Array.from(Array(10), () => Array.from(Array(adjN), () => new Array(adjL).fill(0)));
 
    // Fill the auxiliary matrices s
    for (let i = 0; i < 10; i++) {
        switch (i) {
            case 0:
            case 3:
            case 5:
            case 7:
            case 9:
                s[i] = Array.from(Array(adjL), () => new Array(adjM).fill(0));
                break;
            case 1:
            case 2:
            case 4:
            case 6:
            case 8:
                s[i] = Array.from(Array(adjN), () => new Array(adjL).fill(0));
                break;
        }
    }
 
    for (let i = 0; i < 10; i++) {
        switch (i) {
            case 0:
                for (let j = 0; j < adjL; j++) {
                    for (let k = 0; k < adjM; k++) {
                        s[i][j][k] = Bs[0][1][j][k] - Bs[1][1][j][k];
                    }
                }
                break;
            case 1:
                for (let j = 0; j < adjN; j++) {
                    for (let k = 0; k < adjL; k++) {
                        s[i][j][k] = As[0][0][j][k] + As[0][1][j][k];
                    }
                }
                break;
            case 2:
                for (let j = 0; j < adjN; j++) {
                    for (let k = 0; k < adjL; k++) {
                        s[i][j][k] = As[1][0][j][k] + As[1][1][j][k];
                    }
                }
                break;
            case 3:
                for (let j = 0; j < adjL; j++) {
                    for (let k = 0; k < adjM; k++) {
                        s[i][j][k] = Bs[1][0][j][k] - Bs[0][0][j][k];
                    }
                }
                break;
            case 4:
                for (let j = 0; j < adjN; j++) {
                    for (let k = 0; k < adjL; k++) {
                        s[i][j][k] = As[0][0][j][k] + As[1][1][j][k];
                    }
                }
                break;
            case 5:
                for (let j = 0; j < adjL; j++) {
                    for (let k = 0; k < adjM; k++) {
                        s[i][j][k] = Bs[0][0][j][k] + Bs[1][1][j][k];
                    }
                }
                break;
            case 6:
                for (let j = 0; j < adjN; j++) {
                    for (let k = 0; k < adjL; k++) {
                        s[i][j][k] = As[0][1][j][k] - As[1][1][j][k];
                    }
                }
                break;
            case 7:
                for (let j = 0; j < adjL; j++) {
                    for (let k = 0; k < adjM; k++) {
                        s[i][j][k] = Bs[1][0][j][k] + Bs[1][1][j][k];
                    }
                }
                break;
            case 8:
                for (let j = 0; j < adjN; j++) {
                    for (let k = 0; k < adjL; k++) {
                        s[i][j][k] = As[0][0][j][k] - As[1][0][j][k];
                    }
                }
                break;
            case 9:
                for (let j = 0; j < adjL; j++) {
                    for (let k = 0; k < adjM; k++) {
                        s[i][j][k] = Bs[0][0][j][k] + Bs[0][1][j][k];
                    }
                }
                break;
        }
    }
 
    // Initialize the matrices p
    let p = Array.from(Array(7), () => Array.from(Array(adjN), () => new Array(adjL).fill(0)));
 
    // Fill the matrices p
    p[0] = Strassen(As[0][0], s[0], adjN, adjL, adjM);
    p[1] = Strassen(s[1], Bs[1][1], adjN, adjL, adjM);
    p[2] = Strassen(s[2], Bs[0][0], adjN, adjL, adjM);
    p[3] = Strassen(As[1][1], s[3], adjN, adjL, adjM);
    p[4] = Strassen(s[4], s[5], adjN, adjL, adjM);
    p[5] = Strassen(s[6], s[7], adjN, adjL, adjM);
    p[6] = Strassen(s[8], s[9], adjN, adjL, adjM);
 
    // Calculate the result matrix c
    for (let i = 0; i < adjN; i++) {
        for (let j = 0; j < adjM; j++) {
            c[i][j] = p[4][i][j] + p[3][i][j] - p[1][i][j] + p[5][i][j];
            if (j + adjM < m)
                c[i][j + adjM] = p[0][i][j] + p[1][i][j];
            if (i + adjN < n)
                c[i + adjN][j] = p[2][i][j] + p[3][i][j];
            if (i + adjN < n && j + adjM < m)
                c[i + adjN][j + adjM] = p[4][i][j] + p[0][i][j] - p[2][i][j] - p[6][i][j];
        }
    }
 
    return c;
}
 
// Test the function with two matrices
let matA = [[1, 2, 3], [4, 5, 6]];
let matB = [[7, 8], [9, 10], [11, 12]];
 
let matC = Strassen(matA, matB, 2, 3, 2);
for (let i = 0; i < 2; i++) {
    console.log(matC[i].join(' '));
}

                    

Output
58 64 
139 154 











Time Complexity : O(n ^2.808), the algorithm first checks if the size of the matrices is 1, and if so, returns the result of a standard matrix multiplication. Otherwise, it divides the matrices into 4 submatrices and performs 7 matrix multiplications recursively. Finally, it combines the results of the multiplications to obtain the final result.

 Space complexity : O(n^2)



Last Updated : 15 Feb, 2024
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads