Euler Tour | Subtree Sum using Segment Tree

Euler tour tree (ETT) is a method for representing a rooted tree as a number sequence. When traversing the tree using Depth for search(DFS), insert each node in a vector twice, once while entered it and next after visiting all its children. This method is very useful for solving subtree problems and one such problem is Subtree Sum.

Prerequisite : Segment Tree(Sum of given range)

Naive Approach :
Consider a rooted tree with 6 vertices connected as given in the below diagram. Apply DFS for different queries.
The weight associated with each node is written inside the parenthesis.

Queries :
1. Sum of all the subtrees of node 1.
2. Update the value of node 6 to 10.
3. Sum of all the subtrees of node 2.



Answers :
1. 6 + 5 + 4 + 3 + 2 + 1 = 21
2.
3. 10 + 5 + 2 = 17

Time Complexity Analysis :
Such queries can be performed using depth for search(dfs) in O(n) time complexity.

Efficient Approach :
The time complexity for such queries can be reduced to O(log(n)) time by converting the rooted tree into segment tree using Euler tour technique. So, When the number of queries are q, the total complexity becomes O(q*5log(n)).
Euler Tour :
In Euler tour Technique, each vertex is added to the vector twice, while descending into it and while leaving it.

Let us understand with the help of previous example :

On performing depth for search(DFS) using euler tour technique on the given rooted tree, the vector so formed is :

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}
filter_none

edit
close

play_arrow

link
brightness_4
code

// DFS function to traverse the tree
int dfs(int root)
{
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
  
    for (int i = 0; i & lt; v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}

chevron_right


Now, use vector s[ ] to Create Segment Tree.
Below is the representation of segment tree of vector s[ ].

For the output and update query, store the entry time and exit time(which serve as index range) for each node of the rooted tree.

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}

Node  Entry time  Exit time
   1        1           12
   2        2           7
   3        8           11
   4        9           10
   5        5           6
   6        3           4

Query of type 1 :
Find the range sum on segment tree for output query where range is exit time and entry time of the rooted tree node. Deduce that the answer is always twice the expected answer because each node is added twice in segment tree. So reduce the answer by half.

Query of type 2 :
For update query, update the leaf node of segment tree at the entry time and exit time of the rooted tree node.
Below is the implementation of above approach :

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ program for implementation of
// Euler Tour | Subtree Sum.
#include <bits/stdc++.h>
using namespace std;
  
vector<int> v[1001];
vector<int> s;
int seg[1001] = { 0 };
  
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
  
int vertices = 6;
int edges = 5;
  
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
int segment(int low, int high, int pos)
{
    if (high == low) {
        seg[pos] = ar[s[low]];
    }
    else {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
  
/* Return sum of elements in range
   from index l to r . It uses the 
   seg[] array created using segment()
   function. 'pos' is index of current
   node in segment tree seg[].
*/
int query(int node, int start,
          int end, int l, int r)
{
    if (r < start || end < l) {
        return 0;
    }
  
    if (l <= start && end <= r) {
        return seg[node];
    }
  
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                   mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                   end, l, r);
  
    return (p1 + p2);
}
  
/* A recursive function to update the
   nodes which have the given index in
   their range. The following are
   parameters pos --> index of current 
   node in segment tree seg[]. idx -->
   index of the element to be updated.
   This index is in input array. 
   val --> Value to be change at node idx
*/
int update(int pos, int low, int high,
           int idx, int val)
{
    if (low == high) {
        seg[pos] = val;
    }
    else {
        int mid = (low + high) / 2;
  
        if (low <= idx && idx <= mid) {
            update(2 * pos, low, mid,
                   idx, val);
        }
        else {
            update(2 * pos + 1, mid + 1,
                   high, idx, val);
        }
  
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
  
/* A recursive function to form array
    ar[] from a directed tree .
*/
int dfs(int root)
{
    // pushing each node in vector s
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
  
    for (int i = 0; i < v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}
  
// Driver program to test above functions
int main()
{
    // Edges between the nodes
    v[1].push_back(2);
    v[1].push_back(3);
    v[2].push_back(6);
    v[2].push_back(5);
    v[3].push_back(4);
  
    // Calling dfs function.
    int temp = dfs(1);
    s.push_back(temp);
  
    // Storing entry time and exit
    // time of each node
    vector<pair<int, int> > p;
  
    for (int i = 0; i <= vertices; i++)
        p.push_back(make_pair(0, 0));
  
    for (int i = 0; i < s.size(); i++) {
        if (p[s[i]].first == 0)
            p[s[i]].first = i + 1;
        else
            p[s[i]].second = i + 1;
    }
  
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
  
    // query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p[node].first;
    int f = p[node].second;
  
    int ans = query(1, 1, s.size(), e, f);
  
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
  
    // query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
  
    e = p[node].first;
    f = p[node].second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
  
    // query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
  
    e = p[node].first;
    f = p[node].second;
  
    ans = query(1, 1, s.size(), e, f);
  
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
  
    return 0;
}

chevron_right


Output:

Subtree sum of node 1 is : 21
Subtree sum of node 2 is : 17

Time Complexity : O(q*log(n))



My Personal Notes arrow_drop_up

Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.