Maximize sum of squares of array elements possible by replacing pairs with their Bitwise AND and Bitwise OR

Given an array arr[] consisting of N integers, the task is to find the maximum sum of the squares of array elements possible from the given array by performing the following operations:

Examples:

Input: arr[] = {1, 3, 5}
Output: 51
Explanation:
For the pair (arr[1], arr[2]), perform the following operations:
Replace 3 with 3 AND 5, which is equal to 1.
Replace 5 with 2 OR 5, which is equal to 7.
The modified array obtained after the above steps is {1, 1, 7}.
Therefore, the maximized sum of the squares can be calculated by 1 * 1 + 1 * 1 + 7 * 7 = 51.

Input: arr[] = {8, 9, 9, 1}
Output: 243

Approach: The idea is to observe that if x and y are the 2 selected elements then let z = x AND y, w = x OR y where x + y = z + w.



Old sum of squares = M = x2 + y2
New sum of squares = N = (x – d)2 + (y – d)2, d > 0
Difference = N – M = 2d(y + d – x), d > 0 

x = 3 =  0 1 1
y = 5 =  1 0 1
z = 1 =  0 0 1
w = 7 = 1 1 1

For Example: arr[] = {5, 2, 3, 4, 5, 6, 7}

1 0 1 = 5
0 1 0 = 2
0 1 1 = 3
1 0 0 = 4
1 0 1 = 5
1 1 0 = 6
1 1 1 = 7
———-
5 4 4 (sum of set bits)
Now, sum up the squares of these numbers. 
 

Below is the implementation fo the above approach:

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Stores the maximum value
int ans = 0;
 
int binary[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
void findMaximumSum(
    const vector<int>& arr, int n)
{
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for (auto x : arr) {
 
        int idx = 0;
 
        // Iterate all set bits
        while (x) {
 
            // If current bit is set
            if (x & 1)
                binary[idx]++;
            x >>= 1;
            idx++;
        }
    }
 
    // Construct number according
    // to the above defined rule
    for (int i = 0; i < n; ++i) {
 
        int total = 0;
 
        // Traverse each binary bit
        for (int j = 0; j < 21; ++j) {
 
            // If current bit is set
            if (binary[j] > 0) {
                total += pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    cout << ans << endl;
}
 
// Driver Code
int main()
{
    // Given array arr[]
    vector<int> arr = { 8, 9, 9, 1 };
 
    int N = arr.size();
 
    // Function call
    findMaximumSum(arr, N);
 
    return 0;
}
chevron_right

filter_none

edit
close

play_arrow

link
brightness_4
code

// Java program for the above approach
import java.util.*;
 
class GFG{
 
// Stores the maximum value
static int ans = 0;
 
static int []binary = new int[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
static void findMaximumSum(int []arr, int n)
{
     
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for(int x : arr)
    {
        int idx = 0;
 
        // Iterate all set bits
        while (x > 0)
        {
             
            // If current bit is set
            if ((x & 1) > 0)
                binary[idx]++;
                 
            x >>= 1;
            idx++;
        }
    }
 
    // Connumber according
    // to the above defined rule
    for(int i = 0; i < n; ++i)
    {
        int total = 0;
 
        // Traverse each binary bit
        for(int j = 0; j < 21; ++j)
        {
             
            // If current bit is set
            if (binary[j] > 0)
            {
                total += Math.pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    System.out.print(ans + "\n");
}
 
// Driver Code
public static void main(String[] args)
{
     
    // Given array arr[]
   int[] arr = { 8, 9, 9, 1 };
 
    int N = arr.length;
 
    // Function call
    findMaximumSum(arr, N);
}
}
 
// This code is contributed by Amit Katiyar
chevron_right

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python3 program for the above approach
import math
 
binary = [0] * 31
  
# Function to find the maximum sum
# of squares for each element of the
# array after updates
def findMaximumSum(arr, n):
     
    # Stores the maximum value
    ans = 0
 
    # Update the binary bit count at
    # corresponding indices for
    # each element
    for x in arr:
        idx = 0
  
        # Iterate all set bits
        while (x):
  
            # If current bit is set
            if (x & 1):
                binary[idx] += 1
                 
            x >>= 1
            idx += 1
         
    # Construct number according
    # to the above defined rule
    for i in range(n):
        total = 0
  
        # Traverse each binary bit
        for j in range(21):
  
            # If current bit is set
            if (binary[j] > 0):
                total += int(math.pow(2, j))
                binary[j] -= 1
             
        # Square the constructed number
        ans += total * total
     
    # Return the answer
    print(ans)
 
# Driver Code
 
# Given array arr[]
arr = [ 8, 9, 9, 1 ]
  
N = len(arr)
  
# Function call
findMaximumSum(arr, N)
 
# This code is contributed by code_hunt
chevron_right

filter_none

edit
close

play_arrow

link
brightness_4
code

// C# program for the
// above approach
using System;
class GFG{
 
// Stores the maximum
// value
static int ans = 0;
 
static int []binary =
             new int[31];
 
// Function to find the maximum
// sum of squares for each element
// of the array after updates
static void findMaximumSum(int []arr,
                           int n)
{
  // Update the binary bit
  // count at corresponding
  // indices for each element
  for(int i = 0; i < arr.Length; i++)
  {
    int idx = 0;
 
    // Iterate all set bits
    while (arr[i] > 0)
    {
      // If current bit is set
      if ((arr[i] & 1) > 0)
        binary[idx]++;
 
      arr[i] >>= 1;
      idx++;
    }
  }
 
  // Connumber according
  // to the above defined rule
  for(int i = 0; i < n; ++i)
  {
    int total = 0;
 
    // Traverse each binary bit
    for(int j = 0; j < 21; ++j)
    {
      // If current bit is set
      if (binary[j] > 0)
      {
        total += (int)Math.Pow(2, j);
        binary[j]--;
      }
    }
 
    // Square the constructed
    // number
    ans += total * total;
  }
 
  // Return the answer
  Console.Write(ans + "\n");
}
 
// Driver Code
public static void Main(String[] args)
{
  // Given array []arr
  int[] arr = {8, 9, 9, 1};
 
  int N = arr.Length;
 
  // Function call
  findMaximumSum(arr, N);
}
}
 
// This code is contributed by 29AjayKumar
chevron_right

Output
243

Time Complexity: O(N log2A), where A is the maximum element of the array.
Auxiliary Space: O(1)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.




Recommended Posts:


Jadavpur University IT Undergrad 22

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.



Article Tags :