Open In App

Subarray with XOR less than k

Improve
Improve
Like Article
Like
Save
Share
Report

Given an array arr[] of n numbers and a number k. You have to write a program to find the number of subarrays with xor less than k.

Examples: 

Input: arr[] = {8, 9, 10, 11, 12}, k=3
Output: 3
Explanation: There are 3 subarrays with XOR < 3:

  • arr[0…1] = {8, 9} and 8 ^ 9 = 1
  • arr[0…3] = {8, 9, 10, 11} and 8 ^ 9 ^ 10 ^ 11 = 0
  • arr[2…3] = {10, 11} and 10 ^ 11 = 1

Input: arr[] = {12, 4, 6, 8, 21}, k=8
Output: 4
Explanation: There are 4 subarrays with XOR < 8:

  • arr[0…3] = {12, 4, 6, 8} and 12 ^ 4 ^ 6 ^ 8 = 6
  • arr[1…1] = {4} = 4
  • arr[1…2] = {4, 6} and 4 ^ 6 = 2
  • arr[2…2] = {6} = 6

Naive Approach: The naive algorithm is to simply calculate the xor value of each and every subarray and compare it with the given number k to find the answer. 

Below is the implementation of the above approach: 

C++
// C++ program to count number of
// subarrays with XOR less than k
#include <iostream>
using namespace std;

// function to count number of
// subarrays with XOR less than k
int xorLessK(int arr[], int n, int k)
{
    int count = 0;

    // check all subarrays
    for (int i = 0; i < n; i++) {
        int tempXor = 0;
        for (int j = i; j < n; j++) {
            tempXor ^= arr[j];
            if (tempXor < k)
                count++;
        }
    }

    return count;
}

// Driver program to test above function
int main()
{
    int n, k = 3;
    int arr[] = { 8, 9, 10, 11, 12 };

    n = sizeof(arr) / sizeof(arr[0]);

    cout << xorLessK(arr, n, k);

    return 0;
}
Java
// Java program to count number of 
// subarrays with XOR less than k 

import java.io.*;

class GFG {
    
// function to count number of 
// subarrays with XOR less than k 
static int xorLessK(int arr[], int n, int k) 
{ 
    int count = 0; 

    // check all subarrays 
    for (int i = 0; i < n; i++) { 
        int tempXor = 0; 
        for (int j = i; j < n; j++) { 
            tempXor ^= arr[j]; 
            if (tempXor < k) 
                count++; 
        } 
    } 

    return count; 
} 

// Driver program to test above function 
    public static void main (String[] args) {

        
    int k = 3; 
    int arr[] = new int[] { 8, 9, 10, 11, 12 }; 
    int n = arr.length; 

    System.out.println(xorLessK(arr, n, k)); 
        
    }
}
C#
// C# program to count number of 
// subarrays with XOR less than k 
using System;

class GFG {
    
// function to count number of 
// subarrays with XOR less than k 
static int xorLessK(int []arr, int n, int k) 
{ 
    int count = 0; 

    // check all subarrays 
    for (int i = 0; i < n; i++) { 
        
        int tempXor = 0; 
        
        for (int j = i; j < n; j++) { 
            
            tempXor ^= arr[j]; 
            if (tempXor < k) 
                count++; 
        } 
    } 

    return count; 
} 

// Driver Code
static public void Main ()
{
        
    int k = 3; 
    int []arr = new int[] {8, 9, 10,
                           11, 12 }; 
    int n = arr.Length; 
    Console.WriteLine(xorLessK(arr, n, k)); 
    
} 
} 
Javascript
<script>
    // Javascript program to count number of 
    // subarrays with XOR less than k 
    
    // function to count number of 
    // subarrays with XOR less than k 
    function xorLessK(arr, n, k) 
    { 
        let count = 0; 

        // check all subarrays 
        for (let i = 0; i < n; i++) { 

            let tempXor = 0; 

            for (let j = i; j < n; j++) { 

                tempXor ^= arr[j]; 
                if (tempXor < k) 
                    count++; 
            } 
        } 

        return count; 
    } 
    
    let k = 3; 
    let arr = [8, 9, 10, 11, 12]; 
    let n = arr.length; 
    document.write(xorLessK(arr, n, k)); 
    
</script>
PHP
<?php
// PHP program to count number of
// subarrays with XOR less than k

// function to count number of
// subarrays with XOR less than k
function xorLessK($arr, $n, $k)
{
    $count = 0;

    // check all subarrays
    for ($i = 0; $i < $n; $i++) 
    {
        $tempXor = 0;
        for ($j = $i; $j < $n; $j++)
        {
            $tempXor ^= $arr[$j];
            if ($tempXor < $k)
                $count++;
        }
    }

    return $count;
}

    // Driver Code
    $n; $k = 3;
    $arr = array(8, 9, 10, 11, 12);

    $n = count($arr);

    echo xorLessK($arr, $n, $k);

// This code is contributed by anuj_67.
?>
Python3
# Python 3 program to count number of
# subarrays with XOR less than k

# function to count number of
# subarrays with XOR less than k
def xorLessK(arr, n, k):
    count = 0

    # check all subarrays
    for i in range(n):
        tempXor = 0
        for j in range(i, n):
            tempXor ^= arr[j]
            if (tempXor < k):
                count += 1
    
    return count

# Driver Code
if __name__ == '__main__':
    k = 3
    arr = [8, 9, 10, 11, 12]

    n = len(arr)

    print(xorLessK(arr, n, k))

# This code is contributed by
# Sahil_shelangia

Output
3

Time Complexity: O(N * N)
Space complexity: O(1)

Efficient Approach: An efficient approach will be to calculate all of the prefix xor values i.e. a[1:i] for all i. 
It can be verified that the xor of a subarray a[l:r] can be written as (a[1:l-1] xor a[1:r]), where a[i, j] is the xor of all the elements with index such that, i <= index <= j. 

Explanation: 
We will store a number as a binary number in trie. The left child will show that the next bit is 0 and the right child will show the next bit is 1. 
For example, the given picture below shows numbers 1001 and 1010 in trie. 

trie1


If xor[i, j] represents the xor of all elements in the subarray a[i, j], then at an index i what we have is, a trie which has xor[1:1], xor[1:2]…..xor[1:i-1] already inserted. Now we somehow count how many of these (numbers in trie) are such that its xor with xor[1:i] is smaller than k. This will cover all the subarrays ending at the index i and having xor i.e. xor[j, i] <=k;
Now the problem remains, how to count the numbers with xor smaller than k. So, for example, take the current bit of the ith index element is p, a current bit of number k be q and the current node in trie be node. 

Take the case when p=1, k=1. Then if we go to the right child the current xor would be 0 (as the right child means 1 and p=1, 1(xor)1=0).As q=1, all the numbers that are to the right child of this node would give xor value smaller than k. So, we would count the numbers that are right to this node. 
If we go to the left child, the current xor would be 1 (as the left child means 0 and p=1, 0(xor)1=1). So, if we go to the left child we can still find number with xor smaller than k, therefore we move on to the left child.

So, to count the numbers that are below a given node, we modify the trie and each node will also store the number of leafs in that subtree and this would be modified after each insertion.

Other three cases for different values of p and k can be solved in the same way to the count the number of numbers with xor less than k.

Below is the C++ implementation of the above idea: 

CPP
#include <iostream>
using namespace std;
class trienode {
public:
    int left_count, right_count;
    trienode* left;
    trienode* right;
    trienode()
    {
        left_count = 0;
        right_count = 0;

        // Left denotes 0
        left = NULL;

        // Right denotes 1
        right = NULL;
    }
};
void insert(trienode* root, int element)
{
    for (int i = 31; i >= 0; i--) {
        int x = (element >> i) & 1;

        // If the current bit is 1
        if (x) {
            root->right_count++;
            if (root->right == NULL)
                root->right = new trienode();
            root = root->right;
        }
        else {
            root->left_count++;
            if (root->left == NULL)
                root->left = new trienode();
            root = root->left;
        }
    }
}
int query(trienode* root, int element, int k)
{
    if (root == NULL)
        return 0;
    int res = 0;
    for (int i = 31; i >= 0; i--) {
        bool current_bit_of_k = (k >> i) & 1;
        bool current_bit_of_element = (element >> i) & 1;

        // If the current bit of k is 1
        if (current_bit_of_k) {
            // If current bit of element is 1
            if (current_bit_of_element) {
                res += root->right_count;
                if (root->left == NULL)
                    return res;
                root = root->left;
            }

            // If current bit of element is 0
            else {
                res += root->left_count;
                if (root->right == NULL)
                    return res;
                root = root->right;
            }
        }

        // If the current bit of k is zero
        else {
            // If current bit of element is 1
            if (current_bit_of_element) {
                if (root->right == NULL)
                    return res;
                root = root->right;
            }

            // If current bit of element is 0
            else {
                if (root->left == NULL)
                    return res;
                root = root->left;
            }
        }
    }
    return res;
}

// Driver code
int main()
{

    int n = 5, k = 3;
    int arr[] = { 8, 9, 10, 11, 12 };

    // Below three variables are used for storing
    // current XOR
    int temp, temp1, temp2 = 0;
    trienode* root = new trienode();
    insert(root, 0);
    long long total = 0;
    for (int i = 0; i < n; i++) {
        temp = arr[i];
        temp1 = temp2 ^ temp;
        total += query(root, temp1, k);
        insert(root, temp1);
        temp2 = temp1;
    }

    cout << total << endl;

    return 0;
}
Java
// Java code for above approach
import java.util.*;

class TrieNode {
    int left_count, right_count;
    TrieNode left;
    TrieNode right;

    TrieNode() {
        left_count = 0;
        right_count = 0;
        // Left denotes 0
        left = null;
        // Right denotes 1
        right = null;
    }
}

class Main {
    static void insert(TrieNode root, int element) {
        for (int i = 31; i >= 0; i--) {
            int x = (element >> i) & 1;
            // If the current bit is 1
            if (x == 1) {
                root.right_count++;
                if (root.right == null)
                    root.right = new TrieNode();
                root = root.right;
            } else {
                root.left_count++;
                if (root.left == null)
                    root.left = new TrieNode();
                root = root.left;
            }
        }
    }

    static int query(TrieNode root, int element, int k) {
        if (root == null)
            return 0;
        int res = 0;
        for (int i = 31; i >= 0; i--) {
            int current_bit_of_k = (k >> i) & 1;
            int current_bit_of_element = (element >> i) & 1;
            // If the current bit of k is 1
            if (current_bit_of_k == 1) {
                // If current bit of element is 1
                if (current_bit_of_element == 1) {
                    res += root.right_count;
                    if (root.left == null)
                        return res;
                    root = root.left;
                } 
                // If current bit of element is 0
                else {
                    res += root.left_count;
                    if (root.right == null)
                        return res;
                    root = root.right;
                }
            } 
            // If the current bit of k is zero
            else {
                // If current bit of element is 1
                if (current_bit_of_element == 1) {
                    if (root.right == null)
                        return res;
                    root = root.right;
                } 
                // If current bit of element is 0
                else {
                    if (root.left == null)
                        return res;
                    root = root.left;
                }
            }
        }
        return res;
    }

    // Driver code
    public static void main(String[] args) {
        int n = 5, k = 3;
        int[] arr = { 8, 9, 10, 11, 12 };

        // Below three variables are used for storing
        // current XOR
        int temp, temp1, temp2 = 0;
        TrieNode root = new TrieNode();
        insert(root, 0);
        long total = 0;
        for (int i = 0; i < n; i++) {
            temp = arr[i];
            temp1 = temp2 ^ temp;
            total += query(root, temp1, k);
            insert(root, temp1);
            temp2 = temp1;
        }

        System.out.println(total);
    }
}

// This code is contributed by Aman Kumar.
C#
// C# code for above approach
using System;

public class TrieNode {
  public int left_count, right_count;
  public TrieNode left;
  public TrieNode right;

  public TrieNode() {
    left_count = 0;
    right_count = 0;
    // Left denotes 0
    left = null;
    // Right denotes 1
    right = null;
  }
}

public class MainClass {
  public static void Insert(TrieNode root, int element) {
    for (int i = 31; i >= 0; i--) {
      int x = (element >> i) & 1;
      // If the current bit is 1
      if (x == 1) {
        root.right_count++;
        if (root.right == null)
          root.right = new TrieNode();
        root = root.right;
      } else {
        root.left_count++;
        if (root.left == null)
          root.left = new TrieNode();
        root = root.left;
      }
    }
  }

  public static int Query(TrieNode root, int element, int k) {
    if (root == null)
      return 0;
    int res = 0;
    for (int i = 31; i >= 0; i--) {
      int current_bit_of_k = (k >> i) & 1;
      int current_bit_of_element = (element >> i) & 1;
      // If the current bit of k is 1
      if (current_bit_of_k == 1) {
        // If current bit of element is 1
        if (current_bit_of_element == 1) {
          res += root.right_count;
          if (root.left == null)
            return res;
          root = root.left;
        }
        // If current bit of element is 0
        else {
          res += root.left_count;
          if (root.right == null)
            return res;
          root = root.right;
        }
      }
      // If the current bit of k is zero
      else {
        // If current bit of element is 1
        if (current_bit_of_element == 1) {
          if (root.right == null)
            return res;
          root = root.right;
        }
        // If current bit of element is 0
        else {
          if (root.left == null)
            return res;
          root = root.left;
        }
      }
    }
    return res;
  }

  // Driver code
  public static void Main(string[] args) {
    int n = 5, k = 3;
    int[] arr = { 8, 9, 10, 11, 12 };

    // Below three variables are used for storing
    // current XOR
    int temp, temp1, temp2 = 0;
    TrieNode root = new TrieNode();
    Insert(root, 0);
    long total = 0;
    for (int i = 0; i < n; i++) {
      temp = arr[i];
      temp1 = temp2 ^ temp;
      total += Query(root, temp1, k);
      Insert(root, temp1);
      temp2 = temp1;
    }

    Console.WriteLine(total);
  }
}

// This code is contributed by Utkarsh.
Javascript
// Javascript code for above approach

class TrieNode {

    constructor(){
        this.left_count = 0;
        this.right_count = 0;
        // Left denotes 0
        this.left = null;
        // Right denotes 1
        this.right = null;
    }
}

    
function insert(root, element) {
    for (let i = 31; i >= 0; i--) {
        let x = (element >> i) & 1;
        // If the current bit is 1
        if (x == 1) {
            root.right_count++;
            if (root.right == null)
                root.right = new TrieNode();
            root = root.right;
        } else {
            root.left_count++;
            if (root.left == null)
                root.left = new TrieNode();
            root = root.left;
        }
    }
}

function query(root, element, k) {
    if (root == null)
        return 0;
    let res = 0;
    for (let i = 31; i >= 0; i--) {
        let current_bit_of_k = (k >> i) & 1;
        let current_bit_of_element = (element >> i) & 1;
        // If the current bit of k is 1
        if (current_bit_of_k == 1) {
            // If current bit of element is 1
            if (current_bit_of_element == 1) {
                res += root.right_count;
                if (root.left == null)
                    return res;
                root = root.left;
            } 
            // If current bit of element is 0
            else {
                res += root.left_count;
                if (root.right == null)
                    return res;
                root = root.right;
            }
        } 
        // If the current bit of k is zero
        else {
            // If current bit of element is 1
            if (current_bit_of_element == 1) {
                if (root.right == null)
                    return res;
                root = root.right;
            } 
            // If current bit of element is 0
            else {
                if (root.left == null)
                    return res;
                root = root.left;
            }
        }
    }
    return res;
}

    

let n = 5, k = 3;
let arr = [8, 9, 10, 11, 12];

// Below three variables are used for storing
// current XOR
let temp, temp1, temp2 = 0;
let root = new TrieNode();
insert(root, 0);
let total = 0;
for (let i = 0; i < n; i++) {
    temp = arr[i];
    temp1 = temp2 ^ temp;
    total += query(root, temp1, k);
    insert(root, temp1);
    temp2 = temp1;
}

console.log(total);        

// The code is contributed by Arushi Jindal. 
Python3
# Python code for above approach
class TrieNode:
    def __init__(self):
        self.left_count = 0
        self.right_count = 0
        self.left = None # Left denotes 0
        self.right = None # Right denotes 1

def insert(root, element):
    for i in range(31, -1, -1):
        x = (element >> i) & 1

        # If the current bit is 1
        if x:
            root.right_count += 1
            if root.right is None:
                root.right = TrieNode()
            root = root.right
        else:
            root.left_count += 1
            if root.left is None:
                root.left = TrieNode()
            root = root.left

def query(root, element, k):
    if root is None:
        return 0
    res = 0
    for i in range(31, -1, -1):
        current_bit_of_k = (k >> i) & 1
        current_bit_of_element = (element >> i) & 1

        # If the current bit of k is 1
        if current_bit_of_k:
            # If current bit of element is 1
            if current_bit_of_element:
                res += root.right_count
                if root.left is None:
                    return res
                root = root.left

            # If current bit of element is 0
            else:
                res += root.left_count
                if root.right is None:
                    return res
                root = root.right

        # If the current bit of k is zero
        else:
            # If current bit of element is 1
            if current_bit_of_element:
                if root.right is None:
                    return res
                root = root.right

            # If current bit of element is 0
            else:
                if root.left is None:
                    return res
                root = root.left
    return res

# Driver code
if __name__ == "__main__":
    n = 5
    k = 3
    arr = [8, 9, 10, 11, 12]

    # Below three variables are used for storing
    # current XOR
    temp = 0
    temp1 = 0
    temp2 = 0
    root = TrieNode()
    insert(root, 0)
    total = 0
    for i in range(n):
        temp = arr[i]
        temp1 = temp2 ^ temp
        total += query(root, temp1, k)
        insert(root, temp1)
        temp2 = temp1

    print(total)

# This code is contributed by Pushpesh Raj.

Output
3

Time complexity: O(32*N) for both insert and query function. Therefore overall time complexity is O(N)
Space complexity: O(32*N) as we have n numbers with 32 bits so for each bit we’ll require space. Therefore overall time complexity is O(N)

Related Articles: 

 



Last Updated : 15 Mar, 2024
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads