Open In App

CSES Solutions – Path Queries

Last Updated : 03 May, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

You are given a rooted tree consisting of n nodes. The nodes are numbered 1,2,. . . .,n, and node 1 is the root. Each node has a value.

Your task is to process following types of queries:

  1. change the value of node s to x
  2. calculate the sum of values on the path from the root to node s

Examples:

Input: N = 5, Q = 3, v[] = {4, 2, 5, 2, 1}, edges[][] = {{1, 2}, {1, 3}, {3, 4}, {3, 5}}, queries[][] = {{2, 4}, {1, 3, 2}, {2, 4}}
Output:
11
8

Input: N = 10, Q = 2, v[] = {1, 8, 6, 8, 6, 2, 9, 2, 3, 2}, edges = {{10, 5}, {6, 2}, {10, 7}, {5, 2}, {3, 9}, {8, 3}, {1, 4}, {6, 4}, {8, 7}}, queries[][] = {{2, 10}, {2, 6}}
Output:
27
11

Approach: To solve the problem, follow the below idea:

To efficiently handle these queries, we’ll use a combination of depth-first search (DFS) and segment trees.

Depth-First Search (DFS):

We’ll perform a depth-first traversal of the tree to compute the start and end times for each node. This will allow us to efficiently represent the tree as an array and facilitate range queries.

Segment Trees:

A segment tree is an effective data structure for managing updates and range queries. To efficiently conduct range updates and queries, we’ll use a segment tree to represent the values associated with nodes in the tree.

Implementation:

  • To determine the start and finish timings for each node, we will use DFS to represent the tree as an adjacency list. Next, in order to represent the values connected to nodes, we will initialize a segment tree using the calculated start and end times.
  • We will adjust the segment tree to reflect the updated value of each type 1 query, which modifies a node’s value.
  • We’ll utilize the segment tree to quickly and efficiently calculate the total of values in the range that corresponds to the path from the root to the provided node for each type 2 query (which asks for the sum of values on the path from the root to a particular node).

Step-by-step algorithm:

  • Input: Read the tree’s size, initial node values, and queries.
  • DFS: Traverse the tree using DFS to compute start and end times for each node.
  • Segment Tree: Initialize a segment tree to represent node values efficiently.
  • Initialization: Update the segment tree with initial node values and their negations.
  • Query Types: Handle two types of queries: change node value and calculate sum on path.
  • Type 1 Query: Update segment tree and node value accordingly.
  • Type 2 Query: Query segment tree to calculate sum on the path.
  • Updating Tree: When updating node values, update corresponding segment tree nodes.
  • Querying Tree: When querying sum on a path, use segment tree to compute it efficiently.
  • Output: Output the results of the queries as per the query types.

Below is the implementation of the above algorithm:

C++
#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
const int maxn = 400001;
const ll MOD = 1e9 + 7;
ll segTree[1600010], a[400010], n, q;

vector<int> adj[maxn];
int timer = 0, st[maxn], en[maxn];

// function to answer queries of type 1
ll query(int i, int j, int p = 1, int l = 0,
         int r = 2 * n - 2)
{
    // No overlap
    if (i > j)
        return 0;
    // Complete overlap
    if (l >= i and r <= j) {
        return segTree[p];
    }
    // Partial overlap
    int mid = (l + r) / 2;
    ll left = query(i, min(j, mid), p * 2, l, mid);
    ll right
        = query(max(i, mid + 1), j, p * 2 + 1, mid + 1, r);
    return left + right;
}

// function to update the segment tree
void update(int x, int val, int p = 1, int l = 0,
            int r = 2 * n - 2)
{
    // If l == r, then it is a leaf node within the range
    if (l == r) {
        segTree[p] = val;
        return;
    }
    int mid = (l + r) / 2;
    if (x <= mid) {
        update(x, val, p * 2, l, mid);
    }
    else {
        update(x, val, p * 2 + 1, mid + 1, r);
    }
    segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
}

// Function to run dfs on the graph
void dfs(int s, int p)
{
    st[s] = timer++;
    for (auto u : adj[s]) {
        if (u != p) {
            dfs(u, s);
        }
    }
    en[s] = timer++;
}

int main()
{
    // Sample Input
    n = 5, q = 3;
    int v[] = { 4, 2, 5, 2, 1 };
    vector<vector<int> > edges
        = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };
    vector<vector<int> > queries
        = { { 2, 4 }, { 1, 3, 2 }, { 2, 4 } };
    // 1-based indexing
    for (int i = 1; i <= n; i++) {
        a[i] = v[i - 1];
    }
    // Construct the graph
    for (int i = 0; i < n - 1; i++) {
        int a = edges[i][0];
        int b = edges[i][1];
        adj[a].push_back(b);
        adj[b].push_back(a);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++) {
        update(st[i], a[i]);
    }
    for (int i = 1; i <= n; i++) {
        update(en[i], -a[i]);
    }
    for (int i = 0; i < queries.size(); i++) {
        int t = queries[i][0];
        int s = queries[i][1];
        if (t == 1) {
            int x = queries[i][2];
            update(st[s], x);
            update(en[s], -x);
        }
        else
            cout << query(st[1], en[s] - 1) << endl;
    }
}
Java
import java.util.*;

public class Main {
    static final int maxn = 400001;
    static final int MOD = (int)1e9 + 7;
    static long[] segTree = new long[1600010];
    static long[] a = new long[400010];
    static int n, q;
    static List<Integer>[] adj = new ArrayList[maxn];
    static int timer = 0;
    static int[] st = new int[maxn];
    static int[] en = new int[maxn];

    // function to answer queries of type 1
    static long query(int i, int j, int p, int l, int r)
    {
        // No overlap
        if (i > j)
            return 0;
        // Complete overlap
        if (l >= i && r <= j)
            return segTree[p];
        // Partial overlap
        int mid = (l + r) / 2;
        long left
            = query(i, Math.min(j, mid), p * 2, l, mid);
        long right = query(Math.max(i, mid + 1), j,
                           p * 2 + 1, mid + 1, r);
        return left + right;
    }

    // function to update the segment tree
    static void update(int x, int val, int p, int l, int r)
    {
        // If l == r, then it is a leaf node within the
        // range
        if (l == r) {
            segTree[p] = val;
            return;
        }
        int mid = (l + r) / 2;
        if (x <= mid)
            update(x, val, p * 2, l, mid);
        else
            update(x, val, p * 2 + 1, mid + 1, r);
        segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
    }

    // Function to run dfs on the graph
    static void dfs(int s, int p)
    {
        st[s] = timer++;
        for (int u : adj[s]) {
            if (u != p)
                dfs(u, s);
        }
        en[s] = timer++;
    }

    public static void main(String[] args)
    {
        Scanner sc = new Scanner(System.in);
        // Sample Input
        n = 5;
        q = 3;
        int[] v = { 4, 2, 5, 2, 1 };
        int[][] edges
            = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 } };
        int[][] queries
            = { { 2, 4 }, { 1, 3, 2 }, { 2, 4 } };
        // 1-based indexing
        for (int i = 0; i <= n; i++)
            adj[i] = new ArrayList<>();
        for (int i = 1; i <= n; i++)
            a[i] = v[i - 1];
        // Construct the graph
        for (int i = 0; i < n - 1; i++) {
            int a = edges[i][0];
            int b = edges[i][1];
            adj[a].add(b);
            adj[b].add(a);
        }
        dfs(1, 0);
        for (int i = 1; i <= n; i++)
            update(st[i], (int)a[i], 1, 0, 2 * n - 2);
        for (int i = 1; i <= n; i++)
            update(en[i], (int)-a[i], 1, 0, 2 * n - 2);
        for (int i = 0; i < queries.length; i++) {
            int t = queries[i][0];
            int s = queries[i][1];
            if (t == 1) {
                int x = queries[i][2];
                update(st[s], x, 1, 0, 2 * n - 2);
                update(en[s], -x, 1, 0, 2 * n - 2);
            }
            else
                System.out.println(query(st[1], en[s] - 1,
                                         1, 0, 2 * n - 2));
        }
    }
}

// This code is contributed by Ayush Mishra
Python
import sys
from collections import defaultdict

maxn = 400001
MOD = int(1e9) + 7
segTree = [0]*1600010
a = [0]*400010
n, q = 0, 0
adj = defaultdict(list)
timer = 0
st = [0]*maxn
en = [0]*maxn

# function to answer queries of type 1


def query(i, j, p, l, r):
    # No overlap
    if i > j:
        return 0
    # Complete overlap
    if l >= i and r <= j:
        return segTree[p]
    # Partial overlap
    mid = (l + r) // 2
    left = query(i, min(j, mid), p * 2, l, mid)
    right = query(max(i, mid + 1), j, p * 2 + 1, mid + 1, r)
    return left + right

# function to update the segment tree


def update(x, val, p, l, r):
    # If l == r, then it is a leaf node within the range
    if l == r:
        segTree[p] = val
        return
    mid = (l + r) // 2
    if x <= mid:
        update(x, val, p * 2, l, mid)
    else:
        update(x, val, p * 2 + 1, mid + 1, r)
    segTree[p] = segTree[p * 2] + segTree[p * 2 + 1]

# Function to run dfs on the graph


def dfs(s, p):
    global timer
    st[s] = timer
    timer += 1
    for u in adj[s]:
        if u != p:
            dfs(u, s)
    en[s] = timer
    timer += 1


def main():
    global n, q, timer, a
    # Sample Input
    n, q = 5, 3
    v = [4, 2, 5, 2, 1]
    edges = [(1, 2), (1, 3), (3, 4), (3, 5)]
    queries = [(2, 4), (1, 3, 2), (2, 4)]
    # 1-based indexing
    for i in range(1, n+1):
        a[i] = v[i - 1]
    # Construct the graph
    for i in range(n - 1):
        edge_a, edge_b = edges[i]
        adj[edge_a].append(edge_b)
        adj[edge_b].append(edge_a)
    dfs(1, 0)
    for i in range(1, n+1):
        update(st[i], a[i], 1, 0, 2*n - 2)
        update(en[i], -a[i], 1, 0, 2*n - 2)
    for i in range(len(queries)):
        t, s = queries[i][0], queries[i][1]
        if t == 1:
            x = queries[i][2]
            update(st[s], x, 1, 0, 2*n - 2)
            update(en[s], -x, 1, 0, 2*n - 2)
        else:
            print(query(st[1], en[s] - 1, 1, 0, 2*n - 2))


if __name__ == "__main__":
    main()
JavaScript
const maxn = 400001;
const MOD = 1e9 + 7;
let segTree = Array(1600010).fill(0);
let a = Array(400010).fill(0);
let n, q;

let adj = new Array(maxn).fill(null).map(() => []);
let timer = 0;
let st = new Array(maxn).fill(0);
let en = new Array(maxn).fill(0);

// Function to answer queries of type 1
function query(i, j, p = 1, l = 0, r = 2 * n - 2) {
    // No overlap
    if (i > j) return 0;
    // Complete overlap
    if (l >= i && r <= j) return segTree[p];
    // Partial overlap
    let mid = Math.floor((l + r) / 2);
    let left = query(i, Math.min(j, mid), p * 2, l, mid);
    let right = query(Math.max(i, mid + 1), j, p * 2 + 1, mid + 1, r);
    return left + right;
}

// Function to update the segment tree
function update(x, val, p = 1, l = 0, r = 2 * n - 2) {
    // If l == r, then it is a leaf node within the range
    if (l === r) {
        segTree[p] = val;
        return;
    }
    let mid = Math.floor((l + r) / 2);
    if (x <= mid) {
        update(x, val, p * 2, l, mid);
    } else {
        update(x, val, p * 2 + 1, mid + 1, r);
    }
    segTree[p] = segTree[p * 2] + segTree[p * 2 + 1];
}

// Function to run dfs on the graph
function dfs(s, p) {
    st[s] = timer++;
    for (let u of adj[s]) {
        if (u !== p) {
            dfs(u, s);
        }
    }
    en[s] = timer++;
}

// Sample Input
n = 5, q = 3;
let v = [4, 2, 5, 2, 1];
let edges = [[1, 2], [1, 3], [3, 4], [3, 5]];
let queries = [[2, 4], [1, 3, 2], [2, 4]];

// Construct the graph
for (let i = 0; i < n - 1; i++) {
    let a = edges[i][0];
    let b = edges[i][1];
    adj[a].push(b);
    adj[b].push(a);
}
dfs(1, 0);
for (let i = 1; i <= n; i++) {
    a[i] = v[i - 1];
    update(st[i], a[i]);
}
for (let i = 1; i <= n; i++) {
    update(en[i], -a[i]);
}
for (let i = 0; i < queries.length; i++) {
    let t = queries[i][0];
    let s = queries[i][1];
    if (t === 1) {
        let x = queries[i][2];
        update(st[s], x);
        update(en[s], -x);
    } else {
        console.log(query(st[1], en[s] - 1));
    }
}

 // This code is contributed by Ayush Mishra

Output
11
8

Time Complexity: O(n + q * log n), where n is the number of nodes in the tree and q is the number of queries.
Auxiliary Space: O(n), where n is the number of nodes.



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads