Find the count of distinct numbers in a range

Given an array of size N containing numbers only from 0 to 63, and you are asked Q queries regarding it.

Queries are as follows:

  • 1 X Y i.e Change the element at index X to Y
  • 2 L R i.e Print the count of distinct elements present in between L and R inclusive

Examples:



Input:
N = 7
ar = {1, 2, 1, 3, 1, 2, 1}
Q = 5
{ {2, 1, 4},
  {1, 4, 2},
  {1, 5, 2},
  {2, 4, 6},
  {2, 1, 7} }
Output:
3
1
2

Input:
N = 15
ar = {4, 6, 3, 2, 2, 3, 6, 5, 5, 5, 4, 2, 1, 5, 1}
Q = 5
{ {1, 6, 5},
  {1, 4, 2},
  {2, 6, 14},
  {2, 6, 8},
  {2, 1, 6} };

Output:
5
2
5

Pre-requisites: Segment Tree and Bit Manipulation
Approach:

  • The bitMask formed in the question will denote the presence of ith number or not, we will find out that by checking the ith bit of our bitMask, if the ith bit is on, that means i’th element is present, otherwise not present.
  • Build a classical segment tree and each of its nodes will contain a bitmask which is used to decode the number of distinct elements in that particular segment which is defined by the particular node.
  • Build the segment Tree in a bottom-up manner, therefore the bitMask for the current node of the segment tree can be easily calculated by performing bitwise OR operation on the bitMasks of this node’s right and left child. Leaf nodes will be handled separately. Updation will also be done in the same manner.

Below is the implementation of the above approach:

C++

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ Program to find the distinct
// elements in a range
#include <bits/stdc++.h>
using namespace std;
  
// Function to perform queries in a range
long long query(int start, int end, int left, int right,
                int node, long long seg[])
{
    // No overlap
    if (end < left || start > right) {
        return 0;
    }
  
    // Totally Overlap
    else if (start >= left && end <= right) {
        return seg[node];
    }
  
    // Parital Overlap
    else {
        int mid = (start + end) / 2;
  
        // Finding the Answer
        // for the left Child
        long long leftChild = query(start, mid, left,
                                    right, 2 * node, seg);
  
        // Finding the Answer
        // for the right Child
        long long rightChild = query(mid + 1, end, left,
                                     right, 2 * node + 1, seg);
  
        // Combining the BitMasks
        return (leftChild | rightChild);
    }
}
  
// Function to perform update operation
// in the Segment seg
void update(int left, int right, int index, int Value,
            int node, int ar[], long long seg[])
{
    if (left == right) {
        ar[index] = Value;
  
        // Forming the BitMask
        seg[node] = (1LL << Value);
        return;
    }
  
    int mid = (left + right) / 2;
    if (index > mid) {
  
        // Updating the left Child
        update(mid + 1, right, index, Value, 2 * node + 1, ar, seg);
    }
    else {
  
        // Updating the right Child
        update(left, mid, index, Value, 2 * node, ar, seg);
    }
  
    // Updating the BitMask
    seg[node] = (seg[2 * node] | seg[2 * node + 1]);
}
  
// Building the Segment Tree
void build(int left, int right, int node, int ar[],
           long long seg[])
{
    if (left == right) {
  
        // Building the Initial BitMask
        seg[node] = (1LL << ar[left]);
  
        return;
    }
  
    int mid = (left + right) / 2;
  
    // Building the left seg tree
    build(left, mid, 2 * node, ar, seg);
  
    // Building the right seg tree
    build(mid + 1, right, 2 * node + 1, ar, seg);
  
    // Forming the BitMask
    seg[node] = (seg[2 * node] | seg[2 * node + 1]);
}
  
// Utility Function to answer the queries
void getDistinctCount(vector<vector<int> >& queries,
                      int ar[], long long seg[], int n)
{
  
    for (int i = 0; i < queries.size(); i++) {
  
        int op = queries[i][0];
  
        if (op == 2) {
  
            int l = queries[i][1], r = queries[i][2];
  
            long long tempMask = query(0, n - 1, l - 1,
                                       r - 1, 1, seg);
  
            int countOfBits = 0;
  
            // Counting the set bits which denote the
            // distinct elements
            for (int i = 63; i >= 0; i--) {
  
                if (tempMask & (1LL << i)) {
  
                    countOfBits++;
                }
            }
  
            cout << countOfBits << '\n';
        }
        else {
  
            int index = queries[i][1];
            int val = queries[i][2];
  
            // Updating the value
            update(0, n - 1, index - 1, val, 1, ar, seg);
        }
    }
}
  
// Driver Code
int main()
{
    int n = 7;
    int ar[] = { 1, 2, 1, 3, 1, 2, 1 };
  
    long long seg[4 * n] = { 0 };
    build(0, n - 1, 1, ar, seg);
  
    int q = 5;
    vector<vector<int> > queries = {
        { 2, 1, 4 },
        { 1, 4, 2 },
        { 1, 5, 2 },
        { 2, 4, 6 },
        { 2, 1, 7 }
    };
  
    getDistinctCount(queries, ar, seg, n);
  
    return 0;
}

chevron_right


Python3

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python3 Program to find the distinct
# elements in a range
  
# Function to perform queries in a range
def query(start, end, left, 
          right, node, seg):
      
    # No overlap
    if (end < left or start > right):
        return 0
  
    # Totally Overlap
    elif (start >= left and 
            end <= right):
        return seg[node]
  
    # Parital Overlap
    else:
        mid = (start + end) // 2
  
        # Finding the Answer
        # for the left Child
        leftChild = query(start, mid, left, 
                          right, 2 * node, seg)
  
        # Finding the Answer
        # for the right Child
        rightChild = query(mid + 1, end, left,
                           right, 2 * node + 1, seg)
  
        # Combining the BitMasks
        return (leftChild | rightChild)
  
# Function to perform update operation
# in the Segment seg
def update(left, right, index, 
           Value, node, ar, seg):
    if (left == right):
        ar[index] = Value
  
        # Forming the BitMask
        seg[node] = (1 << Value)
        return
  
    mid = (left + right) // 2
    if (index > mid):
  
        # Updating the left Child
        update(mid + 1, right, index, 
               Value, 2 * node + 1, ar, seg)
    else:
  
        # Updating the right Child
        update(left, mid, index,
               Value, 2 * node, ar, seg)
  
    # Updating the BitMask
    seg[node] = (seg[2 * node] | 
                 seg[2 * node + 1])
  
# Building the Segment Tree
def build(left, right, node, ar, eg):
  
    if (left == right):
  
        # Building the Initial BitMask
        seg[node] = (1 << ar[left])
  
        return
  
    mid = (left + right) // 2
  
    # Building the left seg tree
    build(left, mid, 2 * node, ar, seg)
  
    # Building the right seg tree
    build(mid + 1, right, 
                2 * node + 1, ar, seg)
  
    # Forming the BitMask
    seg[node] = (seg[2 * node] | 
                 seg[2 * node + 1])
  
# Utility Function to answer the queries
def getDistinctCount(queries, ar, seg, n):
  
    for i in range(len(queries)):
  
        op = queries[i][0]
  
        if (op == 2):
  
            l = queries[i][1]
            r = queries[i][2]
  
            tempMask = query(0, n - 1, l - 1,
                             r - 1, 1, seg)
  
            countOfBits = 0
  
            # Counting the set bits which denote 
            # the distinct elements
            for i in range(63, -1, -1):
  
                if (tempMask & (1 << i)):
  
                    countOfBits += 1
  
            print(countOfBits)
        else:
  
            index = queries[i][1]
            val = queries[i][2]
  
            # Updating the value
            update(0, n - 1, index - 1
                       val, 1, ar, seg)
  
# Driver Code
if __name__ == '__main__':
    n = 7
    ar = [1, 2, 1, 3, 1, 2, 1]
  
    seg = [0] * 4 * n
    build(0, n - 1, 1, ar, seg)
  
    q = 5
    queries = [[ 2, 1, 4 ],
               [ 1, 4, 2 ],
               [ 1, 5, 2 ],
               [ 2, 4, 6 ],
               [ 2, 1, 7 ]]
  
    getDistinctCount(queries, ar, seg, n)
  
# This code is contributed by Mohit Kumar

chevron_right


Output:

3
1
2

Time complexity: O(N + Q*Log(N))




My Personal Notes arrow_drop_up

Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.



Improved By : mohit kumar 29