# Maximize sum of squares of array elements possible by replacing pairs with their Bitwise AND and Bitwise OR

• Last Updated : 01 Aug, 2022

Given an array arr[] consisting of N integers, the task is to find the maximum sum of the squares of array elements possible from the given array by performing the following operations:

• Select any pair of array elements (arr[i], arr[j])
• Replace arr[i] by arr[i] AND arr[j]
• Replace arr[j] by arr[i] OR arr[j].

Examples:

Input: arr[] = {1, 3, 5}
Output: 51
Explanation:
For the pair (arr, arr), perform the following operations:
Replace 3 with 3 AND 5, which is equal to 1.
Replace 5 with 2 OR 5, which is equal to 7.
The modified array obtained after the above steps is {1, 1, 7}.
Therefore, the maximized sum of the squares can be calculated by 1 * 1 + 1 * 1 + 7 * 7 = 51.

Input: arr[] = {8, 9, 9, 1}
Output: 243

Approach: The idea is to observe that if x and y are the 2 selected elements then let z = x AND y, w = x OR y where x + y = z + w.

• If x ≤ y, without loss of generality, clearly, z ≤ w. So, rewrite the expression as x + y = (x – d) + (y + d)

Old sum of squares = M = x2 + y2
New sum of squares = N = (x – d)2 + (y – d)2, d > 0
Difference = N – M = 2d(y + d – x), d > 0

• If d > 0, the difference is positive. Hence, we successfully increased the total sum of squares. The above observation is due to the fact that the square of a larger number is greater than the sum of the square of a smaller number.
• After converting the given integers in the binary form we observe the following:

x = 3 =  0 1 1
y = 5 =  1 0 1
z = 1 =  0 0 1
w = 7 = 1 1 1

• Total set bits in x + y = 2 + 2 = 4, and the total set bits in z + w = 1 + 3 = 4. Hence, the total set bits are preserved after performing this operation. Now the idea is to figure out z and w.

For Example: arr[] = {5, 2, 3, 4, 5, 6, 7}

1 0 1 = 5
0 1 0 = 2
0 1 1 = 3
1 0 0 = 4
1 0 1 = 5
1 1 0 = 6
1 1 1 = 7
———-
5 4 4 (sum of set bits)
Now, sum up the squares of these numbers.

• Therefore, iterate for every bit position 1 to 20, store the total number of bits in that index.
• Then while constructing a number, take 1 bit at a time from each of the indexes.
• After obtaining the number, add the square of the number to the answer.
• Print the value of the answer after the above steps.

Below is the implementation for the above approach:

## C++

 `// C++ program for the above approach``#include ``using` `namespace` `std;` `// Stores the maximum value``int` `ans = 0;` `int` `binary;` `// Function to find the maximum sum``// of squares for each element of the``// array after updates``void` `findMaximumSum(``    ``const` `vector<``int``>& arr, ``int` `n)``{``    ``// Update the binary bit count at``    ``// corresponding indices for``    ``// each element``    ``for` `(``auto` `x : arr) {` `        ``int` `idx = 0;` `        ``// Iterate all set bits``        ``while` `(x) {` `            ``// If current bit is set``            ``if` `(x & 1)``                ``binary[idx]++;``            ``x >>= 1;``            ``idx++;``        ``}``    ``}` `    ``// Construct number according``    ``// to the above defined rule``    ``for` `(``int` `i = 0; i < n; ++i) {` `        ``int` `total = 0;` `        ``// Traverse each binary bit``        ``for` `(``int` `j = 0; j < 21; ++j) {` `            ``// If current bit is set``            ``if` `(binary[j] > 0) {``                ``total += ``pow``(2, j);``                ``binary[j]--;``            ``}``        ``}` `        ``// Square the constructed number``        ``ans += total * total;``    ``}` `    ``// Return the answer``    ``cout << ans << endl;``}` `// Driver Code``int` `main()``{``    ``// Given array arr[]``    ``vector<``int``> arr = { 8, 9, 9, 1 };` `    ``int` `N = arr.size();` `    ``// Function call``    ``findMaximumSum(arr, N);` `    ``return` `0;``}`

## C

 `// C program for the above approach``#include ``#include ` `// Stores the maximum value``int` `ans = 0;``int` `binary;` `// Function to find the maximum sum``// of squares for each element of the``// array after updates``void` `findMaximumSum(``const` `int``* arr, ``int` `n)``{``    ``// Update the binary bit count at``    ``// corresponding indices for``    ``// each element``    ``for` `(``int` `i = 0; i < n; i++) {``        ``int` `x = arr[i];``        ``int` `idx = 0;` `        ``// Iterate all set bits``        ``while` `(x) {` `            ``// If current bit is set``            ``if` `(x & 1)``                ``binary[idx]++;``            ``x >>= 1;``            ``idx++;``        ``}``    ``}` `    ``// Construct number according``    ``// to the above defined rule``    ``for` `(``int` `i = 0; i < n; ++i) {` `        ``int` `total = 0;` `        ``// Traverse each binary bit``        ``for` `(``int` `j = 0; j < 21; ++j) {` `            ``// If current bit is set``            ``if` `(binary[j] > 0) {``                ``total += ``pow``(2, j);``                ``binary[j]--;``            ``}``        ``}` `        ``// Square the constructed number``        ``ans += total * total;``    ``}` `    ``// Return the answer` `    ``printf``(``"%d\n"``, ans);``}` `// Driver Code``int` `main()``{``  ` `    ``// Given array arr[]``    ``int` `arr[] = { 8, 9, 9, 1 };``    ``int` `N = ``sizeof``(arr) / ``sizeof``(arr);` `    ``// Function call``    ``findMaximumSum(arr, N);` `    ``return` `0;``}` `// This code is contributed by phalashi.`

## Java

 `// Java program for the above approach``import` `java.util.*;` `class` `GFG{` `// Stores the maximum value``static` `int` `ans = ``0``;` `static` `int` `[]binary = ``new` `int``[``31``];` `// Function to find the maximum sum``// of squares for each element of the``// array after updates``static` `void` `findMaximumSum(``int` `[]arr, ``int` `n)``{``    ` `    ``// Update the binary bit count at``    ``// corresponding indices for``    ``// each element``    ``for``(``int` `x : arr)``    ``{``        ``int` `idx = ``0``;` `        ``// Iterate all set bits``        ``while` `(x > ``0``)``        ``{``            ` `            ``// If current bit is set``            ``if` `((x & ``1``) > ``0``)``                ``binary[idx]++;``                ` `            ``x >>= ``1``;``            ``idx++;``        ``}``    ``}` `    ``// Connumber according``    ``// to the above defined rule``    ``for``(``int` `i = ``0``; i < n; ++i)``    ``{``        ``int` `total = ``0``;` `        ``// Traverse each binary bit``        ``for``(``int` `j = ``0``; j < ``21``; ++j)``        ``{``            ` `            ``// If current bit is set``            ``if` `(binary[j] > ``0``)``            ``{``                ``total += Math.pow(``2``, j);``                ``binary[j]--;``            ``}``        ``}` `        ``// Square the constructed number``        ``ans += total * total;``    ``}` `    ``// Return the answer``    ``System.out.print(ans + ``"\n"``);``}` `// Driver Code``public` `static` `void` `main(String[] args)``{``    ` `    ``// Given array arr[]``   ``int``[] arr = { ``8``, ``9``, ``9``, ``1` `};` `    ``int` `N = arr.length;` `    ``// Function call``    ``findMaximumSum(arr, N);``}``}` `// This code is contributed by Amit Katiyar`

## Python3

 `# Python3 program for the above approach``import` `math` `binary ``=` `[``0``] ``*` `31`` ` `# Function to find the maximum sum``# of squares for each element of the``# array after updates``def` `findMaximumSum(arr, n):``    ` `    ``# Stores the maximum value``    ``ans ``=` `0` `    ``# Update the binary bit count at``    ``# corresponding indices for``    ``# each element``    ``for` `x ``in` `arr:``        ``idx ``=` `0`` ` `        ``# Iterate all set bits``        ``while` `(x):`` ` `            ``# If current bit is set``            ``if` `(x & ``1``):``                ``binary[idx] ``+``=` `1``                ` `            ``x >>``=` `1``            ``idx ``+``=` `1``        ` `    ``# Construct number according``    ``# to the above defined rule``    ``for` `i ``in` `range``(n):``        ``total ``=` `0`` ` `        ``# Traverse each binary bit``        ``for` `j ``in` `range``(``21``):`` ` `            ``# If current bit is set``            ``if` `(binary[j] > ``0``):``                ``total ``+``=` `int``(math.``pow``(``2``, j))``                ``binary[j] ``-``=` `1``            ` `        ``# Square the constructed number``        ``ans ``+``=` `total ``*` `total``    ` `    ``# Return the answer``    ``print``(ans)` `# Driver Code` `# Given array arr[]``arr ``=` `[ ``8``, ``9``, ``9``, ``1` `]`` ` `N ``=` `len``(arr)`` ` `# Function call``findMaximumSum(arr, N)` `# This code is contributed by code_hunt`

## C#

 `// C# program for the``// above approach``using` `System;``class` `GFG{` `// Stores the maximum``// value``static` `int` `ans = 0;` `static` `int` `[]binary =``             ``new` `int``;` `// Function to find the maximum``// sum of squares for each element``// of the array after updates``static` `void` `findMaximumSum(``int` `[]arr,``                           ``int` `n)``{``  ``// Update the binary bit``  ``// count at corresponding``  ``// indices for each element``  ``for``(``int` `i = 0; i < arr.Length; i++)``  ``{``    ``int` `idx = 0;` `    ``// Iterate all set bits``    ``while` `(arr[i] > 0)``    ``{``      ``// If current bit is set``      ``if` `((arr[i] & 1) > 0)``        ``binary[idx]++;` `      ``arr[i] >>= 1;``      ``idx++;``    ``}``  ``}` `  ``// Connumber according``  ``// to the above defined rule``  ``for``(``int` `i = 0; i < n; ++i)``  ``{``    ``int` `total = 0;` `    ``// Traverse each binary bit``    ``for``(``int` `j = 0; j < 21; ++j)``    ``{``      ``// If current bit is set``      ``if` `(binary[j] > 0)``      ``{``        ``total += (``int``)Math.Pow(2, j);``        ``binary[j]--;``      ``}``    ``}` `    ``// Square the constructed``    ``// number``    ``ans += total * total;``  ``}` `  ``// Return the answer``  ``Console.Write(ans + ``"\n"``);``}` `// Driver Code``public` `static` `void` `Main(String[] args)``{``  ``// Given array []arr``  ``int``[] arr = {8, 9, 9, 1};` `  ``int` `N = arr.Length;` `  ``// Function call``  ``findMaximumSum(arr, N);``}``}` `// This code is contributed by 29AjayKumar`

## Javascript

 ``

Output

`243`

Time Complexity: O(N log2A), where A is the maximum element of the array.
Auxiliary Space: O(1)

My Personal Notes arrow_drop_up