Skip to content
Related Articles

Related Articles

Find distance between two nodes in the given Binary tree for Q queries
  • Difficulty Level : Hard
  • Last Updated : 28 Jan, 2021

Given a binary tree having N nodes and weight of N-1 edges. The distance between two nodes is the sum of the weight of edges on the path between two nodes. Each query contains two integers U and V, the task is to find the distance between nodes U and V.

Examples: 

Input: 
 

Output: 3 5 12 12 
Explanation: 
Distance between nodes 1 to 3 = weight(1, 3) = 2 
Distance between nodes 2 to 3 = weight(1, 2) + weight(1, 3) = 5 
Distance between nodes 3 to 5 = weight(1, 3) + weight(1, 2) + weight(2, 5) = 12 
Distance between nodes 4 to 5 = weight(4, 2) + weight(2, 5) = 12 
 



Approach: The idea is to use LCA in a tree using Binary Lifting Technique.  

  • Binary Lifting is a Dynamic Programming approach where we pre-compute an array lca[i][j] where i = [1, n], j = [1, log(n)] and lca[i][j] contains 2j-th ancestor of node i. 
    • For computing the values of lca[][], the following recursion may be used

 lca[i][j] =\begin{cases} parent[i] & \text{ ;if } j=0 \\ lca[lca[i][j - 1]][j - 1] & \text{ ;if } j>0 \end{cases}

  • As we will compute the lca[][] array we will also calculate the distance[][] where distance[i][j] contains the distance from node i to its 2j-th ancestor 
    • For computing the values of dist[][], the following recursion may be used.

 dist[i][j] =\begin{cases} cost(i, parent[i]) & \text{ ;if } j=0 \\ dist[i][j] = dist[i][j - 1] + dist[lca[i][j - 1]][j - 1]; & \text{ ;if } j>0 \end{cases}

  • After precomputation, we find the distance between (u, v) as we find the least common ancestor of (u, v).

Below is the implementation of the above approach:

C++




// C++ Program to find distance
// between two nodes using LCA
 
#include <bits/stdc++.h>
using namespace std;
 
#define MAX 1000
 
#define log 10 // log2(MAX)
 
// Array to store the level
// of each node
int level[MAX];
 
int lca[MAX][log];
int dist[MAX][log];
 
// Vector to store tree
vector<pair<int, int> > graph[MAX];
 
void addEdge(int u, int v, int cost)
{
    graph[u].push_back({ v, cost });
    graph[v].push_back({ u, cost });
}
 
// Pre-Processing to calculate
// values of lca[][], dist[][]
void dfs(int node, int parent,
         int h, int cost)
{
    // Using recursion formula to
    // calculate the values
    // of lca[][]
    lca[node][0] = parent;
 
    // Storing the level of
    // each node
    level[node] = h;
    if (parent != -1) {
        dist[node][0] = cost;
    }
 
    for (int i = 1; i < log; i++) {
        if (lca[node][i - 1] != -1) {
 
            // Using recursion formula to
            // calculate the values of
            // lca[][] and dist[][]
            lca[node][i]
                = lca[lca[node]
                         [i - 1]]
                     [i - 1];
 
            dist[node][i]
                = dist[node][i - 1]
                  + dist[lca[node][i - 1]]
                        [i - 1];
        }
    }
 
    for (auto i : graph[node]) {
        if (i.first == parent)
            continue;
        dfs(i.first, node,
h + 1, i.second);
    }
}
 
// Function to find the distance
// between given nodes u and v
void findDistance(int u, int v)
{
 
    int ans = 0;
 
    // The node which is present
    // farthest from the root node
    // is taken as v. If u is
    // farther from root node
    // then swap the two
    if (level[u] > level[v])
        swap(u, v);
 
    // Finding the ancestor of v
    // which is at same level as u
    for (int i = log - 1; i >= 0; i--) {
 
        if (lca[v][i] != -1
            && level[lca[v][i]]
                   >= level[u]) {
 
            // Adding distance of node
            // v till its 2^i-th ancestor
            ans += dist[v][i];
            v = lca[v][i];
        }
    }
 
    // If u is the ancestor of v
    // then u is the LCA of u and v
    if (v == u) {
 
        cout << ans << endl;
    }
 
    else {
 
        // Finding the node closest to the
        // root which is not the common
        // ancestor of u and v i.e. a node
        // x such that x is not the common
        // ancestor of u and v but lca[x][0] is
        for (int i = log - 1; i >= 0; i--) {
 
            if (lca[v][i] != lca[u][i]) {
 
                // Adding the distance
                // of v and u to
                // its 2^i-th ancestor
                ans += dist[u][i] + dist[v][i];
 
                v = lca[v][i];
                u = lca[u][i];
            }
        }
 
        // Adding the distance of u and v
        // to its first ancestor
        ans += dist[u][0] + dist[v][0];
 
        cout << ans << endl;
    }
}
 
// Driver Code
int main()
{
 
    // Number of nodes
    int n = 5;
 
    // Add edges with their cost
    addEdge(1, 2, 2);
    addEdge(1, 3, 3);
    addEdge(2, 4, 5);
    addEdge(2, 5, 7);
 
    // Initialising lca and dist values
    // with -1 and 0 respectively
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < log; j++) {
            lca[i][j] = -1;
            dist[i][j] = 0;
        }
    }
 
    // Perform DFS
    dfs(1, -1, 0, 0);
 
    // Query 1: {1, 3}
    findDistance(1, 3);
 
    // Query 2: {2, 3}
    findDistance(2, 3);
 
    // Query 3: {3, 5}
    findDistance(3, 5);
 
    return 0;
}

Java




// Java program to find distance
// between two nodes using LCA
import java.io.*;
import java.util.*;
 
class GFG{
 
static final int MAX = 1000;
// log2(MAX)
static final int log = 10;
 
// Array to store the level
// of each node
static int[] level = new int[MAX];
 
static int[][] lca = new int[MAX][log];
static int[][] dist = new int[MAX][log];
 
// Vector to store tree
@SuppressWarnings("unchecked")
static List<List<int[]> > graph = new ArrayList();
 
static void addEdge(int u, int v, int cost)
{
    graph.get(u).add(new int[]{ v, cost });
    graph.get(v).add(new int[]{ u, cost });
}
 
// Pre-Processing to calculate
// values of lca[][], dist[][]
static void dfs(int node, int parent,
                int h, int cost)
{
     
    // Using recursion formula to
    // calculate the values
    // of lca[][]
    lca[node][0] = parent;
 
    // Storing the level of
    // each node
    level[node] = h;
     
    if (parent != -1)
    {
        dist[node][0] = cost;
    }
 
    for(int i = 1; i < log; i++)
    {
        if (lca[node][i - 1] != -1)
        {
             
            // Using recursion formula to
            // calculate the values of
            // lca[][] and dist[][]
            lca[node][i] = lca[lca[node][i - 1]][i - 1];
 
            dist[node][i] = dist[node][i - 1] +
                            dist[lca[node][i - 1]][i - 1];
        }
    }
 
    for(int[] i : graph.get(node))
    {
        if (i[0] == parent)
            continue;
             
        dfs(i[0], node, h + 1, i[1]);
    }
}
 
// Function to find the distance
// between given nodes u and v
static void findDistance(int u, int v)
{
    int ans = 0;
 
    // The node which is present
    // farthest from the root node
    // is taken as v. If u is
    // farther from root node
    // then swap the two
    if (level[u] > level[v])
    {
        int temp = u;
        u = v;
        v = temp;
    }
 
    // Finding the ancestor of v
    // which is at same level as u
    for(int i = log - 1; i >= 0; i--)
    {
        if (lca[v][i] != -1 &&
      level[lca[v][i]] >= level[u])
        {
 
            // Adding distance of node
            // v till its 2^i-th ancestor
            ans += dist[v][i];
            v = lca[v][i];
        }
    }
 
    // If u is the ancestor of v
    // then u is the LCA of u and v
    if (v == u)
    {
        System.out.println(ans);
    }
 
    else
    {
         
        // Finding the node closest to the
        // root which is not the common
        // ancestor of u and v i.e. a node
        // x such that x is not the common
        // ancestor of u and v but lca[x][0] is
        for(int i = log - 1; i >= 0; i--)
        {
            if (lca[v][i] != lca[u][i])
            {
                 
                // Adding the distance
                // of v and u to
                // its 2^i-th ancestor
                ans += dist[u][i] + dist[v][i];
 
                v = lca[v][i];
                u = lca[u][i];
            }
        }
 
        // Adding the distance of u and v
        // to its first ancestor
        ans += dist[u][0] + dist[v][0];
 
        System.out.println(ans);
    }
}
 
// Driver Code
public static void main(String[] args)
{
     
    // Number of nodes
    int n = 5;
 
    for(int i = 0; i < MAX; i++)
    {
        graph.add(new ArrayList<int[]>());
    }
 
    // Add edges with their cost
    addEdge(1, 2, 2);
    addEdge(1, 3, 3);
    addEdge(2, 4, 5);
    addEdge(2, 5, 7);
 
    // Initialising lca and dist values
    // with -1 and 0 respectively
    for(int i = 1; i <= n; i++)
    {
        for(int j = 0; j < log; j++)
        {
            lca[i][j] = -1;
            dist[i][j] = 0;
        }
    }
 
    // Perform DFS
    dfs(1, -1, 0, 0);
 
    // Query 1: {1, 3}
    findDistance(1, 3);
 
    // Query 2: {2, 3}
    findDistance(2, 3);
 
    // Query 3: {3, 5}
    findDistance(3, 5);
}
}
 
// This code is contributed by jithin

Python3




# Python3 Program to find
# distance between two nodes
# using LCA
MAX = 1000
 
# lg2(MAX)
lg = 10
 
# Array to store the level
# of each node
level = [0 for i in range(MAX)]
 
lca = [[0 for i in range(lg)]
          for j in range(MAX)]
dist = [[0 for i in range(lg)]
           for j in range(MAX)]
 
# Vector to store tree
graph = [[] for i in range(MAX)]
 
def addEdge(u, v, cost):
   
    global graph
     
    graph[u].append([v, cost])
    graph[v].append([u, cost])
 
# Pre-Processing to calculate
# values of lca[][], dist[][]
def dfs(node, parent, h, cost):
   
    # Using recursion formula to
    # calculate the values
    # of lca[][]
    lca[node][0] = parent
 
    # Storing the level of
    # each node
    level[node] = h
     
    if (parent != -1):
        dist[node][0] = cost
 
    for i in range(1, lg):
        if (lca[node][i - 1] != -1):
           
            # Using recursion formula to
            # calculate the values of
            # lca[][] and dist[][]
            lca[node][i] = lca[lca[node][i - 1]][i - 1]
 
            dist[node][i] = (dist[node][i - 1] +
                             dist[lca[node][i - 1]][i - 1])
 
    for i in graph[node]:
        if (i[0] == parent):
            continue
        dfs(i[0], node, h + 1, i[1])
 
# Function to find the distance
# between given nodes u and v
def findDistance(u, v):
   
    ans = 0
 
    # The node which is present
    # farthest from the root node
    # is taken as v. If u is
    # farther from root node
    # then swap the two
    if (level[u] > level[v]):
        temp = u
        u = v
        v = temp
 
    # Finding the ancestor of v
    # which is at same level as u
    i = lg - 1
     
    while(i >= 0):
        if (lca[v][i] != -1 and
            level[lca[v][i]] >= level[u]):
           
            # Adding distance of node
            # v till its 2^i-th ancestor
            ans += dist[v][i]
            v = lca[v][i]
             
        i -= 1
 
    # If u is the ancestor of v
    # then u is the LCA of u and v
    if (v == u):
        print(ans)
 
    else:
        # Finding the node closest to the
        # root which is not the common
        # ancestor of u and v i.e. a node
        # x such that x is not the common
        # ancestor of u and v but lca[x][0] is
        i = lg - 1
         
        while(i >= 0):
            if (lca[v][i] != lca[u][i]):
                # Adding the distance
                # of v and u to
                # its 2^i-th ancestor
                ans += dist[u][i] + dist[v][i]
 
                v = lca[v][i]
                u = lca[u][i]
            i -= 1
 
        # Adding the distance of u and v
        # to its first ancestor
        ans += (dist[u][0] +
                dist[v][0])
 
        print(ans)
 
# Driver Code
if __name__ == '__main__':
   
    # Number of nodes
    n = 5
 
    # Add edges with their cost
    addEdge(1, 2, 2)
    addEdge(1, 3, 3)
    addEdge(2, 4, 5)
    addEdge(2, 5, 7)
 
    # Initialising lca and dist values
    # with -1 and 0 respectively
    for i in range(1, n + 1):
        for j in range(lg):
            lca[i][j] = -1
            dist[i][j] = 0
             
    # Perform DFS
    dfs(1, -1, 0, 0)
    # Query 1: {1, 3}
    findDistance(1, 3)
    # Query 2: {2, 3}
    findDistance(2, 3)
    # Query 3: {3, 5}
    findDistance(3, 5)
     
# This code is contributed by SURENDRA_GANGWAR

C#




// C# program to find distance
// between two nodes using LCA
using System;
using System.Collections.Generic;
class GFG
{
 
  static readonly int MAX = 1000;
 
  // log2(MAX)
  static readonly int log = 10;
 
  // Array to store the level
  // of each node
  static int[] level = new int[MAX];
  static int[,] lca = new int[MAX,log];
  static int[,] dist = new int[MAX,log];
 
  // List to store tree
  static List<List<int[]> > graph = new List<List<int[]>>();
 
  static void addEdge(int u, int v, int cost)
  {
    graph[u].Add(new int[]{ v, cost });
    graph[v].Add(new int[]{ u, cost });
  }
 
  // Pre-Processing to calculate
  // values of lca[,], dist[,]
  static void dfs(int node, int parent,
                  int h, int cost)
  {
 
    // Using recursion formula to
    // calculate the values
    // of lca[,]
    lca[node, 0] = parent;
 
    // Storing the level of
    // each node
    level[node] = h;
 
    if (parent != -1)
    {
      dist[node, 0] = cost;
    }
 
    for(int i = 1; i < log; i++)
    {
      if (lca[node, i - 1] != -1)
      {
 
        // Using recursion formula to
        // calculate the values of
        // lca[,] and dist[,]
        lca[node, i] = lca[lca[node, i - 1], i - 1];
 
        dist[node, i] = dist[node, i - 1] +
          dist[lca[node, i - 1], i - 1];
      }
    }
 
    foreach(int[] i in graph[node])
    {
      if (i[0] == parent)
        continue;           
      dfs(i[0], node, h + 1, i[1]);
    }
  }
 
  // Function to find the distance
  // between given nodes u and v
  static void findDistance(int u, int v)
  {
    int ans = 0;
 
    // The node which is present
    // farthest from the root node
    // is taken as v. If u is
    // farther from root node
    // then swap the two
    if (level[u] > level[v])
    {
      int temp = u;
      u = v;
      v = temp;
    }
 
    // Finding the ancestor of v
    // which is at same level as u
    for(int i = log - 1; i >= 0; i--)
    {
      if (lca[v, i] != -1 &&
          level[lca[v, i]] >= level[u])
      {
 
        // Adding distance of node
        // v till its 2^i-th ancestor
        ans += dist[v, i];
        v = lca[v, i];
      }
    }
 
    // If u is the ancestor of v
    // then u is the LCA of u and v
    if (v == u)
    {
      Console.WriteLine(ans);
    }
 
    else
    {
 
      // Finding the node closest to the
      // root which is not the common
      // ancestor of u and v i.e. a node
      // x such that x is not the common
      // ancestor of u and v but lca[x,0] is
      for(int i = log - 1; i >= 0; i--)
      {
        if (lca[v, i] != lca[u, i])
        {
 
          // Adding the distance
          // of v and u to
          // its 2^i-th ancestor
          ans += dist[u, i] + dist[v, i];
 
          v = lca[v, i];
          u = lca[u, i];
        }
      }
 
      // Adding the distance of u and v
      // to its first ancestor
      ans += dist[u, 0] + dist[v, 0];
      Console.WriteLine(ans);
    }
  }
 
  // Driver Code
  public static void Main(String[] args)
  {
 
    // Number of nodes
    int n = 5;
    for(int i = 0; i < MAX; i++)
    {
      graph.Add(new List<int[]>());
    }
 
    // Add edges with their cost
    addEdge(1, 2, 2);
    addEdge(1, 3, 3);
    addEdge(2, 4, 5);
    addEdge(2, 5, 7);
 
    // Initialising lca and dist values
    // with -1 and 0 respectively
    for(int i = 1; i <= n; i++)
    {
      for(int j = 0; j < log; j++)
      {
        lca[i, j] = -1;
        dist[i, j] = 0;
      }
    }
 
    // Perform DFS
    dfs(1, -1, 0, 0);
 
    // Query 1: {1, 3}
    findDistance(1, 3);
 
    // Query 2: {2, 3}
    findDistance(2, 3);
 
    // Query 3: {3, 5}
    findDistance(3, 5);
  }
}
 
// This code is contributed by aashish1995
Output: 
3
5
12

 

Time Complexity: The time taken in pre-processing is O(N logN) and every query takes O(logN) time. Therefore, overall time complexity of the solution is O(N logN).
 

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  Get hold of all the important mathematical concepts for competitive programming with the Essential Maths for CP Course at a student-friendly price.

In case you wish to attend live classes with industry experts, please refer Geeks Classes Live and Geeks Classes Live USA

My Personal Notes arrow_drop_up
Recommended Articles
Page :