Sum of prime numbers in range [L, R] from given Array for Q queries

Given an array arr[] of the size of N followed by an array of Q queries, of the following two types:

  • Query Type 1: Given two integers L and R, find the sum of prime elements from index L to R where 0 <= L <= R <= N-1.
  • Query Type 2: Given two integers i and X, change arr[i] = X where 0 <= i <= n-1.

Note: Every first index of the subquery determines the type of query to be answered.

Example: 

Input: arr[] = {1, 3, 5, 7, 9, 11}, Q = { { 1, 1, 3}, {2, 1, 10}, {1, 1, 3 } }
Output: 
15
12
Explanation: 
First query is of type 1, so answer is (3 + 5 + 7), = 15
Second query is of type 2, so arr[1] = 10
Third query is of type 1, where arr[1] = 10, which is not prime hence answer is (5 + 7) = 12

Input: arr[] = {1, 2, 35, 7, 14, 11}, Q = { {2, 4, 3}, {1, 4, 5 } }
Output: 14
Explanation:
First query is of type 2, So update arr[4] = 3
Second query is of type 1, since arr[4] = 3, which is prime. So answer is (3 + 11) = 14

Naive Approach: The idea is to iterate for each query between L to R and perform the required operation on the given array. 



Time Complexity: O(Q  * N * (O(sqrt(max(arr[i]))

Approach:  To optimize the problem use Segment tree and Sieve Of Eratosthenes.

  • First, create a boolean array that will mark the prime numbers.
  • Now while making the segment tree only add those array elements as leaf nodes which are prime.

C++

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
  
int const MAX = 1000001;
bool prime[MAX];
  
// Function to find the prime numbers
void SieveOfEratosthenes()
{
    // Create a boolean array prime[]
    // and initialize all entries it as true
    // A value in prime[i] will
    // finally be false if i is Not a prime
    memset(prime, true, sizeof(prime));
  
    for (int p = 2; p * p <= MAX; p++) {
  
        // Check if prime[p] is not
        // changed, then it is a prime
        if (prime[p] == true) {
  
            // Update all multiples of p
            // greater than or equal to
            // the square of it numbers
            // which are multiple of p
            // and are less than p^2 are
            // already been marked
            for (int i = p * p; i <= MAX; i += p)
                prime[i] = false;
        }
    }
}
  
// Function to get the middle
// index from corner indexes
int getMid(int s, int e)
{
    return s + (e - s) / 2;
}
  
// Function to get the sum of
// values in the given range
// of the array
  
int getSumUtil(int* st, int ss,
               int se, int qs,
               int qe, int si)
{
    // If segment of this node is a
    // part of given range, then
    // return the sum of the segment
    if (qs <= ss && qe >= se)
        return st[si];
  
    // If segment of this node is
    // outside the given range
    if (se < qs || ss > qe)
        return 0;
  
    // If a part of this segment
    // overlaps with the given range
    int mid = getMid(ss, se);
  
    return getSumUtil(st, ss, mid,
                      qs, qe,
                      2 * si + 1)
           + getSumUtil(st, mid + 1,
                        se, qs, qe,
                        2 * si + 2);
}
  
// Function to update the nodes which
// have the given index in their range
void updateValueUtil(int* st, int ss,
                     int se, int i,
                     int diff, int si)
{
    // If the input index lies
    // outside the range of
    // this segment
    if (i < ss || i > se)
        return;
  
    // If the input index is in
    // range of this node, then update
    // the value of the node and its children
    st[si] = st[si] + diff;
    if (se != ss) {
        int mid = getMid(ss, se);
        updateValueUtil(st, ss, mid, i,
                        diff, 2 * si + 1);
        updateValueUtil(st, mid + 1,
                        se, i, diff,
                        2 * si + 2);
    }
}
  
// Function to update a value in
// input array and segment tree
void updateValue(int arr[], int* st,
                 int n, int i,
                 int new_val)
{
    // Check for erroneous input index
    if (i < 0 || i > n - 1) {
        cout << "-1";
        return;
    }
  
    // Get the difference between
    // new value and old value
    int diff = new_val - arr[i];
  
    int prev_val = arr[i];
  
    // Update the value in array
    arr[i] = new_val;
  
    // Update the values of
    // nodes in segment tree
    // only if either previous
    // value or new value
    // or both are prime
    if (prime[new_val]
        || prime[prev_val]) {
  
        // If only new value is prime
        if (!prime[prev_val])
            updateValueUtil(st, 0, n - 1,
                            i, new_val, 0);
  
        // If only new value is prime
        else if (!prime[new_val])
            updateValueUtil(st, 0, n - 1,
                            i, -prev_val, 0);
  
        // If both are prime
        else
            updateValueUtil(st, 0, n - 1,
                            i, diff, 0);
    }
}
  
// Return sum of elements in range
// from index qs (quey start) to qe
// (query end). It mainly uses getSumUtil()
int getSum(int* st, int n, int qs, int qe)
{
    // Check for erroneous input values
    if (qs < 0 || qe > n - 1 || qs > qe) {
        cout << "-1";
        return -1;
    }
  
    return getSumUtil(st, 0, n - 1,
                      qs, qe, 0);
}
  
// Function that constructs Segment Tree
int constructSTUtil(int arr[], int ss,
                    int se, int* st,
                    int si)
{
    // If there is one element in
    // array, store it in current node of
    // segment tree and return
    if (ss == se) {
  
        // Only add those elements in segment
        // tree which are prime
        if (prime[arr[ss]])
            st[si] = arr[ss];
        else
            st[si] = 0;
        return st[si];
    }
  
    // If there are more than one
    // elements, then recur for left and
    // right subtrees and store the
    // sum of values in this node
    int mid = getMid(ss, se);
    st[si]
        = constructSTUtil(arr, ss, mid,
                          st, si * 2 + 1)
          + constructSTUtil(arr, mid + 1,
                            se, st,
                            si * 2 + 2);
    return st[si];
}
  
// Function to construct segment
// tree from given array
int* constructST(int arr[], int n)
{
    // Allocate memory for the segment tree
  
    // Height of segment tree
    int x = (int)(ceil(log2(n)));
  
    // Maximum size of segment tree
    int max_size = 2 * (int)pow(2, x) - 1;
  
    // Allocate memory
    int* st = new int[max_size];
  
    // Fill the allocated memory st
    constructSTUtil(arr, 0, n - 1, st, 0);
  
    // Return the constructed segment tree
    return st;
}
  
// Driver code
int main()
{
    int arr[] = { 1, 3, 5, 7, 9, 11 };
    int n = sizeof(arr) / sizeof(arr[0]);
  
    int Q[3][3]
        = { { 1, 1, 3 },
            { 2, 1, 10 },
            { 1, 1, 3 } };
  
    // Function call
    SieveOfEratosthenes();
  
    // Build segment tree from given array
    int* st = constructST(arr, n);
  
    // Print sum of values in
    // array from index 1 to 3
    cout << getSum(st, n, 1, 3) << endl;
  
    // Update: set arr[1] = 10
    // and update corresponding
    // segment tree nodes
    updateValue(arr, st, n, 1, 10);
  
    // Find sum after the value is updated
    cout << getSum(st, n, 1, 3) << endl;
    return 0;
}

chevron_right


Python3

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python3 program for the above approach
import math
MAX = 1000001
  
# Create a boolean array prime[]
# and initialize all entires it as true
# A value in prime[i] will 
# finally be false if i is Not a prime
prime = [True] * MAX
  
# Function to find prime numbers
def SieveOfEratosthenes():
      
    p = 2
    while p * p <= MAX:
          
        # Check if prime[p] is not 
        # changed, then it is a prime 
        if prime[p] == True:
              
            # Update all multiples of p 
            # greater than or equal to 
            # the square of it numbers 
            # which are multiple of p 
            # and are less than p^2 are 
            # already been marked 
            for i in range(p * p, MAX, p):
                prime[i] = False
          
        p += 1
          
# Function to get the middle 
# index from corner indexes 
def getMid(s, e):
      
    return s + (e - s) // 2
  
# Function to get the sum of 
# values in the given range 
# of the array 
def getSumUtil(st, ss, se, qs, qe, si):
      
    # If segment of this node is a 
    # part of given range, then 
    # return the sum of the segment 
    if qs <= ss and qe >= se:
        return st[si]
      
    # If segment of this node is 
    # outside the given range 
    if se < qs or ss > qe:
        return 0
      
    # If a part of this segment 
    # overlaps with the given range 
    mid = getMid(ss, se)
      
    return (getSumUtil(st, ss, mid, qs,
                       qe, 2 * si + 1) +
            getSumUtil(st, mid + 1, se,
                       qs, qe, 2 * si + 2))
  
# Function to update the nodes which 
# have the given index in their range 
def updateValueUtil(st, ss, se, i, diff, si):
      
    # If the input index lies 
    # outside the range of 
    # this segment 
    if i < ss or i > se:
        return
      
    # If the input index is in 
    # range of this node, then update 
    # the value of the node and its children 
    st[si] = st[si] + diff
      
    if se != ss:
        mid = getMid(ss, se)
        updateValueUtil(st, ss, mid, i,
                        diff, 2 * si + 1)
        updateValueUtil(st, mid + 1, se, i,
                        diff, 2 * si + 2)
          
# Function to update a value in 
# input array and segment tree 
def updateValue(arr, st, n, i, new_val):
      
    # Check for errorneous imput index
    if i < 0 or i > n - 1:
        print(-1)
        return
      
    # Get the difference between 
    # new value and old value 
    diff = new_val - arr[i]
      
    prev_val = arr[i]
      
    # Update the value in array
    arr[i] = new_val
      
    # Update the values of nodes
    # in segment tree only if 
    # either previous value or 
    # new value or both are prime 
    if prime[new_val] or prime[prev_val]:
          
        # If only new value is prime
        if not prime[prev_val]:
            updateValueUtil(st, 0, n - 1
                            i, new_val, 0)
          
        # If only old value is prime
        elif not prime[new_val]:
            updateValueUtil(st, 0, n - 1, i,
                            -prev_val, 0)
          
        # If both are prime
        else:
            updateValueUtil(st, 0, n - 1,
                            i, diff, 0)
              
# Return sum of elements in range 
# from index qs (quey start) to qe 
# (query end). It mainly uses getSumUtil() 
def getSum(st, n, qs, qe):
      
    # Check for erroneous input values 
    if qs < 0 or qe > n-1 or qs > qe:
        return -1
      
    return getSumUtil(st, 0, n - 1, qs, qe, 0)
  
# Function that constructs the Segment Tree
def constructSTUtil(arr, ss, se, st, si):
      
    # If there is one element in 
    # array, store it in current  
    # node of segment tree and return 
    if ss == se:
          
        # Only add those elements in segment 
        # tree which are prime 
        if prime[arr[ss]]:
            st[si] = arr[ss]
        else:
            st[si] = 0
          
        return st[si]
      
    # If there are more than one 
    # elements, then recur for left and 
    # right subtrees and store the 
    # sum of values in this node 
    mid = getMid(ss, se)
    st[si] = (constructSTUtil(arr, ss, mid, st,
                              2 * si + 1) +
              constructSTUtil(arr, mid + 1, se, 
                              st, 2 * si + 2))
    return st[si]
  
# Function to construct segment 
# tree from given array 
def constructST(arr, n):
      
    # Allocate memory for the segment tree
      
    # Height of segment tree
    x = int(math.ceil(math.log2(n)))
      
    # Maximum size of segment tree
    max_size = 2 * int(pow(2, x)) - 1
      
    # Allocate memory
    st = [0] * max_size
      
    # Fill the allocated memory st 
    constructSTUtil(arr, 0, n - 1, st, 0)
      
    # Return the constructed segment tree 
    return st
  
# Driver code
arr = [ 1, 3, 5, 7, 9, 11 ]
n = len(arr)
  
Q = [ [ 1, 1, 3 ],
      [ 2, 1, 10 ],
      [ 1, 1, 3 ] ]
  
# Function call
SieveOfEratosthenes()
  
# Build segment tree from given array
st = constructST(arr, n)
  
# Print sum of values in 
# array from index 1 to 3 
print(getSum(st, n, 1, 3))
  
# Update: set arr[1] = 10 
# and update corresponding 
# segment tree nodes 
updateValue(arr, st, n, 1, 10)
  
# Find sum after value is updated
print(getSum(st, n, 1, 3))
  
# This code is contrinuted by Stuti Pathak

chevron_right


Output: 

15
12

 

Time Complexity: O(Q * log N) 
Auxiliary Space: O(N)
 

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.




My Personal Notes arrow_drop_up

Recommended Posts:


Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.



Improved By : stutipathak31jan