Open In App

Maximize sum of squares of array elements possible by replacing pairs with their Bitwise AND and Bitwise OR

Last Updated : 19 Jul, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given an array arr[] consisting of N integers, the task is to find the maximum sum of the squares of array elements possible from the given array by performing the following operations:

  • Select any pair of array elements (arr[i], arr[j])
  • Replace arr[i] by arr[i] AND arr[j]
  • Replace arr[j] by arr[i] OR arr[j].

Examples:

Input: arr[] = {1, 3, 5}
Output: 51
Explanation:
For the pair (arr[1], arr[2]), perform the following operations:
Replace 3 with 3 AND 5, which is equal to 1.
Replace 5 with 2 OR 5, which is equal to 7.
The modified array obtained after the above steps is {1, 1, 7}.
Therefore, the maximized sum of the squares can be calculated by 1 * 1 + 1 * 1 + 7 * 7 = 51.

Input: arr[] = {8, 9, 9, 1}
Output: 243

Approach: The idea is to observe that if x and y are the 2 selected elements then let z = x AND y, w = x OR y where x + y = z + w.

  • If x ? y, without loss of generality, clearly, z ? w. So, rewrite the expression as x + y = (x – d) + (y + d)

Old sum of squares = M = x2 + y2
New sum of squares = N = (x – d)2 + (y – d)2, d > 0
Difference = N – M = 2d(y + d – x), d > 0 

  • If d > 0, the difference is positive. Hence, we successfully increased the total sum of squares. The above observation is due to the fact that the square of a larger number is greater than the sum of the square of a smaller number.
  • After converting the given integers in the binary form we observe the following:

x = 3 =  0 1 1
y = 5 =  1 0 1
z = 1 =  0 0 1
w = 7 = 1 1 1

  • Total set bits in x + y = 2 + 2 = 4, and the total set bits in z + w = 1 + 3 = 4. Hence, the total set bits are preserved after performing this operation. Now the idea is to figure out z and w.

For Example: arr[] = {5, 2, 3, 4, 5, 6, 7}

1 0 1 = 5
0 1 0 = 2
0 1 1 = 3
1 0 0 = 4
1 0 1 = 5
1 1 0 = 6
1 1 1 = 7
———-
5 4 4 (sum of set bits)
Now, sum up the squares of these numbers. 
 

  • Therefore, iterate for every bit position 1 to 20, store the total number of bits in that index.
  • Then while constructing a number, take 1 bit at a time from each of the indexes.
  • After obtaining the number, add the square of the number to the answer.
  • Print the value of the answer after the above steps.

Below is the implementation for the above approach:

C++




// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Stores the maximum value
int ans = 0;
 
int binary[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
void findMaximumSum(
    const vector<int>& arr, int n)
{
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for (auto x : arr) {
 
        int idx = 0;
 
        // Iterate all set bits
        while (x) {
 
            // If current bit is set
            if (x & 1)
                binary[idx]++;
            x >>= 1;
            idx++;
        }
    }
 
    // Construct number according
    // to the above defined rule
    for (int i = 0; i < n; ++i) {
 
        int total = 0;
 
        // Traverse each binary bit
        for (int j = 0; j < 21; ++j) {
 
            // If current bit is set
            if (binary[j] > 0) {
                total += pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    cout << ans << endl;
}
 
// Driver Code
int main()
{
    // Given array arr[]
    vector<int> arr = { 8, 9, 9, 1 };
 
    int N = arr.size();
 
    // Function call
    findMaximumSum(arr, N);
 
    return 0;
}


C




// C program for the above approach
#include <math.h>
#include <stdio.h>
 
// Stores the maximum value
int ans = 0;
int binary[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
void findMaximumSum(const int* arr, int n)
{
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for (int i = 0; i < n; i++) {
        int x = arr[i];
        int idx = 0;
 
        // Iterate all set bits
        while (x) {
 
            // If current bit is set
            if (x & 1)
                binary[idx]++;
            x >>= 1;
            idx++;
        }
    }
 
    // Construct number according
    // to the above defined rule
    for (int i = 0; i < n; ++i) {
 
        int total = 0;
 
        // Traverse each binary bit
        for (int j = 0; j < 21; ++j) {
 
            // If current bit is set
            if (binary[j] > 0) {
                total += pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
 
    printf("%d\n", ans);
}
 
// Driver Code
int main()
{
   
    // Given array arr[]
    int arr[] = { 8, 9, 9, 1 };
    int N = sizeof(arr) / sizeof(arr[0]);
 
    // Function call
    findMaximumSum(arr, N);
 
    return 0;
}
 
// This code is contributed by phalashi.


Java




// Java program for the above approach
import java.util.*;
 
class GFG{
 
// Stores the maximum value
static int ans = 0;
 
static int []binary = new int[31];
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
static void findMaximumSum(int []arr, int n)
{
     
    // Update the binary bit count at
    // corresponding indices for
    // each element
    for(int x : arr)
    {
        int idx = 0;
 
        // Iterate all set bits
        while (x > 0)
        {
             
            // If current bit is set
            if ((x & 1) > 0)
                binary[idx]++;
                 
            x >>= 1;
            idx++;
        }
    }
 
    // Connumber according
    // to the above defined rule
    for(int i = 0; i < n; ++i)
    {
        int total = 0;
 
        // Traverse each binary bit
        for(int j = 0; j < 21; ++j)
        {
             
            // If current bit is set
            if (binary[j] > 0)
            {
                total += Math.pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    System.out.print(ans + "\n");
}
 
// Driver Code
public static void main(String[] args)
{
     
    // Given array arr[]
   int[] arr = { 8, 9, 9, 1 };
 
    int N = arr.length;
 
    // Function call
    findMaximumSum(arr, N);
}
}
 
// This code is contributed by Amit Katiyar


Python3




# Python3 program for the above approach
import math
 
binary = [0] * 31
  
# Function to find the maximum sum
# of squares for each element of the
# array after updates
def findMaximumSum(arr, n):
     
    # Stores the maximum value
    ans = 0
 
    # Update the binary bit count at
    # corresponding indices for
    # each element
    for x in arr:
        idx = 0
  
        # Iterate all set bits
        while (x):
  
            # If current bit is set
            if (x & 1):
                binary[idx] += 1
                 
            x >>= 1
            idx += 1
         
    # Construct number according
    # to the above defined rule
    for i in range(n):
        total = 0
  
        # Traverse each binary bit
        for j in range(21):
  
            # If current bit is set
            if (binary[j] > 0):
                total += int(math.pow(2, j))
                binary[j] -= 1
             
        # Square the constructed number
        ans += total * total
     
    # Return the answer
    print(ans)
 
# Driver Code
 
# Given array arr[]
arr = [ 8, 9, 9, 1 ]
  
N = len(arr)
  
# Function call
findMaximumSum(arr, N)
 
# This code is contributed by code_hunt


C#




// C# program for the
// above approach
using System;
class GFG{
 
// Stores the maximum
// value
static int ans = 0;
 
static int []binary =
             new int[31];
 
// Function to find the maximum
// sum of squares for each element
// of the array after updates
static void findMaximumSum(int []arr,
                           int n)
{
  // Update the binary bit
  // count at corresponding
  // indices for each element
  for(int i = 0; i < arr.Length; i++)
  {
    int idx = 0;
 
    // Iterate all set bits
    while (arr[i] > 0)
    {
      // If current bit is set
      if ((arr[i] & 1) > 0)
        binary[idx]++;
 
      arr[i] >>= 1;
      idx++;
    }
  }
 
  // Connumber according
  // to the above defined rule
  for(int i = 0; i < n; ++i)
  {
    int total = 0;
 
    // Traverse each binary bit
    for(int j = 0; j < 21; ++j)
    {
      // If current bit is set
      if (binary[j] > 0)
      {
        total += (int)Math.Pow(2, j);
        binary[j]--;
      }
    }
 
    // Square the constructed
    // number
    ans += total * total;
  }
 
  // Return the answer
  Console.Write(ans + "\n");
}
 
// Driver Code
public static void Main(String[] args)
{
  // Given array []arr
  int[] arr = {8, 9, 9, 1};
 
  int N = arr.Length;
 
  // Function call
  findMaximumSum(arr, N);
}
}
 
// This code is contributed by 29AjayKumar


Javascript




<script>
 
// Javascript program for the above approach
 
// Stores the maximum value
var ans = 0;
 
var binary = Array(31).fill(0);
 
// Function to find the maximum sum
// of squares for each element of the
// array after updates
function findMaximumSum(arr, n)
{
    // Update the binary bit count at
    // corresponding indices for
    // each element
    var i,j;
    for (i= 0;i<arr.length;i++) {
        var x = arr[i];
        var idx = 0;
 
        // Iterate all set bits
        while (x) {
 
            // If current bit is set
            if (x & 1)
                binary[idx]++;
            x >>= 1;
            idx++;
        }
    }
 
    // Construct number according
    // to the above defined rule
    for (i = 0; i < n; ++i) {
 
        var total = 0;
 
        // Traverse each binary bit
        for (j = 0; j < 21; ++j) {
 
            // If current bit is set
            if (binary[j] > 0) {
                total += Math.pow(2, j);
                binary[j]--;
            }
        }
 
        // Square the constructed number
        ans += total * total;
    }
 
    // Return the answer
    document.write(ans);
}
 
// Driver Code
    // Given array arr[]
    var arr = [8, 9, 9, 1];
 
    var N = arr.length;
 
    // Function call
    findMaximumSum(arr, N);
 
</script>


Output

243






Time Complexity: O(N log2A), where A is the maximum element of the array.
Auxiliary Space: O(1)

Another Approach:

  1. “Initialize ‘maxSum’ as 0 to store the maximum sum of squares.
  2. Iterate through all pairs of array elements using two nested loops:
    Let the outer loop variable be ‘i’ ranging from 0 to ‘n-1’, where ‘n’ is the size of the array.
    Let the inner loop variable be ‘j’ ranging from ‘i+1’ to ‘n-1’.
  3. For each pair (‘arr[i]’, ‘arr[j]’), perform the following steps:
    Calculate the bitwise AND operation between ‘arr[i]’ and ‘arr[j]’ and store it in ‘new_i’.
    Calculate the bitwise OR operation between ‘arr[i]’ and ‘arr[j]’ and store it in ‘new_j’.
    Calculate the sum of squares:
    Initialize ‘sum’ as 0.
    Iterate through the array using a loop variable ‘k’ ranging from 0 to ‘n-1’.
    If ‘k’ is equal to ‘i’, add ‘new_i * new_i’ to ‘sum’.
    If ‘k’ is equal to ‘j’, add ‘new_j * new_j’ to ‘sum’.
    Otherwise, add ‘arr[k] * arr[k]’ to ‘sum’.
    Update ‘maxSum’ if the calculated ‘sum’ is greater than the current ‘maxSum’.
  4. After the loops, ‘maxSum’ will hold the maximum sum of squares.
  5. Output the value of ‘maxSum’.”

Below is the implementation of the above approach:

C++




#include <iostream>
#include <vector>
#include <algorithm>
 
using namespace std;
 
int findMaxSumOfSquares(vector<int>& arr) {
    int n = arr.size();
    int maxSum = 0;
 
    // Iterate through all pairs of array elements
    for (int i = 0; i < n; i++) {
        for (int j = i + 1; j < n; j++) {
            // Calculate the bitwise AND and OR operations
            int new_i = arr[i] & arr[j];
            int new_j = arr[i] | arr[j];
 
            // Calculate the sum of squares
            int sum = 0;
            for (int k = 0; k < n; k++) {
                if (k == i)
                    sum += new_i * new_i;
                else if (k == j)
                    sum += new_j * new_j;
                else
                    sum += arr[k] * arr[k];
            }
 
            // Update the maximum sum if needed
            maxSum = max(maxSum, sum);
        }
    }
 
    return maxSum;
}
 
int main() {
    // Pre-defined array
    vector<int> arr = {1, 3, 5};
 
    int maxSum = findMaxSumOfSquares(arr);
    cout  << maxSum << endl;
 
    return 0;
}


Java




import java.util.ArrayList;
import java.util.List;
 
public class Main {
    public static int findMaxSumOfSquares(List<Integer> arr) {
        int n = arr.size();
        int maxSum = 0;
 
        // Iterate through all pairs of array elements
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                // Calculate the bitwise AND and OR operations
                int new_i = arr.get(i) & arr.get(j);
                int new_j = arr.get(i) | arr.get(j);
 
                // Calculate the sum of squares
                int sum = 0;
                for (int k = 0; k < n; k++) {
                    if (k == i)
                        sum += new_i * new_i;
                    else if (k == j)
                        sum += new_j * new_j;
                    else
                        sum += arr.get(k) * arr.get(k);
                }
 
                // Update the maximum sum if needed
                maxSum = Math.max(maxSum, sum);
            }
        }
 
        return maxSum;
    }
 
    public static void main(String[] args) {
        // Pre-defined array
        List<Integer> arr = new ArrayList<>();
        arr.add(1);
        arr.add(3);
        arr.add(5);
 
        int maxSum = findMaxSumOfSquares(arr);
        System.out.println(maxSum);
    }
}


Python




import math
 
def FindMaxSumOfSquares(arr):
    n = len(arr)
    maxSum = 0
 
    # Iterate through all pairs of array elements
    for i in range(n):
        for j in range(i + 1, n):
            # Calculate the bitwise AND and OR operations
            new_i = arr[i] & arr[j]
            new_j = arr[i] | arr[j]
 
            # Calculate the sum of squares
            sum_ = 0
            for k in range(n):
                if k == i:
                    sum_ += new_i * new_i
                elif k == j:
                    sum_ += new_j * new_j
                else:
                    sum_ += arr[k] * arr[k]
 
            # Update the maximum sum if needed
            maxSum = max(maxSum, sum_)
 
    return maxSum
 
# Driver code
 
# Pre-defined array
arr = [1, 3, 5]
 
maxSum = FindMaxSumOfSquares(arr)
print(maxSum)
 
 
# by phasing17


C#




using System;
using System.Collections.Generic;
 
class GFG
{
    public static int FindMaxSumOfSquares(List<int> arr)
    {
        int n = arr.Count;
        int maxSum = 0;
 
        // Iterate through all pairs of array elements
        for (int i = 0; i < n; i++)
        {
            for (int j = i + 1; j < n; j++)
            {
                // Calculate the bitwise AND and OR operations
                int new_i = arr[i] & arr[j];
                int new_j = arr[i] | arr[j];
 
                // Calculate the sum of squares
                int sum = 0;
                for (int k = 0; k < n; k++)
                {
                    if (k == i)
                        sum += new_i * new_i;
                    else if (k == j)
                        sum += new_j * new_j;
                    else
                        sum += arr[k] * arr[k];
                }
 
                // Update the maximum sum if needed
                maxSum = Math.Max(maxSum, sum);
            }
        }
 
        return maxSum;
    }
     
      // Driver code
    public static void Main(string[] args)
    {
        // Pre-defined array
        List<int> arr = new List<int> { 1, 3, 5 };
 
        int maxSum = FindMaxSumOfSquares(arr);
        Console.WriteLine(maxSum);
    }
}


Javascript




function findMaxSumOfSquares(arr) {
    const n = arr.length;
    let maxSum = 0;
 
    // Iterate through all pairs of array elements
    for (let i = 0; i < n; i++) {
        for (let j = i + 1; j < n; j++) {
            // Calculate the bitwise AND and OR operations
            const new_i = arr[i] & arr[j];
            const new_j = arr[i] | arr[j];
 
            // Calculate the sum of squares
            let sum = 0;
            for (let k = 0; k < n; k++) {
                if (k === i)
                    sum += new_i * new_i;
                else if (k === j)
                    sum += new_j * new_j;
                else
                    sum += arr[k] * arr[k];
            }
 
            // Update the maximum sum if needed
            maxSum = Math.max(maxSum, sum);
        }
    }
 
    return maxSum;
}
 
// Driver code
const arr = [1, 3, 5];
 
const maxSum = findMaxSumOfSquares(arr);
console.log(maxSum);


Output

51







Time Complexity: O(N^3)

Auxiliary Space: O(1)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads