Queries to count array elements greater than or equal to a given number with updates

Given two arrays arr[] and query[] of sizes N and Q respectively and an integer M, the task for each query, is to count the number of array elements which are greater than or equal to query[i] and decrease all such numbers by M and perform the rest of the queries on the updated array.
Examples:

Input: arr[] = {1, 2, 3, 4}, query[] = {4, 3, 1}, M = 1 
Output: 1 2 4 
Explanation: 
query[0]: Count of array elements which are greater than or equal to 4 in arr[] is 1 and decreasing the number by M modifies array to {1, 2, 3, 3}. 
query[1]: Count of array elements which are greater than or equal to 3 in arr[] is 2 and decreasing all such numbers by M modifies array to {1, 2, 2, 2}. 
query[2]: Count of array elements which are greater than or equal to 1 in arr[] is 4 and decreasing all such numbers by M modifies array to {0, 1, 1, 1}.

Input: arr[] = {1, 2, 3}, query = {3, 3}, M = 2 
Output: 1 0 
Explanation: 
query[0]: Count of array elements which are greater than or equal to in arr[] is 1 and decreasing that number by M modifies array to arr[] = {1, 2, 1}. 
query[1]: Count of array elements which are greater than or equal to 3 in arr[] is 0. So the array remains unchanged.

Naive Approach : The simplest approach is to iterate over the entire array for every query, and check that if the current array element is greater than or equal to query[i]. For elements for which the above condition is found to be true, subtract that element by M and increment the count.Finaly print the count for each query. 

Time Complexity: O(Q * N) 
Auxiliary Space: O(1)

Efficient Approach: The problem can be solved using Segment Trees.



  1. First, sort the array arr[] in non decreasing order.
  2. Now, find the first position of element, say l, which contains an element >= query[i].
  3. If no such elements exist then answer for that query will be 0. Otherwise answer will be N – l.
  4. Finally update the segment tree in the given range l to N – 1 and subtract M from all elements in the given range.

Illustration: 
Consider the following example : arr[] = {1, 2, 3, 4}, query[] = {4, 3, 1}, M = 1 
After sorting arr[] = {1, 2, 3, 4}. 
query[0]: K = 4 and arr[3] >= 4 so l = 3 and result = 4 – 3 = 1 and updated arr[] = {1, 2, 3, 3} 
query[1]: K = 3 and arr[2] >=3 so l = 2 and result = 4 – 2 = 2 and updated arr[] = {1, 2, 2, 2} 
query[2]: K = 1 and arr[0] >=1 so l = 0 and result = 4 – 0 = 4 and updated arr[] = {0, 1, 1, 1}

Below is the implementation of above approach:

C++

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
  
// Function to build a segment tree
void build(vector<int>& sum,
           vector<int>& a,
           int l, int r, int rt)
{
    // Check for base case
    if (l == r) {
        sum[rt] = a[l - 1];
        return;
    }
  
    // Find mid point
    int m = (l + r) >> 1;
  
    // Recursively build the
    // segment tree
    build(sum, a, l, m, rt << 1);
    build(sum, a, m + 1, r, rt << 1 | 1);
}
  
// Function for push down operation
// on the segment tree
void pushDown(vector<int>& sum,
              vector<int>& add,
              int rt, int ln, int rn)
{
    if (add[rt]) {
        add[rt << 1] += add[rt];
        add[rt << 1 | 1] += add[rt];
        sum[rt << 1] += add[rt] * ln;
        sum[rt << 1 | 1] += add[rt] * rn;
        add[rt] = 0;
    }
}
  
// Function to update the segment tree
void update(vector<int>& sum,
            vector<int>& add,
            int L, int R, int C, int l,
            int r, int rt)
{
    // Complete overlap
    if (L <= l && r <= R) {
        sum[rt] += C * (r - l + 1);
        add[rt] += C;
        return;
    }
  
    // Find mid
    int m = (l + r) >> 1;
  
    // Perform push down operation
    // on segment tree
    pushDown(sum, add, rt, m - l + 1,
             r - m);
  
    // Recursively update the segment tree
    if (L <= m)
        update(sum, add, L, R, C, l, m,
               rt << 1);
  
    if (R > m)
        update(sum, add, L, R, C, m + 1, r,
               rt << 1 | 1);
}
  
// Function to process the query
int query(vector<int>& sum, 
          vector<int>& add,
          int L, int R, int l, 
          int r, int rt)
{
    // Base case
    if (L <= l && r <= R) {
        return sum[rt];
    }
  
    // Find mid
    int m = (l + r) >> 1;
  
    // Perform push down operation
    // on segment tree
    pushDown(sum, add, rt, m - l + 1,
             r - m);
  
    int ans = 0;
  
    // Recursively calculate the result
    // of the query
    if (L <= m)
        ans += query(sum, add, L, R, l, m,
                     rt << 1);
    if (R > m)
        ans += query(sum, add, L, R, m + 1, r,
                     rt << 1 | 1);
  
    // Return the result
    return ans;
}
  
// Function to count the numbers
// which are greater than the given query
void sequenceMaintenance(int n, int q,
                         vector<int>& a,
                         vector<int>& b, 
                         int m)
{
    // Sort the input array
    sort(a.begin(), a.end());
  
    // Create segment tree of size 4*n
    vector<int> sum, add, ans;
    sum.assign(n << 2, 0);
    add.assign(n << 2, 0);
  
    // Build the segment tree
    build(sum, a, 1, n, 1);
  
    // Iterate over the queries
    for (int i = 0; i < q; i++) {
        int l = 1, r = n, pos = -1;
        while (l <= r) {
            int m = (l + r) >> 1;
            if (query(sum, add, m, m, 1, n, 1)
                >= b[i]) {
                r = m - 1;
                pos = m;
            }
            else {
                l = m + 1;
            }
        }
        if (pos == -1)
            ans.push_back(0);
        else {
            // Store result in array
            ans.push_back(n - pos + 1);
  
            // Update the elements in
            // the given range
            update(sum, add, pos, n, -m,
                   1, n, 1);
        }
    }
  
    // Print the result of queries
    for (int i = 0; i < ans.size(); i++) {
        cout << ans[i] << " ";
    }
}
  
// Driver Code
int main()
{
    int N = 4;
    int Q = 3;
    int M = 1;
    vector<int> arr = { 1, 2, 3, 4 };
    vector<int> query = { 4, 3, 1 };
  
    // Function Call
    sequenceMaintenance(N, Q, arr, query, M);
    return 0;
}

chevron_right


Python3

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python3 program for the above approach
  
# Function to build a segment tree
def build(sum, a, l, r, rt):
  
    # Check for base case
    if (l == r):
        sum[rt] = a[l - 1]
        return
          
    # Find mid point
    m = (l + r) >> 1
  
    # Recursively build the
    # segment tree
    build(sum, a, l, m, rt << 1)
    build(sum, a, m + 1, r, rt << 1 | 1)
  
# Function for push down operation
# on the segment tree
def pushDown(sum, add, rt, ln, rn):
  
    if (add[rt]):
        add[rt << 1] += add[rt]
        add[rt << 1 | 1] += add[rt]
        sum[rt << 1] += add[rt] * ln
        sum[rt << 1 | 1] += add[rt] * rn
        add[rt] = 0
  
# Function to update the segment tree
def update(sum, add, L, R, C, l, r, rt):
  
    # Complete overlap
    if (L <= l and r <= R):
        sum[rt] += C * (r - l + 1)
        add[rt] += C
        return
  
    # Find mid
    m = (l + r) >> 1
  
    # Perform push down operation
    # on segment tree
    pushDown(sum, add, rt, m - l + 1, r - m)
  
    # Recursively update the segment tree
    if (L <= m):
        update(sum, add, L, R, C, l,
               m, rt << 1)
  
    if (R > m):
        update(sum, add, L, R, C, m + 1,
               r, rt << 1 | 1)
  
# Function to process the queryy
def queryy(sum, add, L, R, l, r, rt):
  
    # Base case
    if (L <= l and r <= R):
        return sum[rt]
  
    # Find mid
    m = (l + r) >> 1
  
    # Perform push down operation
    # on segment tree
    pushDown(sum, add, rt, m - l + 1, r - m)
  
    ans = 0
  
    # Recursively calculate the result
    # of the queryy
    if (L <= m):
        ans += queryy(sum, add, L, R, l, 
                      m, rt << 1)
    if (R > m):
        ans += queryy(sum, add, L, R, m + 1,
                      r, (rt << 1 | 1))
  
    # Return the result
    return ans
  
# Function to count the numbers
# which are greater than the given queryy
def sequenceMaintenance(n, q, a, b, m):
  
    # Sort the input array
    a = sorted(a)
  
    # Create segment tree of size 4*n
    # vector<int> sum, add, ans
    sum = [0] * (4 * n)
    add = [0] * (4 * n)
    ans = []
  
    # Build the segment tree
    build(sum, a, 1, n, 1)
  
    #print(sum)
  
    # Iterate over the queries
    for i in range(q):
        l = 1
        r = n
        pos = -1
          
        while (l <= r):
            m = (l + r) >> 1
            if (queryy(sum, add, m, m,
                       1, n, 1) >= b[i]):
                r = m - 1
                pos = m
  
            else:
                l = m + 1
  
        if (pos == -1):
            ans.append(0)
        else:
              
            # Store result in array
            ans.append(n - pos + 1)
  
            # Update the elements in
            # the given range
            update(sum, add, pos, n, 
                   -m, 1, n, 1)
  
    # Print the result of queries
    for i in ans:
        print(i, end = " ")
  
# Driver Code
if __name__ == '__main__':
  
    N = 4
    Q = 3
    M = 1
      
    arr = [ 1, 2, 3, 4 ]
    query = [ 4, 3, 1 ]
  
    # Function call
    sequenceMaintenance(N, Q, arr, query, M)
  
# This code is contributed by mohit kumar 29

chevron_right


C#

filter_none

edit
close

play_arrow

link
brightness_4
code

// C# program for the above approach 
using System; 
using System.Collections;
using System.Collections.Generic;  
   
class GFG{ 
     
// Function to build a segment tree
static void build(ArrayList sum,
                  ArrayList a,
                  int l, int r, 
                  int rt)
{
      
    // Check for base case
    if (l == r)
    {
        sum[rt] = a[l - 1];
        return;
    }
    
    // Find mid point
    int m = (l + r) >> 1;
    
    // Recursively build the
    // segment tree
    build(sum, a, l, m, rt << 1);
    build(sum, a, m + 1, r, rt << 1 | 1);
}
    
// Function for push down operation
// on the segment tree
static void pushDown(ArrayList sum,
                     ArrayList add,
                     int rt, int ln, 
                     int rn)
{
    if ((int)add[rt] != 0) 
    {
        add[rt << 1] = (int)add[rt << 1] +
                       (int)add[rt];
        add[rt << 1 | 1] = (int)add[rt << 1 | 1] +
                           (int)add[rt];
        sum[rt << 1] = (int)sum[rt << 1] +
                       (int)add[rt] * ln;
        sum[rt << 1 | 1] = (int)sum[rt << 1 | 1] +
                           (int)add[rt] * rn;
        add[rt] = 0;
    }
}
    
// Function to update the segment tree
static void update(ArrayList sum,
                   ArrayList add,
                   int L, int R, int C,
                   int l, int r, int rt)
{
      
    // Complete overlap
    if (L <= l && r <= R)
    {
        sum[rt] = (int)sum[rt] + 
                    C * (r - l + 1);
        add[rt] = (int)add[rt] + C;
        return;
    }
    
    // Find mid
    int m = (l + r) >> 1;
    
    // Perform push down operation
    // on segment tree
    pushDown(sum, add, rt, m - l + 1,
                           r - m);
    
    // Recursively update the segment tree
    if (L <= m)
        update(sum, add, L, R, C, l, m,
               rt << 1);
    if (R > m)
        update(sum, add, L, R, C, m + 1, r,
               rt << 1 | 1);
}
    
// Function to process the query
static int query(ArrayList sum, 
                 ArrayList add,
                 int L, int R, int l, 
                 int r, int rt)
{
      
    // Base case
    if (L <= l && r <= R)
    {
        return (int)sum[rt];
    }
    
    // Find mid
    int m = (l + r) >> 1;
    
    // Perform push down operation
    // on segment tree
    pushDown(sum, add, rt, m - l + 1,
                           r - m);
    
    int ans = 0;
    
    // Recursively calculate the result
    // of the query
    if (L <= m)
        ans += query(sum, add, L, R, l, m,
                     rt << 1);
    if (R > m)
        ans += query(sum, add, L, R, m + 1, r,
                     rt << 1 | 1);
    
    // Return the result
    return ans;
}
    
// Function to count the numbers
// which are greater than the given query
static void sequenceMaintenance(int n, int q,
                                ArrayList a,
                                ArrayList b, 
                                int m)
{
      
    // Sort the input array
    a.Sort();
    
    // Create segment tree of size 4*n
    ArrayList sum = new ArrayList();
    ArrayList add = new ArrayList();
    ArrayList ans = new ArrayList();
       
    for(int i = 0; i < (n << 2); i++)
    {
        sum.Add(0);
        add.Add(0);
    }
       
    // Build the segment tree
    build(sum, a, 1, n, 1);
    
    // Iterate over the queries
    for(int i = 0; i < q; i++)
    {
        int l = 1, r = n, pos = -1;
        while (l <= r)
        {
            m = (l + r) >> 1;
            if (query(sum, add, m, m,
                      1, n, 1) >= (int)b[i])
            {
                r = m - 1;
                pos = m;
            }
            else
            {
                l = m + 1;
            }
        }
        if (pos == -1)
            ans.Add(0);
        else
        {
              
            // Store result in array
            ans.Add(n - pos + 1);
    
            // Update the elements in
            // the given range
            update(sum, add, pos, n, -m,
                   1, n, 1);
        }
    }
    
    // Print the result of queries
    for(int i = 0; i < ans.Count; i++)
    {
        Console.Write(ans[i] + " ");
    }
}
   
// Driver Code 
public static void Main(string[] args) 
    int N = 4;
    int Q = 3;
    int M = 1;
      
    ArrayList arr = new ArrayList(){ 1, 2, 3, 4 };
    ArrayList query = new ArrayList(){ 4, 3, 1 };
    
    // Function call
    sequenceMaintenance(N, Q, arr, query, M);
  
// This code is contributed by rutvik_56

chevron_right


Output: 

1 2 4

Time Complexity : O (N + (Q * logN)) 
Auxiliary Space : O (N)
 

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.




My Personal Notes arrow_drop_up

Recommended Posts:


I am currently pursuing B Tech (Hons) in Computer Science I love technology and explore new things I want to become a very good programmer I love to travel and meet new peoples

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.



Improved By : mohit kumar 29, rutvik_56