Open In App

Minimum cost to convert 3 X 3 matrix into magic square

Improve
Improve
Like Article
Like
Save
Share
Report

A Magic Square is a n x n matrix of distinct element from 1 to n2 where the sum of any row, column or diagonal is always equal to same number. 
Consider a 3 X 3 matrix, s, of integers in the inclusive range [1, 9] . We can convert any digit, a, to any other digit, b, in the range [1, 9] at cost |a – b|
Given s, convert it into a magic square at minimal cost by changing zero or more of its digits. The task is to find minimum cost. 

Note: The resulting matrix must contain distinct integers in the inclusive range [1, 9].

Examples: 

Input : mat[][] = { { 4, 9, 2 },
                              { 3, 5, 7 },
                            { 8, 1, 5 }};

Output : 1

Given matrix s is not a magic square. To convert it into magic square we change the bottom right  value, s[2][2], from 5 to 6 at a cost of | 5 – 6 | = 1.

Input : mat[][] = { { 4, 8, 2 },
                                { 4, 5, 7 },
                              { 6, 1, 6 }};

Output : 4

The idea is to find all 3 X 3 magic squares and, for each one, compute the cost of changing mat into a known magic square. The result is the smallest of these costs. 
We know that s will always be 3 X 3. There are 8 possible magic squares for 3 X 3 matrix. 

There are two ways to approach this: 
So, compute all 8 magic squares by examining all permutations of integers 1, 2, 3, ….., 9 and for each one, check if it forms a magic square if the permutation is inserted into the square starting from the upper left hand corner.

Below is C++ implementation of this approach:  

C++




#include <bits/stdc++.h>
using namespace std;
 
// Return if given vector denote the magic square or not.
bool is_magic_square(vector<int> v)
{
    int a[3][3];
 
    // Convert vector into 3 X 3 matrix
    for (int i = 0; i < 3; ++i)
        for (int j = 0; j < 3; ++j)
            a[i][j] = v[3 * i + j];      
 
    int s = 0;
    for (int j = 0; j < 3; ++j)
        s += a[0][j];
 
    // Checking if each row sum is same
    for (int i = 1; i <= 2; ++i) {
        int tmp = 0;
        for (int j = 0; j < 3; ++j)
            tmp += a[i][j];
        if (tmp != s)
            return 0;
    }
 
    // Checking if each column sum is same
    for (int j = 0; j < 3; ++j) {
        int tmp = 0;
        for (int i = 0; i < 3; ++i)
            tmp += a[i][j];
        if (tmp != s)
            return 0;
    }   
 
    // Checking if diagonal 1 sum is same
    int tmp = 0;
    for (int i = 0; i < 3; ++i)
        tmp += a[i][i];
    if (tmp != s)
        return 0;   
 
    // Checking if diagonal 2 sum is same
    tmp = 0;
    for (int i = 0; i < 3; ++i)
        tmp += a[2 - i][i];
    if (tmp != s)
        return 0;
    return 1;
}
 
// Generating all magic square
void find_magic_squares(vector<vector<int> >& magic_squares)
{
    vector<int> v(9);
 
    // Initializing the vector
    for (int i = 0; i < 9; ++i)
        v[i] = i + 1;
 
    // Producing all permutation of vector
    // and checking if it denote the magic square or not.
    do {
        if (is_magic_square(v)) {
            magic_squares.push_back(v);
        }
    } while (next_permutation(v.begin(), v.end()));
}
 
// Return sum of difference between each element of two vector
int diff(vector<int> a, vector<int> b)
{
    int res = 0;
 
    for (int i = 0; i < 9; ++i)
        res += abs(a[i] - b[i]);
 
    return res;
}
 
// Wrapper function
int wrapper(vector<int> v)
{
    int res = INT_MAX;
    vector<vector<int> > magic_squares;
 
    // generating all magic square
    find_magic_squares(magic_squares);
 
    for (int i = 0; i < magic_squares.size(); ++i) {
 
        // Finding the difference with each magic square
        // and assigning the minimum value.
        res = min(res, diff(v, magic_squares[i]));
    }
    return res;
}
 
// Driven Program
int main()
{
    // Taking matrix in vector in row wise to make
    // calculation easy
    vector<int> v;
    v.push_back(4);
    v.push_back(9);
    v.push_back(2);
 
    v.push_back(3);
    v.push_back(5);
    v.push_back(7);
 
    v.push_back(8);
    v.push_back(1);
    v.push_back(5);
 
    cout << wrapper(v) << endl;
 
    return 0;
}


Java




import java.util.*;
 
public class Main {
 
  // Return if given vector denote the magic square or not.
  public static boolean is_magic_square(ArrayList<Integer> v)
  {
    int[][] a = new int[3][3];
 
    // Convert vector into 3 X 3 matrix
    for (int i = 0; i < 3; ++i) {
      for (int j = 0; j < 3; ++j) {
        a[i][j] = v.get(3 * i + j);
      }
    }
    int s = 0;
    for (int j = 0; j < 3; ++j) {
      s += a[0][j];
    }
 
    // Checking if each row sum is same
    for (int i = 1; i <= 2; ++i) {
      int tmp = 0;
      for (int j = 0; j < 3; ++j) {
        tmp += a[i][j];
      }
      if (tmp != s) {
        return false;
      }
    }
 
    // Checking if each column sum is same
    for (int j = 0; j < 3; ++j) {
      int tmp = 0;
      for (int i = 0; i < 3; ++i) {
        tmp += a[i][j];
      }
      if (tmp != s) {
        return false;
      }
    }
 
    // Checking if diagonal 1 sum is same
    int tmp = 0;
    for (int i = 0; i < 3; ++i) {
      tmp += a[i][i];
    }
    if (tmp != s) {
      return false;
    }
 
    // Checking if diagonal 2 sum is same
    tmp = 0;
    for (int i = 0; i < 3; ++i) {
      tmp += a[2 - i][i];
    }
    if (tmp != s) {
      return false;
    }
    return true;
  }
 
  // Generating all magic square
  public static void find_magic_squares(
    ArrayList<ArrayList<Integer> > magic_squares)
  {
    ArrayList<Integer> v = new ArrayList<Integer>();
 
    // Initializing the vector
    for (int i = 0; i < 9; ++i) {
      v.add(i + 1);
    }
 
    // Producing all permutation of vector
    // and checking if it denote the magic square or not.
    do {
      if (is_magic_square(v)) {
        magic_squares.add(
          new ArrayList<Integer>(v));
      }
    } while (next_permutation(v));
  }
 
  public static boolean
    next_permutation(ArrayList<Integer> v)
  {
    int n = v.size();
    int i = n - 2;
    while (i >= 0 && v.get(i) >= v.get(i + 1)) {
      i--;
    }
    if (i < 0) {
      return false;
    }
    int j = n - 1;
    while (v.get(j) <= v.get(i)) {
      j--;
    }
    Collections.swap(v, i, j);
    Collections.reverse(v.subList(i + 1, n));
    return true;
  }
 
  // Return sum of difference between each element of two vector
  public static int diff(ArrayList<Integer> a,
                         ArrayList<Integer> b)
  {
    int res = 0;
    for (int i = 0; i < 9; ++i) {
      res += Math.abs(a.get(i) - b.get(i));
    }
    return res;
  }
 
  // Wrapper function
  public static int wrapper(ArrayList<Integer> v)
  {
    int res = Integer.MAX_VALUE;
    ArrayList<ArrayList<Integer> > magic_squares
      = new ArrayList<ArrayList<Integer> >();
 
    // generating all magic square
    find_magic_squares(magic_squares);
    for (int i = 0; i < magic_squares.size(); ++i) {
 
      // Finding the difference with each magic square
      // and assigning the minimum value.
      res = Math.min(res,
                     diff(v, magic_squares.get(i)));
    }
    return res;
  }
 
  // Driven Program
  public static void main(String[] args)
  {
 
    // Taking matrix in vector in row wise to make
    // calculation easy
    ArrayList<Integer> v = new ArrayList<Integer>();
    v.add(4);
    v.add(9);
    v.add(2);
    v.add(3);
    v.add(5);
    v.add(7);
    v.add(8);
    v.add(1);
    v.add(5);
    System.out.println(wrapper(v));
  }
}
// This code is contributed by Prajwal Kandekar


Python3




from itertools import permutations
from math import inf
 
# Return if given list denote the magic square or not.
def is_magic_square(lst):
    a = [[0 for i in range(3)] for j in range(3)]
     
    # Convert list into 3 X 3 matrix
    for i in range(3):
        for j in range(3):
            a[i][j] = lst[3 * i + j]
     
    s = sum(a[0])
     
    # Checking if each row sum is same
    for i in range(1, 3):
        tmp = sum(a[i])
        if tmp != s:
            return False
     
    # Checking if each column sum is same
    for j in range(3):
        tmp = sum(a[i][j] for i in range(3))
        if tmp != s:
            return False
     
    # Checking if diagonal 1 sum is same
    tmp = sum(a[i][i] for i in range(3))
    if tmp != s:
        return False
     
    # Checking if diagonal 2 sum is same
    tmp = sum(a[2 - i][i] for i in range(3))
    if tmp != s:
        return False
     
    return True
 
# Generating all magic square
def find_magic_squares():
    magic_squares = []
    lst = list(range(1, 10))
 
    # Producing all permutation of list
    # and checking if it denote the magic square or not.
    for v in permutations(lst):
        if is_magic_square(v):
            magic_squares.append(v)
     
    return magic_squares
 
# Return sum of difference between each element of two list
def diff(a, b):
    return sum(abs(a[i] - b[i]) for i in range(9))
 
# Wrapper function
def wrapper(v):
    res = inf
    magic_squares = find_magic_squares()
 
    for i in range(len(magic_squares)):
        # Finding the difference with each magic square
        # and assigning the minimum value.
        res = min(res, diff(v, magic_squares[i]))
     
    return res
 
# Driven Program
if __name__ == '__main__':
    # Taking matrix in list in row wise to make
    # calculation easy
    v = [4, 9, 2, 3, 5, 7, 8, 1, 5]
 
    print(wrapper(v))
 
# This code is contributed by Prince Kumar


Javascript




// Javacript code for the above approach
 
function* permutations(lst) {
  if (lst.length <= 1) yield lst;
  else {
    for (let perm of permutations(lst.slice(1))) {
      for (let i = 0; i < lst.length; i++) {
        yield perm.slice(0, i).concat([lst[0]]).concat(perm.slice(i));
      }
    }
  }
}
 
// Return if given list denote the magic square or not.
function is_magic_square(lst) {
  const a = [[0, 0, 0], [0, 0, 0], [0, 0, 0]];
 
  // Convert list into 3 X 3 matrix
  for (let i = 0; i < 3; i++) {
    for (let j = 0; j < 3; j++) {
      a[i][j] = lst[3 * i + j];
    }
  }
 
  const s = a[0].reduce((acc, val) => acc + val, 0);
 
  // Checking if each row sum is same
  for (let i = 1; i < 3; i++) {
    const tmp = a[i].reduce((acc, val) => acc + val, 0);
    if (tmp !== s) {
      return false;
    }
  }
 
  // Checking if each column sum is same
  for (let j = 0; j < 3; j++) {
    const tmp = a.reduce((acc, row) => acc + row[j], 0);
    if (tmp !== s) {
      return false;
    }
  }
 
  // Checking if diagonal 1 sum is same
  const tmp1 = a[0][0] + a[1][1] + a[2][2];
  if (tmp1 !== s) {
    return false;
  }
 
  // Checking if diagonal 2 sum is same
  const tmp2 = a[0][2] + a[1][1] + a[2][0];
  if (tmp2 !== s) {
    return false;
  }
 
  return true;
}
 
// Generating all magic square
function find_magic_squares() {
  const magic_squares = [];
  const lst = [1, 2, 3, 4, 5, 6, 7, 8, 9];
 
  // Producing all permutation of list
  // and checking if it denote the magic square or not.
  for (let v of permutations(lst)) {
    if (is_magic_square(v)) {
      magic_squares.push(v);
    }
  }
 
  return magic_squares;
}
 
// Return sum of difference between each element of two list
function diff(a, b) {
  return a.reduce((acc, val, idx) => acc + Math.abs(val - b[idx]), 0);
}
 
// Wrapper function
function wrapper(v) {
  let res = Infinity;
  const magic_squares = find_magic_squares();
 
  for (let i = 0; i < magic_squares.length; i++) {
    // Finding the difference with each magic square
    // and assigning the minimum value.
    res = Math.min(res, diff(v, magic_squares[i]));
  }
 
  return res;
}
 
// Driven Program
const v = [4, 9, 2, 3, 5, 7, 8, 1, 5];
console.log(wrapper(v));
 
// This code is contributed by prince


C#




// C# code for the above approach
using System;
using System.Collections.Generic;
using System.Linq;
 
namespace MagicSquare {
class Program {
    static IEnumerable<List<int> >
    Permutations(List<int> lst)
    {
        if (lst.Count <= 1) {
            yield return lst;
        }
        else {
            foreach(var perm in Permutations(
                lst.Skip(1).ToList()))
            {
                for (int i = 0; i < lst.Count; i++) {
                    yield return perm.Take(i)
                        .Concat(new List<int>{ lst[0] })
                        .Concat(perm.Skip(i))
                        .ToList();
                }
            }
        }
    }
 
    // Return if given list denote the magic square or not.
    static bool IsMagicSquare(List<int> lst)
    {
        int[][] a = new int[][] { new int[] { 0, 0, 0 },
                                  new int[] { 0, 0, 0 },
                                  new int[] { 0, 0, 0 } };
 
        // Convert list into 3 X 3 matrix
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < 3; j++) {
                a[i][j] = lst[3 * i + j];
            }
        }
 
        int s = a[0].Sum();
 
        // Checking if each row sum is same
        for (int i = 1; i < 3; i++) {
            int tmp = a[i].Sum();
            if (tmp != s) {
                return false;
            }
        }
 
        // Checking if each column sum is same
        for (int j = 0; j < 3; j++) {
            int tmp = a.Sum(row = > row[j]);
            if (tmp != s) {
                return false;
            }
        }
 
        // Checking if diagonal 1 sum is same
        int tmp1 = a[0][0] + a[1][1] + a[2][2];
        if (tmp1 != s) {
            return false;
        }
 
        // Checking if diagonal 2 sum is same
        int tmp2 = a[0][2] + a[1][1] + a[2][0];
        if (tmp2 != s) {
            return false;
        }
 
        return true;
    }
 
    // Generating all magic square
    static List<List<int> > FindMagicSquares()
    {
        List<List<int> > magicSquares
            = new List<List<int> >();
        List<int> lst
            = new List<int>{ 1, 2, 3, 4, 5, 6, 7, 8, 9 };
 
        // Producing all permutation of list
        // and checking if it denotes a magic square or not.
        foreach(var v in Permutations(lst))
        {
            if (IsMagicSquare(v)) {
                magicSquares.Add(v);
            }
        }
 
        return magicSquares;
    }
    // Return sum of difference between each element of two
    // list
    static int Diff(List<int> a, List<int> b)
    {
        return a.Select((x, i) = > Math.Abs(x - b[i]))
            .Sum();
    }
 
    // Wrapper function
    static int Wrapper(List<int> v)
    {
        int res = int.MaxValue;
        List<List<int> > magicSquares = FindMagicSquares();
 
        foreach(var magicSquare in magicSquares)
        {
            // Finding the difference with each magic square
            // and assigning the minimum value.
            res = Math.Min(res, Diff(v, magicSquare));
        }
 
        return res;
    }
 
    // Driver code
    static void Main(string[] args)
    {
        List<int> lst
            = new List<int>{ 4, 9, 2, 3, 5, 7, 8, 1, 5 };
        int res = Wrapper(lst);
 
        Console.WriteLine(res);
    }
}
}


Output

1

 



Last Updated : 23 Mar, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads