Open In App

Divide and Conquer | Set 5 (Strassen’s Matrix Multiplication)

Improve
Improve
Improve
Like Article
Like
Save Article
Save
Share
Report issue
Report

Given two square matrices A and B of size n x n each, find their multiplication matrix. 

Naive Method: Following is a simple way to multiply two matrices. 

C++




void multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}
 
// This code is contributed by noob2000.


C




void multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}


Java




// java code
static int multiply(int A[][N], int B[][N], int C[][N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}
 
// This code is contributed by shivanisinghss2110


Python3




def multiply(A, B, C):
 
    for i in range(N):
     
        for j in range( N):
         
            C[i][j] = 0
            for k in range(N):
             
                C[i][j] += A[i][k]*B[k][j]
 
# this code is contributed by shivanisinghss2110


C#




// C# code
static int multiply(int A[,N], int B[,N], int C[,N])
{
    for (int i = 0; i < N; i++)
    {
        for (int j = 0; j < N; j++)
        {
            C[i,j] = 0;
            for (int k = 0; k < N; k++)
            {
                C[i,j] += A[i,k]*B[k,j];
            }
        }
    }
}
 
// This code is contributed by rutvik_56.


Javascript




<script>
 
function multiply(A, B, C)
{
    for (var i = 0; i < N; i++)
    {
        for (var j = 0; j < N; j++)
        {
            C[i][j] = 0;
            for (var k = 0; k < N; k++)
            {
                C[i][j] += A[i][k]*B[k][j];
            }
        }
    }
}
 
</script>


Time Complexity of above method is O(N3). 

Divide and Conquer :

Following is simple Divide and Conquer method to multiply two square matrices. 

  1. Divide matrices A and B in 4 sub-matrices of size N/2 x N/2 as shown in the below diagram. 
  2. Calculate following values recursively. ae + bg, af + bh, ce + dg and cf + dh. 

strassen_new

Implementation:

C++




#include <bits/stdc++.h>
using namespace std;
 
#define ROW_1 4
#define COL_1 4
 
#define ROW_2 4
#define COL_2 4
 
void print(string display, vector<vector<int> > matrix,
           int start_row, int start_column, int end_row,
           int end_column)
{
    cout << endl << display << " =>" << endl;
    for (int i = start_row; i <= end_row; i++) {
        for (int j = start_column; j <= end_column; j++) {
            cout << setw(10);
            cout << matrix[i][j];
        }
        cout << endl;
    }
    cout << endl;
    return;
}
 
void add_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B,
                vector<vector<int> >& matrix_C,
                int split_index)
{
    for (auto i = 0; i < split_index; i++)
        for (auto j = 0; j < split_index; j++)
            matrix_C[i][j]
                = matrix_A[i][j] + matrix_B[i][j];
}
 
vector<vector<int> >
multiply_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B)
{
    int col_1 = matrix_A[0].size();
    int row_1 = matrix_A.size();
    int col_2 = matrix_B[0].size();
    int row_2 = matrix_B.size();
 
    if (col_1 != row_2) {
        cout << "\nError: The number of columns in Matrix "
                "A  must be equal to the number of rows in "
                "Matrix B\n";
        return {};
    }
 
    vector<int> result_matrix_row(col_2, 0);
    vector<vector<int> > result_matrix(row_1,
                                       result_matrix_row);
 
    if (col_1 == 1)
        result_matrix[0][0]
            = matrix_A[0][0] * matrix_B[0][0];
    else {
        int split_index = col_1 / 2;
 
        vector<int> row_vector(split_index, 0);
        vector<vector<int> > result_matrix_00(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_01(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_10(split_index,
                                              row_vector);
        vector<vector<int> > result_matrix_11(split_index,
                                              row_vector);
 
        vector<vector<int> > a00(split_index, row_vector);
        vector<vector<int> > a01(split_index, row_vector);
        vector<vector<int> > a10(split_index, row_vector);
        vector<vector<int> > a11(split_index, row_vector);
        vector<vector<int> > b00(split_index, row_vector);
        vector<vector<int> > b01(split_index, row_vector);
        vector<vector<int> > b10(split_index, row_vector);
        vector<vector<int> > b11(split_index, row_vector);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                a00[i][j] = matrix_A[i][j];
                a01[i][j] = matrix_A[i][j + split_index];
                a10[i][j] = matrix_A[split_index + i][j];
                a11[i][j] = matrix_A[i + split_index]
                                    [j + split_index];
                b00[i][j] = matrix_B[i][j];
                b01[i][j] = matrix_B[i][j + split_index];
                b10[i][j] = matrix_B[split_index + i][j];
                b11[i][j] = matrix_B[i + split_index]
                                    [j + split_index];
            }
 
        add_matrix(multiply_matrix(a00, b00),
                   multiply_matrix(a01, b10),
                   result_matrix_00, split_index);
        add_matrix(multiply_matrix(a00, b01),
                   multiply_matrix(a01, b11),
                   result_matrix_01, split_index);
        add_matrix(multiply_matrix(a10, b00),
                   multiply_matrix(a11, b10),
                   result_matrix_10, split_index);
        add_matrix(multiply_matrix(a10, b01),
                   multiply_matrix(a11, b11),
                   result_matrix_11, split_index);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                result_matrix[i][j]
                    = result_matrix_00[i][j];
                result_matrix[i][j + split_index]
                    = result_matrix_01[i][j];
                result_matrix[split_index + i][j]
                    = result_matrix_10[i][j];
                result_matrix[i + split_index]
                             [j + split_index]
                    = result_matrix_11[i][j];
            }
 
        result_matrix_00.clear();
        result_matrix_01.clear();
        result_matrix_10.clear();
        result_matrix_11.clear();
        a00.clear();
        a01.clear();
        a10.clear();
        a11.clear();
        b00.clear();
        b01.clear();
        b10.clear();
        b11.clear();
    }
    return result_matrix;
}
 
int main()
{
    vector<vector<int> > matrix_A = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array A", matrix_A, 0, 0, ROW_1 - 1, COL_1 - 1);
 
    vector<vector<int> > matrix_B = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array B", matrix_B, 0, 0, ROW_2 - 1, COL_2 - 1);
 
    vector<vector<int> > result_matrix(
        multiply_matrix(matrix_A, matrix_B));
 
    print("Result Array", result_matrix, 0, 0, ROW_1 - 1,
          COL_2 - 1);
}
 
// Time Complexity: O(n^3)
// Code Contributed By: lucasletum


Java




//Java program to find the resultant
//product matrix for a given pair of matrices
//using Divide and Conquer Approach
 
import java.io.*;
import java.util.*;
 
class GFG {
 
  static int ROW_1 = 4,COL_1 = 4, ROW_2 = 4, COL_2 = 4;
 
  public static void printMat(int[][] a, int r, int c){
    for(int i=0;i<r;i++){
      for(int j=0;j<c;j++){
        System.out.print(a[i][j]+" ");
      }
      System.out.println("");
    }
    System.out.println("");
  }
 
  public static void print(String display, int[][] matrix,int start_row, int start_column, int end_row,int end_column)
  {
    System.out.println(display + " =>\n");
    for (int i = start_row; i <= end_row; i++) {
      for (int j = start_column; j <= end_column; j++) {
        //cout << setw(10);
        System.out.print(matrix[i][j]+" ");
      }
      System.out.println("");
    }
    System.out.println("");
  }
 
  public static void add_matrix(int[][] matrix_A,int[][] matrix_B,int[][] matrix_C, int split_index)
  {
    for (int i = 0; i < split_index; i++){
      for (int j = 0; j < split_index; j++){
        matrix_C[i][j] = matrix_A[i][j] + matrix_B[i][j];
      }
    }
  }
 
  public static void initWithZeros(int a[][], int r, int c){
    for(int i=0;i<r;i++){
      for(int j=0;j<c;j++){
        a[i][j]=0;
      }
    }
  }
 
  public static int[][] multiply_matrix(int[][] matrix_A,int[][] matrix_B)
  {
    int col_1 = matrix_A[0].length;
    int row_1 = matrix_A.length;
    int col_2 = matrix_B[0].length;
    int row_2 = matrix_B.length;
 
    if (col_1 != row_2) {
      System.out.println("\nError: The number of columns in Matrix A  must be equal to the number of rows in Matrix B\n");
      int temp[][] = new int[1][1];
      temp[0][0]=0;
      return temp;
    }
 
    int[] result_matrix_row = new int[col_2];
    Arrays.fill(result_matrix_row,0);
    int[][] result_matrix = new int[row_1][col_2];
    initWithZeros(result_matrix,row_1,col_2);
 
    if (col_1 == 1){
      result_matrix[0][0] = matrix_A[0][0] * matrix_B[0][0];
    }else {
      int split_index = col_1 / 2;
 
      int[] row_vector = new int[split_index];
      Arrays.fill(row_vector,0);
 
      int[][] result_matrix_00 = new int[split_index][split_index];
      int[][] result_matrix_01 = new int[split_index][split_index];
      int[][] result_matrix_10 = new int[split_index][split_index];
      int[][] result_matrix_11 = new int[split_index][split_index];
      initWithZeros(result_matrix_00,split_index,split_index);
      initWithZeros(result_matrix_01,split_index,split_index);
      initWithZeros(result_matrix_10,split_index,split_index);
      initWithZeros(result_matrix_11,split_index,split_index);
 
      int[][] a00 = new int[split_index][split_index];
      int[][] a01 = new int[split_index][split_index];
      int[][] a10 = new int[split_index][split_index];
      int[][] a11 = new int[split_index][split_index];
      int[][] b00 = new int[split_index][split_index];
      int[][] b01 = new int[split_index][split_index];
      int[][] b10 = new int[split_index][split_index];
      int[][] b11 = new int[split_index][split_index];
      initWithZeros(a00,split_index,split_index);
      initWithZeros(a01,split_index,split_index);
      initWithZeros(a10,split_index,split_index);
      initWithZeros(a11,split_index,split_index);
      initWithZeros(b00,split_index,split_index);
      initWithZeros(b01,split_index,split_index);
      initWithZeros(b10,split_index,split_index);
      initWithZeros(b11,split_index,split_index);
 
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          a00[i][j] = matrix_A[i][j];
          a01[i][j] = matrix_A[i][j + split_index];
          a10[i][j] = matrix_A[split_index + i][j];
          a11[i][j] = matrix_A[i + split_index][j + split_index];
          b00[i][j] = matrix_B[i][j];
          b01[i][j] = matrix_B[i][j + split_index];
          b10[i][j] = matrix_B[split_index + i][j];
          b11[i][j] = matrix_B[i + split_index][j + split_index];
        }
      }
 
      add_matrix(multiply_matrix(a00, b00),multiply_matrix(a01, b10),result_matrix_00, split_index);
      add_matrix(multiply_matrix(a00, b01),multiply_matrix(a01, b11),result_matrix_01, split_index);
      add_matrix(multiply_matrix(a10, b00),multiply_matrix(a11, b10),result_matrix_10, split_index);
      add_matrix(multiply_matrix(a10, b01),multiply_matrix(a11, b11),result_matrix_11, split_index);
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          result_matrix[i][j] = result_matrix_00[i][j];
          result_matrix[i][j + split_index] = result_matrix_01[i][j];
          result_matrix[split_index + i][j] = result_matrix_10[i][j];
          result_matrix[i + split_index] [j + split_index] = result_matrix_11[i][j];
        }
      }
    }
    return result_matrix;
  }
 
  public static void main (String[] args) {
    int[][] matrix_A = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    System.out.println("Array A =>");
    printMat(matrix_A,4,4);
 
    int[][] matrix_B = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    System.out.println("Array B =>");
    printMat(matrix_B,4,4);
 
    int[][] result_matrix =  multiply_matrix(matrix_A, matrix_B);
 
    System.out.println("Result Array =>");
    printMat(result_matrix,4,4);
  }
}
// Time Complexity: O(n^3)
//This code is contributed by shruti456rawal


Python3




# Python program to find the resultant
# product matrix for a given pair of matrices
# using Divide and Conquer Approach
 
ROW_1 = 4
COL_1 = 4
ROW_2 = 4
COL_2 = 4
 
#Function to print the matrix
def printMat(a, r, c):
    for i in range(r):
        for j in range(c):
            print(a[i][j], end = " ")
        print()
    print()
 
#Function to print the matrix
def printt(display, matrix, start_row, start_column, end_row,end_column):
    print(display + " =>\n")
    for i in range(start_row, end_row+1):
        for j in range(start_column, end_column+1):
            print(matrix[i][j], end=" ")
        print()
    print()
 
#Function to add two matrices
def add_matrix(matrix_A, matrix_B, matrix_C, split_index):
    for i in range(split_index):
        for j in range(split_index):
            matrix_C[i][j] = matrix_A[i][j] + matrix_B[i][j]
 
#Function to initialize matrix with zeros
def initWithZeros(a, r, c):
    for i in range(r):
        for j in range(c):
            a[i][j] = 0
 
#Function to multiply two matrices
def multiply_matrix(matrix_A, matrix_B):
    col_1 = len(matrix_A[0])
    row_1 = len(matrix_A)
    col_2 = len(matrix_B[0])
    row_2 = len(matrix_B)
 
    if (col_1 != row_2):
        print("\nError: The number of columns in Matrix A  must be equal to the number of rows in Matrix B\n")
        return 0
 
    result_matrix_row = [0] * col_2
    result_matrix = [[0 for x in range(col_2)] for y in range(row_1)]
 
    if (col_1 == 1):
        result_matrix[0][0] = matrix_A[0][0] * matrix_B[0][0]
 
    else:
        split_index = col_1 // 2
 
        row_vector = [0] * split_index
        result_matrix_00 = [[0 for x in range(split_index)] for y in range(split_index)]
        result_matrix_01 = [[0 for x in range(split_index)] for y in range(split_index)]
        result_matrix_10 = [[0 for x in range(split_index)] for y in range(split_index)]
        result_matrix_11 = [[0 for x in range(split_index)] for y in range(split_index)]
        a00 = [[0 for x in range(split_index)] for y in range(split_index)]
        a01 = [[0 for x in range(split_index)] for y in range(split_index)]
        a10 = [[0 for x in range(split_index)] for y in range(split_index)]
        a11 = [[0 for x in range(split_index)] for y in range(split_index)]
        b00 = [[0 for x in range(split_index)] for y in range(split_index)]
        b01 = [[0 for x in range(split_index)] for y in range(split_index)]
        b10 = [[0 for x in range(split_index)] for y in range(split_index)]
        b11 = [[0 for x in range(split_index)] for y in range(split_index)]
 
        for i in range(split_index):
            for j in range(split_index):
                a00[i][j] = matrix_A[i][j]
                a01[i][j] = matrix_A[i][j + split_index]
                a10[i][j] = matrix_A[split_index + i][j]
                a11[i][j] = matrix_A[i + split_index][j + split_index]
                b00[i][j] = matrix_B[i][j]
                b01[i][j] = matrix_B[i][j + split_index]
                b10[i][j] = matrix_B[split_index + i][j]
                b11[i][j] = matrix_B[i + split_index][j + split_index]
 
        add_matrix(multiply_matrix(a00, b00),multiply_matrix(a01, b10),result_matrix_00, split_index)
        add_matrix(multiply_matrix(a00, b01),multiply_matrix(a01, b11),result_matrix_01, split_index)
        add_matrix(multiply_matrix(a10, b00),multiply_matrix(a11, b10),result_matrix_10, split_index)
        add_matrix(multiply_matrix(a10, b01),multiply_matrix(a11, b11),result_matrix_11, split_index)
 
        for i in range(split_index):
            for j in range(split_index):
                result_matrix[i][j] = result_matrix_00[i][j]
                result_matrix[i][j + split_index] = result_matrix_01[i][j]
                result_matrix[split_index + i][j] = result_matrix_10[i][j]
                result_matrix[i + split_index][j + split_index] = result_matrix_11[i][j]
 
    return result_matrix
 
# Driver Code
matrix_A = [ [1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3],
            [2, 2, 2, 2] ]
 
print("Array A =>")
printMat(matrix_A,4,4)
 
matrix_B = [ [1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3],
            [2, 2, 2, 2] ]
 
print("Array B =>")
printMat(matrix_B,4,4)
 
result_matrix = multiply_matrix(matrix_A, matrix_B)
 
print("Result Array =>")
printMat(result_matrix,4,4)


C#




//C# program to find the resultant
//product matrix for a given pair of matrices
//using Divide and Conquer Approach
 
using System;
using System.Collections.Generic;
 
class GFG {
 
  static int ROW_1 = 4,COL_1 = 4, ROW_2 = 4, COL_2 = 4;
 
  public static void printMat(int[, ] a, int r, int c){
    for(int i = 0; i < r; i++){
      for(int j = 0; j < c; j++){
        Console.Write(a[i, j]+" ");
      }
      Console.WriteLine("");
    }
    Console.WriteLine("");
  }
 
  public static void print(string display, int[, ] matrix,int start_row, int start_column, int end_row,int end_column)
  {
    Console.WriteLine(display + " =>\n");
    for (int i = start_row; i <= end_row; i++) {
      for (int j = start_column; j <= end_column; j++) {
        Console.Write(matrix[i, j]+" ");
      }
      Console.WriteLine("");
    }
    Console.WriteLine("");
  }
 
  public static void add_matrix(int[, ] matrix_A,int[, ] matrix_B,int[, ] matrix_C, int split_index)
  {
    for (int i = 0; i < split_index; i++){
      for (int j = 0; j < split_index; j++){
        matrix_C[i, j] = matrix_A[i, j] + matrix_B[i, j];
      }
    }
  }
 
  public static void initWithZeros(int[, ] a, int r, int c){
    for(int i=0;i<r;i++){
      for(int j=0;j<c;j++){
        a[i, j]=0;
      }
    }
  }
 
  public static int[, ] multiply_matrix(int[, ] matrix_A,int[, ] matrix_B)
  {
    int col_1 = matrix_A.GetLength(1);
    int row_1 = matrix_A.GetLength(0);
    int col_2 = matrix_B.GetLength(1);
    int row_2 = matrix_B.GetLength(0);
 
    if (col_1 != row_2) {
      Console.WriteLine("\nError: The number of columns in Matrix A  must be equal to the number of rows in Matrix B\n");
      int[, ] temp = new int[1, 1];
      temp[0, 0]=0;
      return temp;
    }
 
    int[] result_matrix_row = new int[col_2];
    Array.Fill(result_matrix_row,0);
    int[, ] result_matrix = new int[row_1, col_2];
    initWithZeros(result_matrix,row_1,col_2);
 
    if (col_1 == 1){
      result_matrix[0, 0] = matrix_A[0, 0] * matrix_B[0, 0];
    }else {
      int split_index = col_1 / 2;
 
      int[] row_vector = new int[split_index];
      Array.Fill(row_vector,0);
 
      int[, ] result_matrix_00 = new int[split_index, split_index];
      int[, ] result_matrix_01 = new int[split_index, split_index];
      int[, ] result_matrix_10 = new int[split_index, split_index];
      int[, ] result_matrix_11 = new int[split_index, split_index];
      initWithZeros(result_matrix_00,split_index,split_index);
      initWithZeros(result_matrix_01,split_index,split_index);
      initWithZeros(result_matrix_10,split_index,split_index);
      initWithZeros(result_matrix_11,split_index,split_index);
 
      int[, ] a00 = new int[split_index, split_index];
      int[, ] a01 = new int[split_index, split_index];
      int[, ] a10 = new int[split_index, split_index];
      int[, ] a11 = new int[split_index, split_index];
      int[, ] b00 = new int[split_index, split_index];
      int[, ] b01 = new int[split_index, split_index];
      int[, ] b10 = new int[split_index, split_index];
      int[, ] b11 = new int[split_index, split_index];
      initWithZeros(a00,split_index,split_index);
      initWithZeros(a01,split_index,split_index);
      initWithZeros(a10,split_index,split_index);
      initWithZeros(a11,split_index,split_index);
      initWithZeros(b00,split_index,split_index);
      initWithZeros(b01,split_index,split_index);
      initWithZeros(b10,split_index,split_index);
      initWithZeros(b11,split_index,split_index);
 
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          a00[i, j] = matrix_A[i, j];
          a01[i, j] = matrix_A[i, j + split_index];
          a10[i, j] = matrix_A[split_index + i, j];
          a11[i, j] = matrix_A[i + split_index, j + split_index];
          b00[i, j] = matrix_B[i, j];
          b01[i, j] = matrix_B[i, j + split_index];
          b10[i, j] = matrix_B[split_index + i, j];
          b11[i, j] = matrix_B[i + split_index, j + split_index];
        }
      }
 
      add_matrix(multiply_matrix(a00, b00),multiply_matrix(a01, b10),result_matrix_00, split_index);
      add_matrix(multiply_matrix(a00, b01),multiply_matrix(a01, b11),result_matrix_01, split_index);
      add_matrix(multiply_matrix(a10, b00),multiply_matrix(a11, b10),result_matrix_10, split_index);
      add_matrix(multiply_matrix(a10, b01),multiply_matrix(a11, b11),result_matrix_11, split_index);
 
      for (int i = 0; i < split_index; i++){
        for (int j = 0; j < split_index; j++) {
          result_matrix[i, j] = result_matrix_00[i, j];
          result_matrix[i, j + split_index] = result_matrix_01[i, j];
          result_matrix[split_index + i, j] = result_matrix_10[i, j];
          result_matrix[i + split_index, j + split_index] = result_matrix_11[i, j];
        }
      }
    }
    return result_matrix;
  }
 
  public static void Main (string[] args) {
    int[, ] matrix_A = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    Console.WriteLine("Array A =>");
    printMat(matrix_A,4,4);
 
    int[, ] matrix_B = { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
    Console.WriteLine("Array B =>");
    printMat(matrix_B,4,4);
 
    int[, ] result_matrix =  multiply_matrix(matrix_A, matrix_B);
 
    Console.WriteLine("Result Array =>");
    printMat(result_matrix,4,4);
  }
}
 
// Time Complexity: O(n^3)
//This code is contributed by phasing17


Javascript




function multiplyMatrix(matrixA, matrixB) {
  const n = matrixA.length;
  const resultMatrix = new Array(n).fill(null).map(() => new Array(n).fill(0));
 
  if (n === 1) {
    resultMatrix[0][0] = matrixA[0][0] * matrixB[0][0];
    return resultMatrix;
  }
 
  const splitIndex = n / 2;
 
  const a00 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const a01 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const a10 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const a11 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
 
  const b00 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const b01 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const b10 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
  const b11 = new Array(splitIndex).fill(null).map(() => new Array(splitIndex).fill(0));
 
  for (let i = 0; i < splitIndex; i++) {
    for (let j = 0; j < splitIndex; j++) {
      a00[i][j] = matrixA[i][j];
      a01[i][j] = matrixA[i][j + splitIndex];
      a10[i][j] = matrixA[splitIndex + i][j];
      a11[i][j] = matrixA[i + splitIndex][j + splitIndex];
 
      b00[i][j] = matrixB[i][j];
      b01[i][j] = matrixB[i][j + splitIndex];
      b10[i][j] = matrixB[splitIndex + i][j];
      b11[i][j] = matrixB[i + splitIndex][j + splitIndex];
    }
  }
 
  const resultMatrix00 = addMatrix(multiplyMatrix(a00, b00), multiplyMatrix(a01, b10), splitIndex);
  const resultMatrix01 = addMatrix(multiplyMatrix(a00, b01), multiplyMatrix(a01, b11), splitIndex);
  const resultMatrix10 = addMatrix(multiplyMatrix(a10, b00), multiplyMatrix(a11, b10), splitIndex);
  const resultMatrix11 = addMatrix(multiplyMatrix(a10, b01), multiplyMatrix(a11, b11), splitIndex);
 
  for (let i = 0; i < splitIndex; i++) {
    for (let j = 0; j < splitIndex; j++) {
      resultMatrix[i][j] = resultMatrix00[i][j];
      resultMatrix[i][j + splitIndex] = resultMatrix01[i][j];
      resultMatrix[splitIndex + i][j] = resultMatrix10[i][j];
      resultMatrix[i + splitIndex][j + splitIndex] = resultMatrix11[i][j];
    }
  }
 
  return resultMatrix;
}
 
function addMatrix(matrixA, matrixB, n) {
  const resultMatrix = new Array(n).fill(null).map(() => new Array(n).fill(0));
 
  for (let i = 0; i < n; i++) {
    for (let j = 0; j < n; j++) {
      resultMatrix[i][j] = matrixA[i][j] + matrixB[i][j];
    }
  }
 
  return resultMatrix;
}
const matrixA = [
  [1, 1, 1, 1],
  [2, 2, 2, 2],
  [3, 3, 3, 3],
  [2, 2, 2, 2]
];
console.log("Array A =>")
console.log(matrixA);
const matrixB = [
  [1, 1, 1, 1],
  [2, 2, 2, 2],
  [3, 3, 3, 3],
  [2, 2, 2, 2]
];
console.log("Array B =>")
console.log(matrixB);
const resultMatrix = multiplyMatrix(matrixA, matrixB);
 
console.log("Result Array =>")
console.log(resultMatrix);


Output

Array A =>
         1         1         1         1
         2         2         2         2
         3         3         3         3
         2         2         2         2


Array B =>
         1         1         1         1
         2         2         2         2
         3         3         3         3
         2         2         2         2


Result Array =>
         8         8         8         8
        16        16        16        16
        24        24        24        24
        16        16        16        16

In the above method, we do 8 multiplications for matrices of size N/2 x N/2 and 4 additions. Addition of two matrices takes O(N2) time. So the time complexity can be written as 

T(N) = 8T(N/2) + O(N2)  

From Master's Theorem, time complexity of above method is O(N3)
which is unfortunately same as the above naive method.

Simple Divide and Conquer also leads to O(N3), can there be a better way? 

In the above divide and conquer method, the main component for high time complexity is 8 recursive calls. The idea of Strassen’s method is to reduce the number of recursive calls to 7. Strassen’s method is similar to above simple divide and conquer method in the sense that this method also divide matrices to sub-matrices of size N/2 x N/2 as shown in the above diagram, but in Strassen’s method, the four sub-matrices of result are calculated using following formulae.
 

stressen_formula_new_new

Time Complexity of Strassen’s Method

Addition and Subtraction of two matrices takes O(N2) time. So time complexity can be written as 

T(N) = 7T(N/2) +  O(N2)

From Master's Theorem, time complexity of above method is 
O(NLog7) which is approximately O(N2.8074)

Generally Strassen’s Method is not preferred for practical applications for following reasons. 

  1. The constants used in Strassen’s method are high and for a typical application Naive method works better. 
  2. For Sparse matrices, there are better methods especially designed for them. 
  3. The submatrices in recursion take extra space. 
  4. Because of the limited precision of computer arithmetic on noninteger values, larger errors accumulate in Strassen’s algorithm than in Naive Method

Implementation:

C++




#include <bits/stdc++.h>
using namespace std;
 
#define ROW_1 4
#define COL_1 4
 
#define ROW_2 4
#define COL_2 4
 
void print(string display, vector<vector<int> > matrix,
           int start_row, int start_column, int end_row,
           int end_column)
{
    cout << endl << display << " =>" << endl;
    for (int i = start_row; i <= end_row; i++) {
        for (int j = start_column; j <= end_column; j++) {
            cout << setw(10);
            cout << matrix[i][j];
        }
        cout << endl;
    }
    cout << endl;
    return;
}
 
vector<vector<int> >
add_matrix(vector<vector<int> > matrix_A,
           vector<vector<int> > matrix_B, int split_index,
           int multiplier = 1)
{
    for (auto i = 0; i < split_index; i++)
        for (auto j = 0; j < split_index; j++)
            matrix_A[i][j]
                = matrix_A[i][j]
                  + (multiplier * matrix_B[i][j]);
    return matrix_A;
}
 
vector<vector<int> >
multiply_matrix(vector<vector<int> > matrix_A,
                vector<vector<int> > matrix_B)
{
    int col_1 = matrix_A[0].size();
    int row_1 = matrix_A.size();
    int col_2 = matrix_B[0].size();
    int row_2 = matrix_B.size();
 
    if (col_1 != row_2) {
        cout << "\nError: The number of columns in Matrix "
                "A  must be equal to the number of rows in "
                "Matrix B\n";
        return {};
    }
 
    vector<int> result_matrix_row(col_2, 0);
    vector<vector<int> > result_matrix(row_1,
                                       result_matrix_row);
 
    if (col_1 == 1)
        result_matrix[0][0]
            = matrix_A[0][0] * matrix_B[0][0];
    else {
        int split_index = col_1 / 2;
 
        vector<int> row_vector(split_index, 0);
 
        vector<vector<int> > a00(split_index, row_vector);
        vector<vector<int> > a01(split_index, row_vector);
        vector<vector<int> > a10(split_index, row_vector);
        vector<vector<int> > a11(split_index, row_vector);
        vector<vector<int> > b00(split_index, row_vector);
        vector<vector<int> > b01(split_index, row_vector);
        vector<vector<int> > b10(split_index, row_vector);
        vector<vector<int> > b11(split_index, row_vector);
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                a00[i][j] = matrix_A[i][j];
                a01[i][j] = matrix_A[i][j + split_index];
                a10[i][j] = matrix_A[split_index + i][j];
                a11[i][j] = matrix_A[i + split_index]
                                    [j + split_index];
                b00[i][j] = matrix_B[i][j];
                b01[i][j] = matrix_B[i][j + split_index];
                b10[i][j] = matrix_B[split_index + i][j];
                b11[i][j] = matrix_B[i + split_index]
                                    [j + split_index];
            }
 
        vector<vector<int> > p(multiply_matrix(
            a00, add_matrix(b01, b11, split_index, -1)));
        vector<vector<int> > q(multiply_matrix(
            add_matrix(a00, a01, split_index), b11));
        vector<vector<int> > r(multiply_matrix(
            add_matrix(a10, a11, split_index), b00));
        vector<vector<int> > s(multiply_matrix(
            a11, add_matrix(b10, b00, split_index, -1)));
        vector<vector<int> > t(multiply_matrix(
            add_matrix(a00, a11, split_index),
            add_matrix(b00, b11, split_index)));
        vector<vector<int> > u(multiply_matrix(
            add_matrix(a01, a11, split_index, -1),
            add_matrix(b10, b11, split_index)));
        vector<vector<int> > v(multiply_matrix(
            add_matrix(a00, a10, split_index, -1),
            add_matrix(b00, b01, split_index)));
 
        vector<vector<int> > result_matrix_00(add_matrix(
            add_matrix(add_matrix(t, s, split_index), u,
                       split_index),
            q, split_index, -1));
        vector<vector<int> > result_matrix_01(
            add_matrix(p, q, split_index));
        vector<vector<int> > result_matrix_10(
            add_matrix(r, s, split_index));
        vector<vector<int> > result_matrix_11(add_matrix(
            add_matrix(add_matrix(t, p, split_index), r,
                       split_index, -1),
            v, split_index, -1));
 
        for (auto i = 0; i < split_index; i++)
            for (auto j = 0; j < split_index; j++) {
                result_matrix[i][j]
                    = result_matrix_00[i][j];
                result_matrix[i][j + split_index]
                    = result_matrix_01[i][j];
                result_matrix[split_index + i][j]
                    = result_matrix_10[i][j];
                result_matrix[i + split_index]
                             [j + split_index]
                    = result_matrix_11[i][j];
            }
 
        a00.clear();
        a01.clear();
        a10.clear();
        a11.clear();
        b00.clear();
        b01.clear();
        b10.clear();
        b11.clear();
        p.clear();
        q.clear();
        r.clear();
        s.clear();
        t.clear();
        u.clear();
        v.clear();
        result_matrix_00.clear();
        result_matrix_01.clear();
        result_matrix_10.clear();
        result_matrix_11.clear();
    }
    return result_matrix;
}
 
int main()
{
    vector<vector<int> > matrix_A = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array A", matrix_A, 0, 0, ROW_1 - 1, COL_1 - 1);
 
    vector<vector<int> > matrix_B = { { 1, 1, 1, 1 },
                                      { 2, 2, 2, 2 },
                                      { 3, 3, 3, 3 },
                                      { 2, 2, 2, 2 } };
 
    print("Array B", matrix_B, 0, 0, ROW_2 - 1, COL_2 - 1);
 
    vector<vector<int> > result_matrix(
        multiply_matrix(matrix_A, matrix_B));
 
    print("Result Array", result_matrix, 0, 0, ROW_1 - 1,
          COL_2 - 1);
}
 
// Time Complexity: T(N) = 7T(N/2) +  O(N^2) => O(N^Log7)
// which is approximately O(N^2.8074) Code Contributed By:
// lucasletum


Java




/**
 ** Java Program to Implement Strassen Algorithm
 **/
  
import java.util.Scanner;
  
/** Class Strassen **/
public class Strassen
{
    /** Function to multiply matrices **/
    public int[][] multiply(int[][] A, int[][] B)
    {       
        int n = A.length;
        int[][] R = new int[n][n];
        /** base case **/
        if (n == 1)
            R[0][0] = A[0][0] * B[0][0];
        else
        {
            int[][] A11 = new int[n/2][n/2];
            int[][] A12 = new int[n/2][n/2];
            int[][] A21 = new int[n/2][n/2];
            int[][] A22 = new int[n/2][n/2];
            int[][] B11 = new int[n/2][n/2];
            int[][] B12 = new int[n/2][n/2];
            int[][] B21 = new int[n/2][n/2];
            int[][] B22 = new int[n/2][n/2];
  
            /** Dividing matrix A into 4 halves **/
            split(A, A11, 0 , 0);
            split(A, A12, 0 , n/2);
            split(A, A21, n/2, 0);
            split(A, A22, n/2, n/2);
            /** Dividing matrix B into 4 halves **/
            split(B, B11, 0 , 0);
            split(B, B12, 0 , n/2);
            split(B, B21, n/2, 0);
            split(B, B22, n/2, n/2);
  
            /**
              M1 = (A11 + A22)(B11 + B22)
              M2 = (A21 + A22) B11
              M3 = A11 (B12 - B22)
              M4 = A22 (B21 - B11)
              M5 = (A11 + A12) B22
              M6 = (A21 - A11) (B11 + B12)
              M7 = (A12 - A22) (B21 + B22)
            **/
  
            int [][] M1 = multiply(add(A11, A22), add(B11, B22));
            int [][] M2 = multiply(add(A21, A22), B11);
            int [][] M3 = multiply(A11, sub(B12, B22));
            int [][] M4 = multiply(A22, sub(B21, B11));
            int [][] M5 = multiply(add(A11, A12), B22);
            int [][] M6 = multiply(sub(A21, A11), add(B11, B12));
            int [][] M7 = multiply(sub(A12, A22), add(B21, B22));
  
            /**
              C11 = M1 + M4 - M5 + M7
              C12 = M3 + M5
              C21 = M2 + M4
              C22 = M1 - M2 + M3 + M6
            **/
            int [][] C11 = add(sub(add(M1, M4), M5), M7);
            int [][] C12 = add(M3, M5);
            int [][] C21 = add(M2, M4);
            int [][] C22 = add(sub(add(M1, M3), M2), M6);
  
            /** join 4 halves into one result matrix **/
            join(C11, R, 0 , 0);
            join(C12, R, 0 , n/2);
            join(C21, R, n/2, 0);
            join(C22, R, n/2, n/2);
        }
        /** return result **/   
        return R;
    }
    /** Function to sub two matrices **/
    public int[][] sub(int[][] A, int[][] B)
    {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                C[i][j] = A[i][j] - B[i][j];
        return C;
    }
    /** Function to add two matrices **/
    public int[][] add(int[][] A, int[][] B)
    {
        int n = A.length;
        int[][] C = new int[n][n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                C[i][j] = A[i][j] + B[i][j];
        return C;
    }
    /** Function to split parent matrix into child matrices **/
    public void split(int[][] P, int[][] C, int iB, int jB)
    {
        for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
                C[i1][j1] = P[i2][j2];
    }
    /** Function to join child matrices intp parent matrix **/
    public void join(int[][] C, int[][] P, int iB, int jB)
    {
        for(int i1 = 0, i2 = iB; i1 < C.length; i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.length; j1++, j2++)
                P[i2][j2] = C[i1][j1];
    }   
    /** Main function **/
    public static void main (String[] args)
    {
        Scanner scan = new Scanner(System.in);
        System.out.println("Strassen Multiplication Algorithm Test\n");
        /** Make an object of Strassen class **/
        Strassen s = new Strassen();
  
         
        int N = 4;
        /** Accept two 2d matrices **/
         
        int[][] A =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
        int[][] B =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
        System.out.println("\nArray A =>");
     
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(A[i][j] +" ");
            System.out.println();
        }
         
        System.out.println("\nArray B =>");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(B[i][j] +" ");
            System.out.println();
        }
  
        int[][] C = s.multiply(A, B);
  
        System.out.println("\nProduct of matrices A and  B : ");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                System.out.print(C[i][j] +" ");
            System.out.println();
        }
  
    }
}


Python3




# Version 3.6
 
import numpy as np
 
def split(matrix):
    """
    Splits a given matrix into quarters.
    Input: nxn matrix
    Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
    """
    row, col = matrix.shape
    row2, col2 = row//2, col//2
    return matrix[:row2, :col2], matrix[:row2, col2:], matrix[row2:, :col2], matrix[row2:, col2:]
 
def strassen(x, y):
    """
    Computes matrix product by divide and conquer approach, recursively.
    Input: nxn matrices x and y
    Output: nxn matrix, product of x and y
    """
 
    # Base case when size of matrices is 1x1
    if len(x) == 1:
        return x * y
 
    # Splitting the matrices into quadrants. This will be done recursively
    # until the base case is reached.
    a, b, c, d = split(x)
    e, f, g, h = split(y)
 
    # Computing the 7 products, recursively (p1, p2...p7)
    p1 = strassen(a, f - h) 
    p2 = strassen(a + b, h)       
    p3 = strassen(c + d, e)       
    p4 = strassen(d, g - e)       
    p5 = strassen(a + d, e + h)       
    p6 = strassen(b - d, g + h) 
    p7 = strassen(a - c, e + f) 
 
    # Computing the values of the 4 quadrants of the final matrix c
    c11 = p5 + p4 - p2 + p6 
    c12 = p1 + p2          
    c21 = p3 + p4           
    c22 = p1 + p5 - p3 - p7 
 
    # Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22))))
 
    return c


C#




using System;
 
/**
 ** C# Program to Implement Strassen Algorithm
 **/
 
/** Class Strassen **/
 
public class Strassen
{
    /** Function to multiply matrices **/
    public int[,] Multiply(int[,] A, int[,] B)
    {
        int n = A.GetLength(0);
        int[,] R = new int[n, n];
         
        /** base case **/
        if (n == 1)
        {
            R[0, 0] = A[0, 0] * B[0, 0];
        }
        else
        {
            int[,] A11 = new int[n / 2, n / 2];
            int[,] A12 = new int[n / 2, n / 2];
            int[,] A21 = new int[n / 2, n / 2];
            int[,] A22 = new int[n / 2, n / 2];
            int[,] B11 = new int[n / 2, n / 2];
            int[,] B12 = new int[n / 2, n / 2];
            int[,] B21 = new int[n / 2, n / 2];
            int[,] B22 = new int[n / 2, n / 2];
         
             
            /** Dividing matrix A into 4 halves **/
            Split(A, A11, 0, 0);
            Split(A, A12, 0, n / 2);
            Split(A, A21, n / 2, 0);
            Split(A, A22, n / 2, n / 2);
             
            /** Dividing matrix B into 4 halves **/
            Split(B, B11, 0, 0);
            Split(B, B12, 0, n / 2);
            Split(B, B21, n / 2, 0);
            Split(B, B22, n / 2, n / 2);
 
 
            /**
              M1 = (A11 + A22)(B11 + B22)
              M2 = (A21 + A22) B11
              M3 = A11 (B12 - B22)
              M4 = A22 (B21 - B11)
              M5 = (A11 + A12) B22
              M6 = (A21 - A11) (B11 + B12)
              M7 = (A12 - A22) (B21 + B22)
            **/
            int[,] M1 = Multiply(Add(A11, A22), Add(B11, B22));
            int[,] M2 = Multiply(Add(A21, A22), B11);
            int[,] M3 = Multiply(A11, Sub(B12, B22));
            int[,] M4 = Multiply(A22, Sub(B21, B11));
            int[,] M5 = Multiply(Add(A11, A12), B22);
            int[,] M6 = Multiply(Sub(A21, A11), Add(B11, B12));
            int[,] M7 = Multiply(Sub(A12, A22), Add(B21, B22));
             
            /**
              C11 = M1 + M4 - M5 + M7
              C12 = M3 + M5
              C21 = M2 + M4
              C22 = M1 - M2 + M3 + M6
            **/
            int[,] C11 = Add(Sub(Add(M1, M4), M5), M7);
            int[,] C12 = Add(M3, M5);
            int[,] C21 = Add(M2, M4);
            int[,] C22 = Add(Sub(Add(M1, M3), M2), M6);
             
            /** join 4 halves into one result matrix **/
            Join(C11, R, 0, 0);
            Join(C12, R, 0, n / 2);
            Join(C21, R, n / 2, 0);
            Join(C22, R, n / 2, n / 2);
        }
             
        /** return result **/ 
        return R;
    }
     
    /** Function to sub two matrices **/
    public int[,] Sub(int[,] A, int[,] B)
    {
        int n = A.GetLength(0);
        int[,] C = new int[n, n];
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                C[i, j] = A[i, j] - B[i, j];
            }
        }
 
        return C;
    }
 
    /** Function to add two matrices **/
    public int[,] Add(int[,] A, int[,] B)
    {
        int n = A.GetLength(0);
        int[,] C = new int[n, n];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++)
                C[i, j] = A[i, j] + B[i, j];
        return C;
    }
    /** Function to split parent matrix into child matrices **/
    public void Split(int[, ] P, int[, ] C, int iB, int jB)
    {
        for(int i1 = 0, i2 = iB; i1 < C.GetLength(0); i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.GetLength(0); j1++, j2++)
                C[i1, j1] = P[i2, j2];
    }
    /** Function to join child matrices intp parent matrix **/
    public void Join(int[, ] C, int[, ] P, int iB, int jB)
    {
        for(int i1 = 0, i2 = iB; i1 < C.GetLength(0); i1++, i2++)
            for(int j1 = 0, j2 = jB; j1 < C.GetLength(0); j1++, j2++)
                P[i2, j2] = C[i1, j1];
    }   
    /** Main function **/
    public static void Main (string[] args)
    {
        Console.WriteLine("Strassen Multiplication Algorithm Test\n");
        /** Make an object of Strassen class **/
        Strassen s = new Strassen();
  
         
        int N = 4;
        /** Accept two 2d matrices **/
         
        int[, ] A =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
 
        int[, ] B =     { { 1, 1, 1, 1 },
                        { 2, 2, 2, 2 },
                        { 3, 3, 3, 3 },
                        { 2, 2, 2, 2 } };
        Console.WriteLine("\nArray A =>");
     
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                Console.Write(A[i, j] +" ");
            Console.WriteLine();
        }
         
        Console.WriteLine("\nArray B =>");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                Console.Write(B[i, j] +" ");
            Console.WriteLine();
        }
  
        int[, ] C = s.Multiply(A, B);
  
        Console.WriteLine("\nProduct of matrices A and  B : ");
        for (int i = 0; i < N; i++)
        {
            for (int j = 0; j < N; j++)
                Console.Write(C[i, j] +" ");
            Console.WriteLine();
        }
  
    }
}


Javascript




function split(matrix) {
 
    /*
        Splits a given matrix into quarters.
    Input: nxn matrix
    Output: tuple containing 4 n/2 x n/2 matrices corresponding to a, b, c, d
*/
    let row = matrix.length,
        col = matrix[0].length;
    let row2 = Math.floor(row / 2),
        col2 = Math.floor(col / 2);
    return [matrix.slice(0, row2).map(x => x.slice(0, col2)),
        matrix.slice(0, row2).map(x => x.slice(col2)),
        matrix.slice(row2).map(x => x.slice(0, col2)),
        matrix.slice(row2).map(x => x.slice(col2))
    ];
}
 
function strassen(x, y) {
 
    /*
    Computes matrix product by divide and conquer approach, recursively.
    Input: nxn matrices x and y
    Output: nxn matrix, product of x and y
    */
    // Base case when size of matrices is 1x1
    if (x.length === 1) {
        return [
            [x[0][0] * y[0][0]]
        ];
    }
 
 
    // Splitting the matrices into quadrants. This will be done recursively
    // until the base case is reached.
    let [a, b, c, d] = split(x);
    let [e, f, g, h] = split(y);
 
    // Computing the 7 products, recursively (p1, p2...p7)
    let p1 = strassen(a, sub(f, h));
    let p2 = strassen(add(a, b), h);
    let p3 = strassen(add(c, d), e);
    let p4 = strassen(d, sub(g, e));
    let p5 = strassen(add(a, d), add(e, h));
    let p6 = strassen(sub(b, d), add(g, h));
    let p7 = strassen(sub(a, c), add(e, f));
 
    // Computing the values of the 4 quadrants of the final matrix c
    let c11 = add(sub(add(p5, p4), p2), p6);
    let c12 = add(p1, p2);
    let c21 = add(p3, p4);
    let c22 = sub(sub(add(p1, p5), p3), p7);
 
    // Combining the 4 quadrants into a single matrix by stacking horizontally and vertically.
    let top = c11.map((x, i) => x.concat(c12[i]));
    let bottom = c21.map((x, i) => x.concat(c22[i]));
    return top.concat(bottom);
}
 
function add(a, b) {
    return a.map((x, i) => x.map((y, j) => y + b[i][j]));
}
 
function sub(a, b) {
    return a.map((x, i) => x.map((y, j) => y - b[i][j]));
}
 
// Driver's code
let A = [
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3],
    [2, 2, 2, 2]
];
 
let B = [
    [1, 1, 1, 1],
    [2, 2, 2, 2],
    [3, 3, 3, 3],
    [2, 2, 2, 2]
];
 
let C = strassen(A, B);
console.log("Array A =>")
console.table(A);
console.log("Array B =>")
console.table(B);
console.log("Result Array =>")
console.table(C);
// This code is contributed by Prajwal Kandekar


Output

Array A =>
         1         1         1         1
         2         2         2         2
         3         3         3         3
         2         2         2         2


Array B =>
         1         1         1         1
         2         2         2         2
         3         3         3         3
         2         2         2         2


Result Array =>
         8         8         8         8
        16        16        16        16
        24        24        24        24
        16        16        16        16

Easy way to remember Strassen’s Matrix Equation
 

References: 
Introduction to Algorithms 3rd Edition by Clifford Stein, Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest 

 



Last Updated : 01 Jun, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads