Skip to content
Related Articles

Related Articles

Euler Tour | Subtree Sum using Segment Tree
  • Difficulty Level : Medium
  • Last Updated : 15 Dec, 2020
GeeksforGeeks - Summer Carnival Banner

Euler tour tree (ETT) is a method for representing a rooted tree as a number sequence. When traversing the tree using Depth for search(DFS), insert each node in a vector twice, once while entered it and next after visiting all its children. This method is very useful for solving subtree problems and one such problem is Subtree Sum.

Prerequisite : Segment Tree(Sum of given range)

Naive Approach : 
Consider a rooted tree with 6 vertices connected as given in the below diagram. Apply DFS for different queries. 
The weight associated with each node is written inside the parenthesis. 
 

Queries : 
1. Sum of all the subtrees of node 1. 
2. Update the value of node 6 to 10. 
3. Sum of all the subtrees of node 2.



Answers : 
1. 6 + 5 + 4 + 3 + 2 + 1 = 21 
2. 

3. 10 + 5 + 2 = 17
Time Complexity Analysis : 
Such queries can be performed using depth for search(dfs) in O(n) time complexity. 

Efficient Approach : 
The time complexity for such queries can be reduced to O(log(n)) time by converting the rooted tree into segment tree using Euler tour technique. So, When the number of queries are q, the total complexity becomes O(q*5log(n))

Euler Tour : 
In Euler tour Technique, each vertex is added to the vector twice, while descending into it and while leaving it.
Let us understand with the help of previous example : 
 

On performing depth for search(DFS) using euler tour technique on the given rooted tree, the vector so formed is : 

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}

C++




// DFS function to traverse the tree
int dfs(int root)
{
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
 
    for (int i = 0; i & lt; v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}

Now, use vector s[ ] to Create Segment Tree



Below is the representation of segment tree of vector s[ ].  

For the output and update query, store the entry time and exit time(which serve as index range) for each node of the rooted tree. 

s[]={1, 2, 6, 6, 5, 5, 2, 3, 4, 4, 3, 1}

Node  Entry time  Exit time
   1        1           12
   2        2           7
   3        8           11
   4        9           10
   5        5           6
   6        3           4

Query of type 1 : 
Find the range sum on segment tree for output query where range is exit time and entry time of the rooted tree node. Deduce that the answer is always twice the expected answer because each node is added twice in segment tree. So reduce the answer by half.

Query of type 2 : 
For update query, update the leaf node of segment tree at the entry time and exit time of the rooted tree node. 

Below is the implementation of above approach : 

C++




// C++ program for implementation of
// Euler Tour | Subtree Sum.
#include <bits/stdc++.h>
using namespace std;
 
vector<int> v[1001];
vector<int> s;
int seg[1001] = { 0 };
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
int vertices = 6;
int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
int segment(int low, int high, int pos)
{
    if (high == low) {
        seg[pos] = ar[s[low]];
    }
    else {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* Return sum of elements in range
   from index l to r . It uses the
   seg[] array created using segment()
   function. 'pos' is index of current
   node in segment tree seg[].
*/
int query(int node, int start,
          int end, int l, int r)
{
    if (r < start || end < l) {
        return 0;
    }
 
    if (l <= start && end <= r) {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                   mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                   end, l, r);
 
    return (p1 + p2);
}
 
/* A recursive function to update the
   nodes which have the given index in
   their range. The following are
   parameters pos --> index of current
   node in segment tree seg[]. idx -->
   index of the element to be updated.
   This index is in input array.
   val --> Value to be change at node idx
*/
int update(int pos, int low, int high,
           int idx, int val)
{
    if (low == high) {
        seg[pos] = val;
    }
    else {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid) {
            update(2 * pos, low, mid,
                   idx, val);
        }
        else {
            update(2 * pos + 1, mid + 1,
                   high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
/* A recursive function to form array
    ar[] from a directed tree .
*/
int dfs(int root)
{
    // pushing each node in vector s
    s.push_back(root);
    if (v[root].size() == 0)
        return root;
 
    for (int i = 0; i < v[root].size(); i++) {
        int temp = dfs(v[root][i]);
        s.push_back(temp);
    }
    return root;
}
 
// Driver program to test above functions
int main()
{
    // Edges between the nodes
    v[1].push_back(2);
    v[1].push_back(3);
    v[2].push_back(6);
    v[2].push_back(5);
    v[3].push_back(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.push_back(temp);
 
    // Storing entry time and exit
    // time of each node
    vector<pair<int, int> > p;
 
    for (int i = 0; i <= vertices; i++)
        p.push_back(make_pair(0, 0));
 
    for (int i = 0; i < s.size(); i++) {
        if (p[s[i]].first == 0)
            p[s[i]].first = i + 1;
        else
            p[s[i]].second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p[node].first;
    int f = p[node].second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    // query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p[node].first;
    f = p[node].second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p[node].first;
    f = p[node].second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // print the sum of subtree
    cout << "Subtree sum of node " << node << " is : " << (ans / 2) << endl;
 
    return 0;
}

Java




// Java program for implementation of
// Euler Tour | Subtree Sum.
import java.util.ArrayList;
 
class Graph{
 
static class Pair
{
    int first, second;
     
    public Pair(int first, int second)
    {
        this.first = first;
        this.second = second;
    }
}
 
@SuppressWarnings("unchecked")
static ArrayList<Integer>[] v = new ArrayList[1001];
static ArrayList<Integer> s = new ArrayList<>();
static int[] seg = new int[1001];
 
// Value/Weight of each node of tree,
// value of 0th(no such node) node is 0.
static int ar[] = { 0, 1, 2, 3, 4, 5, 6 };
 
static int vertices = 6;
static int edges = 5;
 
// A recursive function that constructs
// Segment Tree for array ar[] = { }.
// 'pos' is index of current node
// in segment tree seg[].
static void segment(int low, int high, int pos)
{
    if (high == low)
    {
        seg[pos] = ar[s.get(low)];
    }
    else
    {
        int mid = (low + high) / 2;
        segment(low, mid, 2 * pos);
        segment(mid + 1, high, 2 * pos + 1);
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1];
    }
}
 
// Return sum of elements in range from index
// l to r. It uses the seg[] array created
// using segment() function. 'pos' is index
// of current node in segment tree seg[].
static int query(int node, int start,
                 int end, int l, int r)
{
    if (r < start || end < l)
    {
        return 0;
    }
 
    if (l <= start && end <= r)
    {
        return seg[node];
    }
 
    int mid = (start + end) / 2;
    int p1 = query(2 * node, start,
                       mid, l, r);
    int p2 = query(2 * node + 1, mid + 1,
                       end, l, r);
 
    return (p1 + p2);
}
 
/*
 * A recursive function to update the nodes which
 * have the given index in their range.
 * The following are parameters
 * pos --> index of current node in segment tree seg[].
 * idx --> index of the element to be updated.
 *         This index is in input array.
 * val --> Value to be change at node idx
 */
static void update(int pos, int low, int high,
                   int idx, int val)
{
    if (low == high)
    {
        seg[pos] = val;
    }
    else
    {
        int mid = (low + high) / 2;
 
        if (low <= idx && idx <= mid)
        {
            update(2 * pos, low, mid,
                       idx, val);
        }
        else
        {
            update(2 * pos + 1, mid + 1,
                       high, idx, val);
        }
 
        seg[pos] = seg[2 * pos] +
                   seg[2 * pos + 1];
    }
}
 
// A recursive function to form array
// ar[] from a directed tree.
static int dfs(int root)
{
     
    // Pushing each node in ArrayList s
    s.add(root);
     
    if (v[root].size() == 0)
        return root;
 
    for(int i = 0; i < v[root].size(); i++)
    {
        int temp = dfs(v[root].get(i));
        s.add(temp);
    }
    return root;
}
 
// Driver code
public static void main(String[] args)
{
    for(int i = 0; i < 1001; i++)
    {
        v[i] = new ArrayList<>();
    }
 
    // Edges between the nodes
    v[1].add(2);
    v[1].add(3);
    v[2].add(6);
    v[2].add(5);
    v[3].add(4);
 
    // Calling dfs function.
    int temp = dfs(1);
    s.add(temp);
 
    // Storing entry time and exit
    // time of each node
    ArrayList<Pair> p = new ArrayList<>();
 
    for(int i = 0; i <= vertices; i++)
        p.add(new Pair(0, 0));
 
    for(int i = 0; i < s.size(); i++)
    {
        if (p.get(s.get(i)).first == 0)
            p.get(s.get(i)).first = i + 1;
        else
            p.get(s.get(i)).second = i + 1;
    }
 
    // Build segment tree from array ar[].
    segment(0, s.size() - 1, 1);
 
    // Query of type 1 return the
    // sum of subtree at node 1.
    int node = 1;
    int e = p.get(node).first;
    int f = p.get(node).second;
 
    int ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
 
    // Query of type 2 return update
    // the subtree at node 6.
    int val = 10;
    node = 6;
 
    e = p.get(node).first;
    f = p.get(node).second;
    update(1, 1, s.size(), e, val);
    update(1, 1, s.size(), f, val);
 
    // Query of type 1 return the
    // sum of subtree at node 2.
    node = 2;
 
    e = p.get(node).first;
    f = p.get(node).second;
 
    ans = query(1, 1, s.size(), e, f);
 
    // Print the sum of subtree
    System.out.println("Subtree sum of node " +
                       node + " is : " + (ans / 2));
}
}
 
// This code is contributed by sanjeev2552

Python3




# Python3 program for implementation of
# Euler Tour | Subtree Sum.
v = [[] for i in range(1001)]
s = []
seg = [0]*1001
 
# Value/Weight of each node of tree,
# value of 0th(no such node) node is 0.
ar = [0, 1, 2, 3, 4, 5, 6]
 
vertices = 6
edges = 5
 
# A recursive function that constructs
# Segment Tree for array ar = .
# 'pos' is index of current node
# in segment tree seg.
def segment(low, high, pos):
    if (high == low):
        seg[pos] = ar[s[low]]
    else:
        mid = (low + high) // 2
        segment(low, mid, 2 * pos)
        segment(mid + 1, high, 2 * pos + 1)
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
         
''' Return sum of elements in range
from index l to r . It uses the
seg array created using segment()
function. 'pos' is index of current
node in segment tree seg.
'''
def query(node, start,end, l, r):
    if (r < start or end < l):
        return 0
     
    if (l <= start and end <= r):
        return seg[node]
         
    mid = (start + end) // 2
    p1 = query(2 * node, start,mid, l, r)
    p2 = query(2 * node + 1, mid + 1, end, l, r)
     
    return (p1 + p2)
 
''' A recursive function to update the
nodes which have the given index in
their range. The following are
parameters pos --> index of current
node in segment tree seg. idx -->
index of the element to be updated.
This index is in input array.
val --> Value to be change at node idx
'''
def update(pos, low, high, idx, val):
    if (low == high):
        seg[pos] = val
    else:
        mid = (low + high) // 2
        if (low <= idx and idx <= mid):
            update(2 * pos, low, mid,idx, val)
        else:
            update(2 * pos + 1, mid + 1, high, idx, val)
         
        seg[pos] = seg[2 * pos] + seg[2 * pos + 1]
     
''' A recursive function to form array
    ar from a directed tree .
'''
def dfs(root):
     
    # pushing each node in vector s
    s.append(root)
    if (len(v[root]) == 0):
        return root
     
    for i in range(len(v[root])):
        temp = dfs(v[root][i])
        s.append(temp)
     
    return root
 
# Edges between the nodes
v[1].append(2)
v[1].append(3)
v[2].append(6)
v[2].append(5)
v[3].append(4)
 
# Calling dfs function.
temp = dfs(1)
s.append(temp)
 
# Storing entry time and exit
# time of each node
p = []
 
for i in range(vertices + 1):
    p.append([0, 0])
 
for i in range(len(s)):
    if (p[s[i]][0] == 0):
        p[s[i]][0] = i + 1
    else:
        p[s[i]][1] = i + 1
 
# Build segment tree from array ar.
segment(0, len(s) - 1, 1)
 
# query of type 1 return the
# sum of subtree at node 1.
node = 1
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node", node, "is :", ans // 2)
 
# query of type 2 return update
# the subtree at node 6.
val = 10
node = 6
 
e = p[node][0]
f = p[node][1]
update(1, 1, len(s), e, val)
update(1, 1, len(s), f, val)
 
# query of type 1 return the
# sum of subtree at node 2.
node = 2
 
e = p[node][0]
f = p[node][1]
 
ans = query(1, 1, len(s), e, f)
 
# print the sum of subtree
print("Subtree sum of node",node,"is :",ans // 2)
 
# This code is contributed by SHUBHAMSINGH10
Output: 
Subtree sum of node 1 is : 21
Subtree sum of node 2 is : 17

 

Time Complexity : O(q*log(n))
 

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  To complete your preparation from learning a language to DS Algo and many more,  please refer Complete Interview Preparation Course.

My Personal Notes arrow_drop_up
Recommended Articles
Page :