Skip to content
Related Articles

Related Articles

Maximize sum of squares of array elements possible by replacing pairs with their Bitwise AND and Bitwise OR
  • Last Updated : 18 Oct, 2020

Given an array arr[] consisting of N integers, the task is to find the maximum sum of the squares of array elements possible from the given array by performing the following operations:

  • Select any pair of array elements (arr[i], arr[j])
  • Replace arr[i] by arr[i] AND arr[j]
  • Replace arr[j] by arr[i] OR arr[j].

Examples:

Input: arr[] = {1, 3, 5}
Output: 51
Explanation:
For the pair (arr[1], arr[2]), perform the following operations:
Replace 3 with 3 AND 5, which is equal to 1.
Replace 5 with 2 OR 5, which is equal to 7.
The modified array obtained after the above steps is {1, 1, 7}.
Therefore, the maximized sum of the squares can be calculated by 1 * 1 + 1 * 1 + 7 * 7 = 51.

Input: arr[] = {8, 9, 9, 1}
Output: 243

Approach: The idea is to observe that if x and y are the 2 selected elements then let z = x AND y, w = x OR y where x + y = z + w.



  • If x ≤ y, without loss of generality, clearly, z ≤ w. So, rewrite the expression as x + y = (x – d) + (y + d)

Old sum of squares = M = x2 + y2
New sum of squares = N = (x – d)2 + (y – d)2, d > 0
Difference = N – M = 2d(y + d – x), d > 0 

  • If d > 0, the difference is positive. Hence, we successfully increased the total sum of squares. The above observation is due to the fact that the square of a larger number is greater than the sum of the square of a smaller number.
  • After converting the given integers in the binary form we observe the following:

x = 3 =  0 1 1
y = 5 =  1 0 1
z = 1 =  0 0 1
w = 7 = 1 1 1

  • Total set bits in x + y = 2 + 2 = 4, and the total set bits in z + w = 1 + 3 = 4. Hence, the total set bits are preserved after performing this operation. Now the idea is to figure out z and w.

For Example: arr[] = {5, 2, 3, 4, 5, 6, 7}

1 0 1 = 5
0 1 0 = 2
0 1 1 = 3
1 0 0 = 4
1 0 1 = 5
1 1 0 = 6
1 1 1 = 7
———-
5 4 4 (sum of set bits)
Now, sum up the squares of these numbers. 
 

  • Therefore, iterate for every bit position 1 to 20, store the total number of bits in that index.
  • Then while constructing a number, take 1 bit at a time from each of the indexes.
  • After obtaining the number, add the square of the number to the answer.
  • Print the value of the answer after the above steps.

Below is the implementation fo the above approach:

C++




// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Stores the maximum value
int ans = 0;
 
int binary[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
void findMaximumSum(
    const vector<int>& arr, int n)
{
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for (auto x : arr) {
 
        int idx = 0;
 
        // Iterate all set bits
        while (x) {
 
            // If current bit is set
            if (x & 1)
                binary[idx]++;
            x >>= 1;
            idx++;
        }
    }
 
    // Construct number according
    // to the above defined rule
    for (int i = 0; i < n; ++i) {
 
        int total = 0;
 
        // Traverse each binary bit
        for (int j = 0; j < 21; ++j) {
 
            // If current bit is set
            if (binary[j] > 0) {
                total += pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    cout << ans << endl;
}
 
// Driver Code
int main()
{
    // Given array arr[]
    vector<int> arr = { 8, 9, 9, 1 };
 
    int N = arr.size();
 
    // Function call
    findMaximumSum(arr, N);
 
    return 0;
}

Java




// Java program for the above approach
import java.util.*;
 
class GFG{
 
// Stores the maximum value
static int ans = 0;
 
static int []binary = new int[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
static void findMaximumSum(int []arr, int n)
{
     
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for(int x : arr)
    {
        int idx = 0;
 
        // Iterate all set bits
        while (x > 0)
        {
             
            // If current bit is set
            if ((x & 1) > 0)
                binary[idx]++;
                 
            x >>= 1;
            idx++;
        }
    }
 
    // Connumber according
    // to the above defined rule
    for(int i = 0; i < n; ++i)
    {
        int total = 0;
 
        // Traverse each binary bit
        for(int j = 0; j < 21; ++j)
        {
             
            // If current bit is set
            if (binary[j] > 0)
            {
                total += Math.pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    System.out.print(ans + "\n");
}
 
// Driver Code
public static void main(String[] args)
{
     
    // Given array arr[]
   int[] arr = { 8, 9, 9, 1 };
 
    int N = arr.length;
 
    // Function call
    findMaximumSum(arr, N);
}
}
 
// This code is contributed by Amit Katiyar

Python3




# Python3 program for the above approach
import math
 
binary = [0] * 31
  
# Function to find the maximum sum
# of squares for each element of the
# array after updates
def findMaximumSum(arr, n):
     
    # Stores the maximum value
    ans = 0
 
    # Update the binary bit count at
    # corresponding indices for
    # each element
    for x in arr:
        idx = 0
  
        # Iterate all set bits
        while (x):
  
            # If current bit is set
            if (x & 1):
                binary[idx] += 1
                 
            x >>= 1
            idx += 1
         
    # Construct number according
    # to the above defined rule
    for i in range(n):
        total = 0
  
        # Traverse each binary bit
        for j in range(21):
  
            # If current bit is set
            if (binary[j] > 0):
                total += int(math.pow(2, j))
                binary[j] -= 1
             
        # Square the constructed number
        ans += total * total
     
    # Return the answer
    print(ans)
 
# Driver Code
 
# Given array arr[]
arr = [ 8, 9, 9, 1 ]
  
N = len(arr)
  
# Function call
findMaximumSum(arr, N)
 
# This code is contributed by code_hunt

C#




// C# program for the
// above approach
using System;
class GFG{
 
// Stores the maximum
// value
static int ans = 0;
 
static int []binary =
             new int[31];
 
// Function to find the maximum
// sum of squares for each element
// of the array after updates
static void findMaximumSum(int []arr,
                           int n)
{
  // Update the binary bit
  // count at corresponding
  // indices for each element
  for(int i = 0; i < arr.Length; i++)
  {
    int idx = 0;
 
    // Iterate all set bits
    while (arr[i] > 0)
    {
      // If current bit is set
      if ((arr[i] & 1) > 0)
        binary[idx]++;
 
      arr[i] >>= 1;
      idx++;
    }
  }
 
  // Connumber according
  // to the above defined rule
  for(int i = 0; i < n; ++i)
  {
    int total = 0;
 
    // Traverse each binary bit
    for(int j = 0; j < 21; ++j)
    {
      // If current bit is set
      if (binary[j] > 0)
      {
        total += (int)Math.Pow(2, j);
        binary[j]--;
      }
    }
 
    // Square the constructed
    // number
    ans += total * total;
  }
 
  // Return the answer
  Console.Write(ans + "\n");
}
 
// Driver Code
public static void Main(String[] args)
{
  // Given array []arr
  int[] arr = {8, 9, 9, 1};
 
  int N = arr.Length;
 
  // Function call
  findMaximumSum(arr, N);
}
}
 
// This code is contributed by 29AjayKumar
Output
243

Time Complexity: O(N log2A), where A is the maximum element of the array.
Auxiliary Space: O(1)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.

My Personal Notes arrow_drop_up
Recommended Articles
Page :