Open In App

Merge Sort Tree with point update using Policy Based Data Structure

  • A Merge Sort Tree (MST) is a data structure that allows for efficient querying of the number of elements less than or equal to a certain value in a range of an array. It is built using a divide-and-conquer approach similar to the merge sort algorithm, where the array is recursively divided into two halves and a tree is constructed to represent the sorted subarrays.
  • A Policy Based Data Structure (PBD) is a data structure that allows for efficient updates, such as insertion and deletion, without the need to rebuild the entire structure. Examples of PBDs include persistent segment trees, persistent arrays, and persistent Union-Find trees.

By using a Policy Based Data Structure as the underlying data structure, we can perform point updates on the array, which is not possible with a regular Merge Sort Tree. This makes the Merge Sort Tree with point update using Policy Based Data Structure a useful tool for solving various range query problems such as counting the number of elements less than or equal to a certain value in a given range of an array.

Approach: The algorithm for the Merge Sort Tree with point update using Policy Based Data Structure is as follows:



Below is the implementation for the above approach:




// C++ code for the above approach:
#include <bits/stdc++.h>
using namespace std;
  
// Persistent Segment Tree
class PST {
public:
    struct Node {
        int val;
        int lc;
        int rc;
  
        Node(int v = 0, int l = -1, int r = -1)
            : val(v), lc(l), rc(r)
        {
        }
    };
  
    vector<Node> t;
  
    PST() { t.push_back({}); }
  
    int update(int cur, int l, int r, int idx, int diff)
    {
        int new_node = t.size();
  
        // Creates a new version of node
        t.push_back(t[cur]);
        t[new_node].val += diff;
        if (l == r) {
            return new_node;
        }
        int mid = (l + r) >> 1;
        if (idx <= mid) {
            t[new_node].lc
                = update(t[cur].lc, l, mid, idx, diff);
        }
        else {
            t[new_node].rc
                = update(t[cur].rc, mid + 1, r, idx, diff);
        }
        return new_node;
    }
  
    int query(int cur, int l, int r, int ql, int qr)
    {
        if (ql <= l && r <= qr) {
            return t[cur].val;
        }
        int mid = (l + r) >> 1;
        int res = 0;
        if (ql <= mid) {
            res += query(t[cur].lc, l, mid, ql, qr);
        }
        if (mid < qr) {
            res += query(t[cur].rc, mid + 1, r, ql, qr);
        }
        return res;
    }
};
  
// Merge Sort Tree with point update
class MST {
public:
    vector<int> a;
    PST pst;
    vector<int> root;
    int n;
  
    MST(vector<int> v)
    {
        a = v;
        n = v.size();
        root.push_back(pst.update(0, 0, n - 1, v[0], 1));
        for (int i = 1; i < n; i++) {
            root.push_back(
                pst.update(root.back(), 0, n - 1, v[i], 1));
        }
    }
  
    int build(int l, int r)
    {
        if (l == r) {
            return pst.update(0, 0, n - 1, a[l], 1);
        }
        int mid = (l + r) >> 1;
        int left = build(l, mid);
        int right = build(mid + 1, r);
        return merge(left, right, l, r);
    }
  
    int merge(int left, int right, int l, int r)
    {
        int cur = root.size();
        root.push_back(0);
        int mid = (l + r) >> 1;
        for (int i = l, j = mid + 1; i <= mid || j <= r;) {
            if (j > r || (i <= mid && a[i] < a[j])) {
                root[cur] = pst.update(root[cur], 0, n - 1,
                                       a[i], 1);
                i++;
            }
            else {
                root[cur] = pst.update(root[cur], 0, n - 1,
                                       a[j], 1);
                j++;
            }
        }
        return cur;
    }
  
    int query(int cur, int l, int r, int ql, int qr, int x)
    {
        if (ql <= l && r <= qr) {
            return pst.query(root[cur], 0, n - 1, 0, x);
        }
        int mid = (l + r) >> 1;
        int res = 0;
        if (ql <= mid) {
            res += query(cur - 1, l, mid, ql, qr, x);
        }
        if (mid < qr) {
            res += query(cur - 1, mid + 1, r, ql, qr, x);
        }
        return res;
    }
  
    void update(int idx, int new_val)
    {
        a[idx] = new_val;
        root.push_back(build(0, n - 1));
    }
};
  
// Drivers code
int main()
{
    vector<int> v = { 1, 4, 2, 3, 5 };
    MST mst(v);
  
    // Query for values less than or
    // equal to 3 in the range [1, 3]
    cout << mst.query(v.size() - 1, 0, v.size() - 1, 1, 3,
                      3)
         << endl;
  
    // Update a[2] = 5
    mst.update(2, 5);
  
    // Query for values less than or equal
    // to 3 in the range [1, 3]
    cout << mst.query(v.size() - 1, 0, v.size() - 1, 1, 3,
                      3)
         << endl;
  
    return 0;
}

Output

2
2

Complexity analysis:


Article Tags :
DSA