Open In App

Maximize XOR from remaining digits after removing K digits in String

Given a string s consisting of digits and an integer K, the task is to remove exactly K digits from the string so that the remaining digits XOR is maximum.

Examples:



Input: s = “1234”,  K = 1
Output:  7
Explanation: After removing 3 from “1234”, the maximum XOR will be 7.

Input: s = “6479”,  K = 2
Output: 15
Explanation: After removing 4 and 7 from “6479”, the maximum XOR will be 15.



Approach: This can be solved with the following idea:

The key insight behind this approach is that we can use the priority queue to keep track of the best XOR values for all possible combinations of the remaining digits, and we can efficiently update the priority queue using the binary representation of the mask. This allows us to narrow down the search space and find the maximum XOR value with much less computational effort. 

Here are the steps of the approach:

Below is the implementation of the above approach:




// C++ code for the above approach:
#include <bits/stdc++.h>
using namespace std;
 
// Function to get maximum xor value
int solve(string s, int K)
{
 
    int N = s.length();
 
    // Initialise queue
    priority_queue<pair<int, int> > pq;
    pq.push({ 0, 0 });
 
    for (int i = 0; i < N; i++) {
        int digit = s[i] - '0';
        int sz = pq.size();
 
        for (int j = 0; j < sz; j++) {
            auto p = pq.top();
            pq.pop();
            int mask = p.first;
            int xor_val = p.second;
 
            // Count set bits
            int set_bits = __builtin_popcount(mask);
 
            if (set_bits < K) {
                pq.push({ mask & ~(1 << set_bits),
                          xor_val ^ digit });
            }
 
            pq.push({ mask & ~(0 << set_bits), xor_val });
        }
    }
 
    int ans = 0;
 
    while (!pq.empty()) {
        ans = max(ans, pq.top().second);
        pq.pop();
    }
 
    // Result after removing K characters
    return ans;
}
 
// Driver code
int main()
{
    string s = "1234";
    int K = 1;
 
    // Function call
    cout << solve(s, K) << endl;
 
    s = "6479";
    K = 2;
    // Function call
    cout << solve(s, K) << endl;
 
    return 0;
}




// JAVA code for the above approach:
 
import java.util.PriorityQueue;
 
class GFG {
 
    // Function to get maximum xor value
    static int solve(String s, int K)
    {
        int N = s.length();
 
        // Initialize priority queue
        PriorityQueue<Pair> pq = new PriorityQueue<>(
            (a, b) -> b.xorVal - a.xorVal);
        pq.add(new Pair(0, 0));
 
        for (int i = 0; i < N; i++) {
            int digit
                = Character.getNumericValue(s.charAt(i));
            int sz = pq.size();
 
            for (int j = 0; j < sz; j++) {
                Pair p = pq.poll();
                int mask = p.mask;
                int xorVal = p.xorVal;
 
                // Count set bits
                int setBits = Integer.bitCount(mask);
 
                if (setBits < K) {
                    pq.add(new Pair(mask & ~(1 << setBits),
                                    xorVal ^ digit));
                }
 
                pq.add(new Pair(mask & ~(0 << setBits),
                                xorVal));
            }
        }
 
        int ans = 0;
 
        while (!pq.isEmpty()) {
            ans = Math.max(ans, pq.poll().xorVal);
        }
 
        // Result after removing K characters
        return ans;
    }
 
    // Pair class to store mask and xor value
    static class Pair {
        int mask;
        int xorVal;
 
        Pair(int mask, int xorVal)
        {
            this.mask = mask;
            this.xorVal = xorVal;
        }
    }
 
    // Driver code
    public static void main(String[] args)
    {
        String s = "1234";
        int K = 1;
 
        // Function call
        System.out.println(solve(s, K));
 
        s = "6479";
        K = 2;
        // Function call
        System.out.println(solve(s, K));
    }
}
 
// This code is contributed by akshitaguprzj3




# Python3 code for the above approach:
def solve(s, K):
    N = len(s)
 
    # Initialize priority queue
    pq = [(0, 0)]
 
    for i in range(N):
        digit = int(s[i])
        sz = len(pq)
 
        for j in range(sz):
            mask, xor_val = pq[j]
            pq[j] = (mask, xor_val)
 
            # Count set bits
            set_bits = bin(mask).count('1')
 
            if set_bits < K:
                pq.append((mask & ~(1 << set_bits), xor_val ^ digit))
 
            pq.append((mask & ~(0 << set_bits), xor_val))
 
    ans = 0
 
    while pq:
        _, xor_val = pq.pop(0)
        ans = max(ans, xor_val)
 
    # Result after removing K characters
    return ans
 
# Driver code
 
 
s = "1234"
K = 1
 
# Function call
print(solve(s, K))
 
s = "6479"
K = 2
 
# Function call
print(solve(s, K))
 
 
# This code is contributed by akshitaguprzj3




// C# code for the above approach:
using System;
using System.Collections.Generic;
 
// Priority Queue Implementation
public class PriorityQueue<T> {
    private List<T> heap;
    private Comparison<T> comparison;
 
    public PriorityQueue(Comparison<T> comparison)
    {
        this.heap = new List<T>();
        this.comparison = comparison;
    }
 
    public void Enqueue(T item)
    {
        this.heap.Add(item);
        HeapifyUp(this.heap.Count - 1);
    }
 
    public T Dequeue()
    {
        if (this.heap.Count == 0) {
            throw new InvalidOperationException(
                "Priority queue is empty.");
        }
 
        T root = this.heap[0];
        int lastIndex = this.heap.Count - 1;
        this.heap[0] = this.heap[lastIndex];
        this.heap.RemoveAt(lastIndex);
 
        HeapifyDown(0);
        return root;
    }
 
    public T Peek()
    {
        if (this.heap.Count == 0) {
            throw new InvalidOperationException(
                "Priority queue is empty.");
        }
 
        return this.heap[0];
    }
 
    public int Count => this.heap.Count;
 
    public bool IsEmpty => this.heap.Count == 0;
 
    private void HeapifyUp(int index)
    {
        while (index > 0) {
            int parentIndex = (index - 1) / 2;
            if (this.comparison(this.heap[index],
                                this.heap[parentIndex])
                < 0) {
                Swap(index, parentIndex);
                index = parentIndex;
            }
            else {
                break;
            }
        }
    }
 
    private void HeapifyDown(int index)
    {
        int leftChild = 2 * index + 1;
        int rightChild = 2 * index + 2;
        int smallest = index;
 
        if (leftChild < this.heap.Count
            && this.comparison(this.heap[leftChild],
                               this.heap[smallest])
                   < 0) {
            smallest = leftChild;
        }
 
        if (rightChild < this.heap.Count
            && this.comparison(this.heap[rightChild],
                               this.heap[smallest])
                   < 0) {
            smallest = rightChild;
        }
 
        if (smallest != index) {
            Swap(index, smallest);
            HeapifyDown(smallest);
        }
    }
 
    private void Swap(int i, int j)
    {
        T temp = this.heap[i];
        this.heap[i] = this.heap[j];
        this.heap[j] = temp;
    }
}
 
class GFG {
    // Function to get maximum xor value
    static int Solve(string s, int K)
    {
        int N = s.Length;
 
        // Initialize priority queue
        PriorityQueue<Pair> pq = new PriorityQueue<Pair>(
            (a, b) => b.xorVal - a.xorVal);
        pq.Enqueue(new Pair(0, 0));
 
        for (int i = 0; i < N; i++) {
            int digit = int.Parse(s[i].ToString());
            int sz = pq.Count;
            for (int j = 0; j < sz; j++) {
                Pair p = pq.Dequeue();
                int mask = p.mask;
                int xorVal = p.xorVal;
 
                // Count set bits
                int setBits = CountSetBits(mask);
 
                if (setBits < K) {
                    pq.Enqueue(
                        new Pair(mask & ~(1 << setBits),
                                 xorVal ^ digit));
                }
 
                pq.Enqueue(new Pair(mask & ~(0 << setBits),
                                    xorVal));
            }
        }
 
        int ans = 0;
 
        while (pq.Count > 0) {
            ans = Math.Max(ans, pq.Dequeue().xorVal);
        }
 
        // Result after removing K characters
        return ans;
    }
 
    // Function to count set bits
    static int CountSetBits(int n)
    {
        int count = 0;
        while (n > 0) {
            count += n & 1;
            n >>= 1;
        }
        return count;
    }
 
    // Pair class to store mask and xor value
    class Pair {
        public int mask;
        public int xorVal;
 
        public Pair(int mask, int xorVal)
        {
            this.mask = mask;
            this.xorVal = xorVal;
        }
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        string s = "1234";
        int K = 1;
 
        // Function call
        Console.WriteLine(Solve(s, K));
 
        s = "6479";
        K = 2;
        // Function call
        Console.WriteLine(Solve(s, K));
    }
}
 
// This code is contributed by Tapesh(tapeshdua420)




function solve(s, K) {
    const N = s.length;
 
    // Initialize a priority queue (min heap)
    const pq = [[0, 0]];
 
    for (let i = 0; i < N; i++) {
        const digit = parseInt(s[i]);
        const sz = pq.length;
 
        for (let j = 0; j < sz; j++) {
            const [mask, xorVal] = pq.shift();
            const setBits = (mask.toString(2).match(/1/g) || []).length;
 
            if (setBits < K) {
                pq.push([mask & ~(1 << setBits), xorVal ^ digit]);
            }
 
            pq.push([mask & ~(0 << setBits), xorVal]);
        }
    }
 
    let ans = 0;
 
    while (pq.length > 0) {
        ans = Math.max(ans, pq[0][1]);
        pq.shift();
    }
 
    // Result after removing K characters
    return ans;
}
 
// Driver code
const s1 = "1234";
const K1 = 1;
console.log(solve(s1, K1));
 
const s2 = "6479";
const K2 = 2;
console.log(solve(s2, K2));
//Contributed by Aditi Tyagi

Output
7
15







Time Complexity: O(N * 2^K * log K)
Auxiliary Space: O(2^K)


Article Tags :