Open In App

Queries to find the sum of weights of all nodes with vertical width from given range in a Binary Tree

Last Updated : 21 Mar, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a Binary Tree consisting of N nodes having values in the range [0, N – 1] rooted as 0, an array wt[] of size N where wt[i] is the weight of the ith node and a 2D array Q[][] consisting of queries of the type {L, R}, the task for each query is to find the sum of weights of all nodes whose vertical width lies in the range [L, R].

Examples:

Input: N = 4, wt[] = {1, 2, 3, 4}, Q[][] = {{0, 1}, {1, 1}}

                                 0(1)
                                /     \
                          3(4)    1(2)
                                    /  
                               2(3)
Output:
6
10
Explanation:
For Query 1: The range is [0, 1]

  1. Nodes at width position 0: 0, 2. Therefore, the sum of nodes is 1 + 3 = 4.
  2. Nodes at width position 1: 1. Therefore, the sum of nodes is 2.

Therefore, the sum of weights all nodes = 4 + 2 = 6.

For Query 2: The range is [-1, 1]

  1. Nodes at width position -1 is 3. Therefore, the sum of nodes is 4.
  2. Nodes at width position 0: 0, 2. Therefore, the sum of nodes is 1 + 3 = 4.
  3. Nodes at width position 1: 1. Therefore, the sum of nodes is 2.

Therefore, the sum of weights all nodes = 4 + 4 + 2 = 10.

Input: N= 8, wt[] = {8, 6, 4, 5, 1, 2, 9, 1}, Q[][] = {{-1, 1}, {-2, -1}, {0, 3}}

                                           1
                                         /  \
                                       3     2
                                    /   \      \
                                  5     6     4
                                               /  \ 
                                             7    0
Output:
25
7
29

Naive Approach: The simplest approach to solve the given problem is to perform a traversal on the tree for each query and then print the sum of weights of all the nodes that lie in the given range [L, R].

Time Complexity: O(N*Q)
Auxiliary Space: O(1)

Efficient Approach: The above approach can also be optimized by precomputing the sum of weights for every width position and store them in a map M, then find the prefix sum of the newly created array to process the queries in constant time. Follow the below steps to solve the problem:

  • Perform the DFS traversal on the given tree and update the sum at all the width positions and store them in a map M by performing the following operations:
    • Initially, the position pos of the root is 0.
    • In each recursive call, update the weight of the node in M by updating M[pos] = M[pos] + wt[node].
    • Recursively call to its left child by decrementing pos by 1.
    • Recursively call to its right child by incrementing pos by 1.
  • Iterate over the map M and update the value of each key-value pair to the sum of the previous values that occurred.
  • Now, traverse the array Queries[] and for each query [L, R], print the value of (M[R] – M[L – 1]) as the result.

Below is the implementation of the above approach:

C++




// C++ program for the above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of a node
struct Node {
    int data;
    Node* left;
    Node* right;
};
 
// Function to create new node
Node* newNode(int d)
{
    Node* n = new Node;
    n->data = d;
    n->left = NULL;
    n->right = NULL;
    return n;
}
 
// Function to pre-compute the sum of
// weights at each width position
void findwt(Node* root, int wt[],
            map<int, int>& um,
            int width = 0)
{
    // Base Case
    if (root == NULL) {
        return;
    }
 
    // Update the current width
    // position weight
    um[width] += wt[root->data];
 
    // Recursive Call to its left
    findwt(root->left, wt, um,
           width - 1);
 
    // Recursive Call to its right
    findwt(root->right, wt, um,
           width + 1);
}
 
// Function to find the sum of the
// weights of nodes whose vertical
// widths lies over the range [L, R]
// for Q queries
void solveQueries(int wt[], Node* root,
                  vector<vector<int> > queries)
{
    // Stores the weight sum
    // of each width position
    map<int, int> um;
 
    // Function Call to fill um
    findwt(root, wt, um);
 
    // Stores the sum of all previous
    // nodes, while traversing Map
    int x = 0;
 
    // Traverse the Map um
    for (auto it = um.begin();
         it != um.end(); it++) {
        x += it->second;
        um[it->first] = x;
    }
 
    // Iterate over all queries
    for (int i = 0;
         i < queries.size(); i++) {
 
        int l = queries[i][0];
        int r = queries[i][1];
 
        // Print the result for the
        // current query [l, r]
        cout << um[r] - um[l - 1]
             << "\n";
    }
}
 
// Driver Code
int main()
{
    int N = 8;
 
    // Given Tree
    Node* root = newNode(1);
    root->left = newNode(3);
    root->left->left = newNode(5);
    root->left->right = newNode(6);
    root->right = newNode(2);
    root->right->right = newNode(4);
    root->right->right->left = newNode(7);
    root->right->right->right = newNode(0);
    int wt[] = { 8, 6, 4, 5, 1, 2, 9, 1 };
    vector<vector<int> > queries{ { -1, 1 },
                                  { -2, -1 },
                                  { 0, 3 } };
 
    solveQueries(wt, root, queries);
 
    return 0;
}


Java




// Java program for the above approach
import java.io.*;
import java.util.*;
 
class GFG{
   
// Structure of a node
static class Node
{
    int data;
    Node left;
    Node right;
}
 
// Function to create new node
static Node newNode(int d)
{
    Node n = new Node();
    n.data = d;
    n.left = null;
    n.right = null;
    return n;
}
 
// Function to pre-compute the sum of
// weights at each width position
static void findwt(Node root, int wt[],
                   TreeMap<Integer, Integer> um,
                   int width)
{
     
    // Base Case
    if (root == null)
    {
        return;
    }
     
    // Update the current width
    // position weight
    if (um.containsKey(width))
    {
        um.put(width, um.get(width) +
               wt[root.data]);
    }
    else
    {
        um.put(width, wt[root.data]);
    }
 
    // Recursive Call to its left
    findwt(root.left, wt, um,
           width - 1);
 
    // Recursive Call to its right
    findwt(root.right, wt, um,
           width + 1);
}
 
// Function to find the sum of the
// weights of nodes whose vertical
// widths lies over the range [L, R]
// for Q queries
static void solveQueries(int wt[], Node root,
                         int[][] queries)
{
     
    // Stores the weight sum
    // of each width position
    TreeMap<Integer,
            Integer> um = new TreeMap<Integer,
                                      Integer>();
 
    // Function Call to fill um
    findwt(root, wt, um, 0);
 
    // Stores the sum of all previous
    // nodes, while traversing Map
    int x = 0;
 
    // Traverse the Map um
      for(Map.Entry<Integer, Integer> it : um.entrySet())
      {
        x += it.getValue();
          um.put(it.getKey(), x);
    }
 
    // Iterate over all queries
    for(int i = 0; i < queries.length; i++)
    {
        int l = queries[i][0];
        int r = queries[i][1];
 
        // Print the result for the
        // current query [l, r]
          int ans = 0;
          if (um.containsKey(r))
          {
              ans = um.get(r);
        }
           
          if (um.containsKey(l - 1))
          {
              ans -= um.get(l - 1);
        }
        System.out.println(ans);
    }
}
 
// Driver Code
public static void main(String[] args)
{
    int N = 8;
 
    // Given Tree
    Node root = newNode(1);
    root.left = newNode(3);
    root.left.left = newNode(5);
    root.left.right = newNode(6);
    root.right = newNode(2);
    root.right.right = newNode(4);
    root.right.right.left = newNode(7);
    root.right.right.right = newNode(0);
     
    int wt[] = { 8, 6, 4, 5, 1, 2, 9, 1 };
    int queries[][] = { { -1, 1 },
                        { -2, -1 },
                        { 0, 3 } };
 
    solveQueries(wt, root, queries);
}
}
 
// This code is contributed by Dharanendra L V.


Python3




# C++ program for the above approach
 
from collections import defaultdict
 
# Structure of a node
 
class Node:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
# Function to create new node
def newNode(d):
    n = Node(d)
    n.left = None
    n.right = None
    return n
 
# Function to pre-compute the sum of
# weights at each width position
def findwt(root, wt, um, width=0):
    if root is None:
        return
    um[width] += wt[root.data]
    findwt(root.left, wt, um, width-1)
    findwt(root.right, wt, um, width+1)
 
def solveQueries(wt, root, queries):
    um = defaultdict(int)
    findwt(root, wt, um)
    x = 0
    for k, v in sorted(um.items()):
        x += v
        um[k] = x
    for q in queries:
        l = q[0]
        r = q[1]
        if l - 1 in um:
            l_val = um[l - 1]
        else:
            l_val = 0
        if r in um:
            r_val = um[r]
        else:
            r_val = 0
        print(r_val - l_val)
 
if __name__ == '__main__':
    N = 8
    root = newNode(1)
    root.left = newNode(3)
    root.left.left = newNode(5)
    root.left.right = newNode(6)
    root.right = newNode(2)
    root.right.right = newNode(4)
    root.right.right.left = newNode(7)
    root.right.right.right = newNode(0)
    wt = [8, 6, 4, 5, 1, 2, 9, 1]
    queries = [[-1, 1], [-2, -1], [0, 3]]
    solveQueries(wt, root, queries)


C#




using System;
using System.Collections.Generic;
 
class GFG {
    // Structure of a node
    public class Node {
        public int data;
        public Node left;
        public Node right;
    }
 
    // Function to create new node
    static Node newNode(int d)
    {
        Node n = new Node();
        n.data = d;
        n.left = null;
        n.right = null;
        return n;
    }
 
    // Function to pre-compute the sum of
    // weights at each width position
    static void findwt(Node root, int[] wt,
                       SortedDictionary<int, int> um,
                       int width)
    {
        // Base Case
        if (root == null) {
            return;
        }
 
        // Update the current width
        // position weight
        if (um.ContainsKey(width)) {
            um[width] = um[width] + wt[root.data];
        }
        else {
            um.Add(width, wt[root.data]);
        }
 
        // Recursive Call to its left
        findwt(root.left, wt, um, width - 1);
 
        // Recursive Call to its right
        findwt(root.right, wt, um, width + 1);
    }
 
    // Function to find the sum of the
    // weights of nodes whose vertical
    // widths lies over the range [L, R]
    // for Q queries
    static void solveQueries(int[] wt, Node root,
                             int[][] queries)
    {
        // Stores the weight sum
        // of each width position
        SortedDictionary<int, int> um
            = new SortedDictionary<int, int>();
 
        // Function Call to fill um
        findwt(root, wt, um, 0);
 
        // Stores the sum of all previous
        // nodes, while traversing Map
        int x = 0;
 
        // Create new TreeMap to store modified values
        SortedDictionary<int, int> newUm
            = new SortedDictionary<int, int>();
 
        // Traverse the Map um
        foreach(KeyValuePair<int, int> it in um)
        {
            x += it.Value;
            newUm.Add(it.Key, x);
        }
 
        // Replace the original TreeMap with the new one
        um = newUm;
 
        // Iterate over all queries
        for (int i = 0; i < queries.Length; i++) {
            int l = queries[i][0];
            int r = queries[i][1];
 
            // Print the result for the
            // current query [l, r]
            int ans = 0;
            if (um.ContainsKey(r)) {
                ans = um[r];
            }
 
            if (um.ContainsKey(l - 1)) {
                ans -= um[l - 1];
            }
            Console.WriteLine(ans);
        }
    }
 
    // Driver Code
    public static void Main(string[] args)
    {
        int N = 8;
 
        // Given Tree
        Node root = newNode(1);
        root.left = newNode(3);
        root.left.left = newNode(5);
        root.left.right = newNode(6);
        root.right = newNode(2);
        root.right.right = newNode(4);
        root.right.right.left = newNode(7);
        root.right.right.right = newNode(0);
 
        int[] wt = { 8, 6, 4, 5, 1, 2, 9, 1 };
        int[][] queries
            = { new int[] { -1, 1 }, new int[] { -2, -1 },
                new int[] { 0, 3 } };
 
        solveQueries(wt, root, queries);
    }
}


Javascript




// Structure of a node
class Node {
    constructor(data) {
        this.data = data;
        this.left = null;
        this.right = null;
    }
}
 
// Function to create new node
function newNode(d) {
    const n = new Node(d);
    n.left = null;
    n.right = null;
    return n;
}
 
// Function to pre-compute the sum of
// weights at each width position
function findwt(root, wt, um, width=0) {
    // Base Case
    if (root == null) {
        return;
    }
 
    // Update the current width position weight
    if (um.has(width)) {
        um.set(width, um.get(width) + wt[root.data]);
    } else {
        um.set(width, wt[root.data]);
    }
 
    // Recursive Call to its left
    findwt(root.left, wt, um, width - 1);
 
    // Recursive Call to its right
    findwt(root.right, wt, um, width + 1);
}
 
// Function to find the sum of the
// weights of nodes whose vertical
// widths lies over the range [L, R]
// for Q queries
function solveQueries(wt, root, queries) {
    // Stores the weight sum of each width position
    const um = new Map();
 
    // Function Call to fill um
    findwt(root, wt, um, 0);
 
    // Stores the sum of all previous nodes, while traversing Map
    let x = 0;
 
    // Traverse the Map um
    for (const [key, value] of um) {
        x += value;
        um.set(key, x);
    }
 
    // Iterate over all queries
    for (let i = 0; i < queries.length; i++) {
        const l = queries[i][0];
        const r = queries[i][1];
 
        // Print the result for the current query [l, r]
        let ans = 0;
        if (um.has(r)) {
            ans = um.get(r);
        }
 
        if (um.has(l - 1)) {
            ans -= um.get(l - 1);
        }
        console.log(ans);
    }
}
 
// Driver Code
(function main() {
    const N = 8;
 
    // Given Tree
    const root = newNode(1);
    root.left = newNode(3);
    root.left.left = newNode(5);
    root.left.right = newNode(6);
    root.right = newNode(2);
    root.right.right = newNode(4);
    root.right.right.left = newNode(7);
    root.right.right.right = newNode(0);
 
    const wt = [8, 6, 4, 5, 1, 2, 9, 1];
    const queries = [[-1, 1], [-2, -1], [0, 3]];
 
    solveQueries(wt, root, queries);
})();


Output

25
7
29

Time Complexity: O(N*log N + Q)
Auxiliary Space: O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads