Open In App

Maximize XOR from remaining digits after removing K digits in String

Last Updated : 18 Sep, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a string s consisting of digits and an integer K, the task is to remove exactly K digits from the string so that the remaining digits XOR is maximum.

Examples:

Input: s = “1234”,  K = 1
Output:  7
Explanation: After removing 3 from “1234”, the maximum XOR will be 7.

Input: s = “6479”,  K = 2
Output: 15
Explanation: After removing 4 and 7 from “6479”, the maximum XOR will be 15.

Approach: This can be solved with the following idea:

The key insight behind this approach is that we can use the priority queue to keep track of the best XOR values for all possible combinations of the remaining digits, and we can efficiently update the priority queue using the binary representation of the mask. This allows us to narrow down the search space and find the maximum XOR value with much less computational effort. 

Here are the steps of the approach:

  • Initialize a priority queue pq of pairs (mask, XOR value), where mask represents the binary digits that are still available for XOR, and XOR value is the current XOR value of the digits in the mask. We start with the mask as all 1’s and the XOR value as 0 and push this pair into the priority queue.
  • Loop through each digit in the input string s, and for each digit, do the following:
    • Get the digit as an integer digit.
    • Loop through each pair (mask, XOR value) in the priority queue, and for each pair, do the following:
      • Compute the number of set bits in the mask using the built-in function __builtin_popcount(mask). Let this value be set_bits.
      • If set_bits < K, we can remove the current digit from the mask. We do this by setting the Kth bit of mask to 0 using the bitwise AND operator & and the bitwise NOT operator ~. We then XOR digit with the XOR value of the pair to get the new XOR value. We then push the new pair (new mask, new XOR value) into the priority queue.
      • We also need to consider the case where we keep the current digit in the mask. We do this by setting the Kth bit of mask to 1 using the bitwise AND operator &. We then push the new pair (new mask, same XOR value) into the priority queue.
  • After we have looped through all pairs in the priority queue, we remove the top pair from the priority queue and store its XOR value as ans. Return ans as the maximum XOR value.

Below is the implementation of the above approach:

C++




// C++ code for the above approach:
#include <bits/stdc++.h>
using namespace std;
 
// Function to get maximum xor value
int solve(string s, int K)
{
 
    int N = s.length();
 
    // Initialise queue
    priority_queue<pair<int, int> > pq;
    pq.push({ 0, 0 });
 
    for (int i = 0; i < N; i++) {
        int digit = s[i] - '0';
        int sz = pq.size();
 
        for (int j = 0; j < sz; j++) {
            auto p = pq.top();
            pq.pop();
            int mask = p.first;
            int xor_val = p.second;
 
            // Count set bits
            int set_bits = __builtin_popcount(mask);
 
            if (set_bits < K) {
                pq.push({ mask & ~(1 << set_bits),
                          xor_val ^ digit });
            }
 
            pq.push({ mask & ~(0 << set_bits), xor_val });
        }
    }
 
    int ans = 0;
 
    while (!pq.empty()) {
        ans = max(ans, pq.top().second);
        pq.pop();
    }
 
    // Result after removing K characters
    return ans;
}
 
// Driver code
int main()
{
    string s = "1234";
    int K = 1;
 
    // Function call
    cout << solve(s, K) << endl;
 
    s = "6479";
    K = 2;
    // Function call
    cout << solve(s, K) << endl;
 
    return 0;
}


Java




// JAVA code for the above approach:
 
import java.util.PriorityQueue;
 
class GFG {
 
    // Function to get maximum xor value
    static int solve(String s, int K)
    {
        int N = s.length();
 
        // Initialize priority queue
        PriorityQueue<Pair> pq = new PriorityQueue<>(
            (a, b) -> b.xorVal - a.xorVal);
        pq.add(new Pair(0, 0));
 
        for (int i = 0; i < N; i++) {
            int digit
                = Character.getNumericValue(s.charAt(i));
            int sz = pq.size();
 
            for (int j = 0; j < sz; j++) {
                Pair p = pq.poll();
                int mask = p.mask;
                int xorVal = p.xorVal;
 
                // Count set bits
                int setBits = Integer.bitCount(mask);
 
                if (setBits < K) {
                    pq.add(new Pair(mask & ~(1 << setBits),
                                    xorVal ^ digit));
                }
 
                pq.add(new Pair(mask & ~(0 << setBits),
                                xorVal));
            }
        }
 
        int ans = 0;
 
        while (!pq.isEmpty()) {
            ans = Math.max(ans, pq.poll().xorVal);
        }
 
        // Result after removing K characters
        return ans;
    }
 
    // Pair class to store mask and xor value
    static class Pair {
        int mask;
        int xorVal;
 
        Pair(int mask, int xorVal)
        {
            this.mask = mask;
            this.xorVal = xorVal;
        }
    }
 
    // Driver code
    public static void main(String[] args)
    {
        String s = "1234";
        int K = 1;
 
        // Function call
        System.out.println(solve(s, K));
 
        s = "6479";
        K = 2;
        // Function call
        System.out.println(solve(s, K));
    }
}
 
// This code is contributed by akshitaguprzj3


Python3




# Python3 code for the above approach:
def solve(s, K):
    N = len(s)
 
    # Initialize priority queue
    pq = [(0, 0)]
 
    for i in range(N):
        digit = int(s[i])
        sz = len(pq)
 
        for j in range(sz):
            mask, xor_val = pq[j]
            pq[j] = (mask, xor_val)
 
            # Count set bits
            set_bits = bin(mask).count('1')
 
            if set_bits < K:
                pq.append((mask & ~(1 << set_bits), xor_val ^ digit))
 
            pq.append((mask & ~(0 << set_bits), xor_val))
 
    ans = 0
 
    while pq:
        _, xor_val = pq.pop(0)
        ans = max(ans, xor_val)
 
    # Result after removing K characters
    return ans
 
# Driver code
 
 
s = "1234"
K = 1
 
# Function call
print(solve(s, K))
 
s = "6479"
K = 2
 
# Function call
print(solve(s, K))
 
 
# This code is contributed by akshitaguprzj3


C#




// C# code for the above approach:
using System;
using System.Collections.Generic;
 
// Priority Queue Implementation
public class PriorityQueue<T> {
    private List<T> heap;
    private Comparison<T> comparison;
 
    public PriorityQueue(Comparison<T> comparison)
    {
        this.heap = new List<T>();
        this.comparison = comparison;
    }
 
    public void Enqueue(T item)
    {
        this.heap.Add(item);
        HeapifyUp(this.heap.Count - 1);
    }
 
    public T Dequeue()
    {
        if (this.heap.Count == 0) {
            throw new InvalidOperationException(
                "Priority queue is empty.");
        }
 
        T root = this.heap[0];
        int lastIndex = this.heap.Count - 1;
        this.heap[0] = this.heap[lastIndex];
        this.heap.RemoveAt(lastIndex);
 
        HeapifyDown(0);
        return root;
    }
 
    public T Peek()
    {
        if (this.heap.Count == 0) {
            throw new InvalidOperationException(
                "Priority queue is empty.");
        }
 
        return this.heap[0];
    }
 
    public int Count => this.heap.Count;
 
    public bool IsEmpty => this.heap.Count == 0;
 
    private void HeapifyUp(int index)
    {
        while (index > 0) {
            int parentIndex = (index - 1) / 2;
            if (this.comparison(this.heap[index],
                                this.heap[parentIndex])
                < 0) {
                Swap(index, parentIndex);
                index = parentIndex;
            }
            else {
                break;
            }
        }
    }
 
    private void HeapifyDown(int index)
    {
        int leftChild = 2 * index + 1;
        int rightChild = 2 * index + 2;
        int smallest = index;
 
        if (leftChild < this.heap.Count
            && this.comparison(this.heap[leftChild],
                               this.heap[smallest])
                   < 0) {
            smallest = leftChild;
        }
 
        if (rightChild < this.heap.Count
            && this.comparison(this.heap[rightChild],
                               this.heap[smallest])
                   < 0) {
            smallest = rightChild;
        }
 
        if (smallest != index) {
            Swap(index, smallest);
            HeapifyDown(smallest);
        }
    }
 
    private void Swap(int i, int j)
    {
        T temp = this.heap[i];
        this.heap[i] = this.heap[j];
        this.heap[j] = temp;
    }
}
 
class GFG {
    // Function to get maximum xor value
    static int Solve(string s, int K)
    {
        int N = s.Length;
 
        // Initialize priority queue
        PriorityQueue<Pair> pq = new PriorityQueue<Pair>(
            (a, b) => b.xorVal - a.xorVal);
        pq.Enqueue(new Pair(0, 0));
 
        for (int i = 0; i < N; i++) {
            int digit = int.Parse(s[i].ToString());
            int sz = pq.Count;
            for (int j = 0; j < sz; j++) {
                Pair p = pq.Dequeue();
                int mask = p.mask;
                int xorVal = p.xorVal;
 
                // Count set bits
                int setBits = CountSetBits(mask);
 
                if (setBits < K) {
                    pq.Enqueue(
                        new Pair(mask & ~(1 << setBits),
                                 xorVal ^ digit));
                }
 
                pq.Enqueue(new Pair(mask & ~(0 << setBits),
                                    xorVal));
            }
        }
 
        int ans = 0;
 
        while (pq.Count > 0) {
            ans = Math.Max(ans, pq.Dequeue().xorVal);
        }
 
        // Result after removing K characters
        return ans;
    }
 
    // Function to count set bits
    static int CountSetBits(int n)
    {
        int count = 0;
        while (n > 0) {
            count += n & 1;
            n >>= 1;
        }
        return count;
    }
 
    // Pair class to store mask and xor value
    class Pair {
        public int mask;
        public int xorVal;
 
        public Pair(int mask, int xorVal)
        {
            this.mask = mask;
            this.xorVal = xorVal;
        }
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        string s = "1234";
        int K = 1;
 
        // Function call
        Console.WriteLine(Solve(s, K));
 
        s = "6479";
        K = 2;
        // Function call
        Console.WriteLine(Solve(s, K));
    }
}
 
// This code is contributed by Tapesh(tapeshdua420)


Javascript




function solve(s, K) {
    const N = s.length;
 
    // Initialize a priority queue (min heap)
    const pq = [[0, 0]];
 
    for (let i = 0; i < N; i++) {
        const digit = parseInt(s[i]);
        const sz = pq.length;
 
        for (let j = 0; j < sz; j++) {
            const [mask, xorVal] = pq.shift();
            const setBits = (mask.toString(2).match(/1/g) || []).length;
 
            if (setBits < K) {
                pq.push([mask & ~(1 << setBits), xorVal ^ digit]);
            }
 
            pq.push([mask & ~(0 << setBits), xorVal]);
        }
    }
 
    let ans = 0;
 
    while (pq.length > 0) {
        ans = Math.max(ans, pq[0][1]);
        pq.shift();
    }
 
    // Result after removing K characters
    return ans;
}
 
// Driver code
const s1 = "1234";
const K1 = 1;
console.log(solve(s1, K1));
 
const s2 = "6479";
const K2 = 2;
console.log(solve(s2, K2));
//Contributed by Aditi Tyagi


Output

7
15







Time Complexity: O(N * 2^K * log K)
Auxiliary Space: O(2^K)



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads