# Maximize sum of squares of array elements possible by replacing pairs with their Bitwise AND and Bitwise OR

Given an array arr[] consisting of N integers, the task is to find the maximum sum of the squares of array elements possible from the given array by performing the following operations:

• Select any pair of array elements (arr[i], arr[j])
• Replace arr[i] by arr[i] AND arr[j]
• Replace arr[j] by arr[i] OR arr[j].

Examples:

Input: arr[] = {1, 3, 5}
Output: 51
Explanation:
For the pair (arr, arr), perform the following operations:
Replace 3 with 3 AND 5, which is equal to 1.
Replace 5 with 2 OR 5, which is equal to 7.
The modified array obtained after the above steps is {1, 1, 7}.
Therefore, the maximized sum of the squares can be calculated by 1 * 1 + 1 * 1 + 7 * 7 = 51.

Input: arr[] = {8, 9, 9, 1}
Output: 243

Approach: The idea is to observe that if x and y are the 2 selected elements then let z = x AND y, w = x OR y where x + y = z + w.

• If x ≤ y, without loss of generality, clearly, z ≤ w. So, rewrite the expression as x + y = (x – d) + (y + d)

Old sum of squares = M = x2 + y2
New sum of squares = N = (x – d)2 + (y – d)2, d > 0
Difference = N – M = 2d(y + d – x), d > 0

• If d > 0, the difference is positive. Hence, we successfully increased the total sum of squares. The above observation is due to the fact that the square of a larger number is greater than the sum of the square of a smaller number.
• After converting the given integers in the binary form we observe the following:

x = 3 =  0 1 1
y = 5 =  1 0 1
z = 1 =  0 0 1
w = 7 = 1 1 1

• Total set bits in x + y = 2 + 2 = 4, and the total set bits in z + w = 1 + 3 = 4. Hence, the total set bits are preserved after performing this operation. Now the idea is to figure out z and w.

For Example: arr[] = {5, 2, 3, 4, 5, 6, 7}

1 0 1 = 5
0 1 0 = 2
0 1 1 = 3
1 0 0 = 4
1 0 1 = 5
1 1 0 = 6
1 1 1 = 7
———-
5 4 4 (sum of set bits)
Now, sum up the squares of these numbers.

• Therefore, iterate for every bit position 1 to 20, store the total number of bits in that index.
• Then while constructing a number, take 1 bit at a time from each of the indexes.
• After obtaining the number, add the square of the number to the answer.
• Print the value of the answer after the above steps.

Below is the implementation fo the above approach:

 `// C++ program for the above approach` `#include ` `using` `namespace` `std;`   `// Stores the maximum value` `int` `ans = 0;`   `int` `binary;`   `// Function to find the maximum sum` `// of squares for each element of the` `// array after updates` `void` `findMaximumSum(` `    ``const` `vector<``int``>& arr, ``int` `n)` `{` `    ``// Update the binary bit count at` `    ``// corresponding indices for` `    ``// each element` `    ``for` `(``auto` `x : arr) {`   `        ``int` `idx = 0;`   `        ``// Iterate all set bits` `        ``while` `(x) {`   `            ``// If current bit is set` `            ``if` `(x & 1)` `                ``binary[idx]++;` `            ``x >>= 1;` `            ``idx++;` `        ``}` `    ``}`   `    ``// Construct number according` `    ``// to the above defined rule` `    ``for` `(``int` `i = 0; i < n; ++i) {`   `        ``int` `total = 0;`   `        ``// Traverse each binary bit` `        ``for` `(``int` `j = 0; j < 21; ++j) {`   `            ``// If current bit is set` `            ``if` `(binary[j] > 0) {` `                ``total += ``pow``(2, j);` `                ``binary[j]--;` `            ``}` `        ``}`   `        ``// Square the constructed number` `        ``ans += total * total;` `    ``}`   `    ``// Return the answer` `    ``cout << ans << endl;` `}`   `// Driver Code` `int` `main()` `{` `    ``// Given array arr[]` `    ``vector<``int``> arr = { 8, 9, 9, 1 };`   `    ``int` `N = arr.size();`   `    ``// Function call` `    ``findMaximumSum(arr, N);`   `    ``return` `0;` `}`

 `// Java program for the above approach` `import` `java.util.*;`   `class` `GFG{`   `// Stores the maximum value` `static` `int` `ans = ``0``;`   `static` `int` `[]binary = ``new` `int``[``31``];`   `// Function to find the maximum sum` `// of squares for each element of the` `// array after updates` `static` `void` `findMaximumSum(``int` `[]arr, ``int` `n)` `{` `    `  `    ``// Update the binary bit count at` `    ``// corresponding indices for` `    ``// each element` `    ``for``(``int` `x : arr) ` `    ``{` `        ``int` `idx = ``0``;`   `        ``// Iterate all set bits` `        ``while` `(x > ``0``) ` `        ``{` `            `  `            ``// If current bit is set` `            ``if` `((x & ``1``) > ``0``)` `                ``binary[idx]++;` `                `  `            ``x >>= ``1``;` `            ``idx++;` `        ``}` `    ``}`   `    ``// Connumber according` `    ``// to the above defined rule` `    ``for``(``int` `i = ``0``; i < n; ++i)` `    ``{` `        ``int` `total = ``0``;`   `        ``// Traverse each binary bit` `        ``for``(``int` `j = ``0``; j < ``21``; ++j)` `        ``{` `            `  `            ``// If current bit is set` `            ``if` `(binary[j] > ``0``)` `            ``{` `                ``total += Math.pow(``2``, j);` `                ``binary[j]--;` `            ``}` `        ``}`   `        ``// Square the constructed number` `        ``ans += total * total;` `    ``}`   `    ``// Return the answer` `    ``System.out.print(ans + ``"\n"``);` `}`   `// Driver Code` `public` `static` `void` `main(String[] args)` `{` `    `  `    ``// Given array arr[]` `   ``int``[] arr = { ``8``, ``9``, ``9``, ``1` `};`   `    ``int` `N = arr.length;`   `    ``// Function call` `    ``findMaximumSum(arr, N);` `}` `}`   `// This code is contributed by Amit Katiyar`

 `# Python3 program for the above approach` `import` `math`   `binary ``=` `[``0``] ``*` `31` ` `  `# Function to find the maximum sum` `# of squares for each element of the` `# array after updates` `def` `findMaximumSum(arr, n):` `    `  `    ``# Stores the maximum value` `    ``ans ``=` `0`   `    ``# Update the binary bit count at` `    ``# corresponding indices for` `    ``# each element` `    ``for` `x ``in` `arr:` `        ``idx ``=` `0` ` `  `        ``# Iterate all set bits` `        ``while` `(x):` ` `  `            ``# If current bit is set` `            ``if` `(x & ``1``):` `                ``binary[idx] ``+``=` `1` `                `  `            ``x >>``=` `1` `            ``idx ``+``=` `1` `        `  `    ``# Construct number according` `    ``# to the above defined rule` `    ``for` `i ``in` `range``(n):` `        ``total ``=` `0` ` `  `        ``# Traverse each binary bit` `        ``for` `j ``in` `range``(``21``):` ` `  `            ``# If current bit is set` `            ``if` `(binary[j] > ``0``):` `                ``total ``+``=` `int``(math.``pow``(``2``, j))` `                ``binary[j] ``-``=` `1` `            `  `        ``# Square the constructed number` `        ``ans ``+``=` `total ``*` `total` `    `  `    ``# Return the answer` `    ``print``(ans)`   `# Driver Code`   `# Given array arr[]` `arr ``=` `[ ``8``, ``9``, ``9``, ``1` `]` ` `  `N ``=` `len``(arr)` ` `  `# Function call` `findMaximumSum(arr, N)`   `# This code is contributed by code_hunt`

 `// C# program for the ` `// above approach` `using` `System;` `class` `GFG{`   `// Stores the maximum ` `// value` `static` `int` `ans = 0;`   `static` `int` `[]binary = ` `             ``new` `int``;`   `// Function to find the maximum ` `// sum of squares for each element ` `// of the array after updates` `static` `void` `findMaximumSum(``int` `[]arr, ` `                           ``int` `n)` `{` `  ``// Update the binary bit ` `  ``// count at corresponding ` `  ``// indices for each element` `  ``for``(``int` `i = 0; i < arr.Length; i++)` `  ``{` `    ``int` `idx = 0;`   `    ``// Iterate all set bits` `    ``while` `(arr[i] > 0) ` `    ``{` `      ``// If current bit is set` `      ``if` `((arr[i] & 1) > 0)` `        ``binary[idx]++;`   `      ``arr[i] >>= 1;` `      ``idx++;` `    ``}` `  ``}`   `  ``// Connumber according` `  ``// to the above defined rule` `  ``for``(``int` `i = 0; i < n; ++i)` `  ``{` `    ``int` `total = 0;`   `    ``// Traverse each binary bit` `    ``for``(``int` `j = 0; j < 21; ++j)` `    ``{` `      ``// If current bit is set` `      ``if` `(binary[j] > 0)` `      ``{` `        ``total += (``int``)Math.Pow(2, j);` `        ``binary[j]--;` `      ``}` `    ``}`   `    ``// Square the constructed ` `    ``// number` `    ``ans += total * total;` `  ``}`   `  ``// Return the answer` `  ``Console.Write(ans + ``"\n"``);` `}`   `// Driver Code` `public` `static` `void` `Main(String[] args)` `{` `  ``// Given array []arr` `  ``int``[] arr = {8, 9, 9, 1};`   `  ``int` `N = arr.Length;`   `  ``// Function call` `  ``findMaximumSum(arr, N);` `}` `}`   `// This code is contributed by 29AjayKumar`

Output
```243
```

Time Complexity: O(N log2A), where A is the maximum element of the array.
Auxiliary Space: O(1)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.