# Sum of distances of all nodes from a given node

• Difficulty Level : Hard
• Last Updated : 28 Sep, 2021

Given a Binary Tree and an integer target, denoting the value of a node, the task is to find the sum of distances of all nodes from the given node.

Examples:

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  To complete your preparation from learning a language to DS Algo and many more,  please refer Complete Interview Preparation Course.

In case you wish to attend live classes with experts, please refer DSA Live Classes for Working Professionals and Competitive Programming Live for Students.

Input: target = 3

Output: 19
Explanation:

Distance of Nodes 1, 6, 7 from the Node 3 = 1
Distance of the Node 2 from the Node 3 = 2
Distance of the Nodes 4, 5 from the Node 3 = 3
Distance of the Nodes 8, 9 from the Node 3 = 4
Sum of the distances = (1 + 1 + 1) + (2) + (3 + 3) + (4 + 4) = 19.

Input: target = 4

Output: 18

Naive Approach: The simplest idea to solve this problem is that, whenever a node is traversed on the left or right of a node, then the distances of the nodes their subtrees reduces by 1, and distance of the rest of the nodes from that node increases by 1.
Therefore, the following relation gives the sum of distances of all nodes from a node, say u:

sumDists(u)= sumDists(parent(u)) – (Nodes in the left and right substree of u) + (N – Nodes in the left and right substree of u)

where,
sumDists(u): Sum of distances of all nodes from the node u
sumDists(parent(u)): Sum of distances of all nodes from the parent node of u

Follow the steps below to solve the problem:

• Create a function to find the number of nodes in the left and right subtree of the given node(including the given node).
• Create a function to find the sum of depths of the node and variable sum denotes the sum of the distance of all nodes from the target.
• Traverse the tree using DFS(Depth First Search) and for each node perform the following:
• If the target matches with the current node then update sum as distance.
• Else:
• If root->left is not null, Find the number of nodes in the left subtree and pass the sum of the distance of all nodes from the root->left node as tempSum.
• If root->right is not null, Find the number of nodes in the right subtree and pass the sum of the distance of all nodes from the root->rightnode as tempSum.
• On reaching the target node, print the sum of the distance of nodes from the target node.

Below is the implementation of the above approach:

## C++

 // C++ program for the above approach #include using namespace std; // Structure of a// Binary Tree Nodeclass TreeNode {public:    int data;    TreeNode* left;    TreeNode* right;}; // Function that allocates a new node// with the given data and NULL to its// left and right pointersTreeNode* newNode(int data){    // Allocate the node    TreeNode* Node = new TreeNode();     // Allocate Memory    Node->data = data;    Node->left = NULL;    Node->right = NULL;     return (Node);} // Function which calculates sum// of depths of all nodesint sumofdepth(TreeNode* root, int l){    // Base Case    if (root == NULL)        return 0;     // Return recursively    return l + sumofdepth(root->left,                          l + 1)           + sumofdepth(root->right,                        l + 1);} // Function to count of nodes// in the left and right subtreeint Noofnodes(TreeNode* root){    // Base Case    if (root == NULL)        return 0;     // Return recursively    return Noofnodes(root->left)           + Noofnodes(root->right)           + 1;} // Stores the sum of distances// of all nodes from given nodeint sum = 0; // Function to find sum of distances// of all nodes from a given nodevoid distance(TreeNode* root,              int target,              int distancesum,              int n){    // If target node matches    // with the current node    if (root->data == target) {        sum = distancesum;        return;    }     // If left of current node exists    if (root->left) {         // Count number of nodes        // in the left subtree        int nodes = Noofnodes(            root->left);         // Update sum        int tempsum = distancesum                      - nodes                      + (n - nodes);         // Recur for the left subtree        distance(root->left, target,                 tempsum, n);    }     // If right is not null    if (root->right) {         // Find number of nodes        // in the left subtree        int nodes = Noofnodes(            root->right);         // Applying the formula given        // in the approach        int tempsum = distancesum                      - nodes + (n - nodes);         // Recur for the right subtree        distance(root->right, target,                 tempsum, n);    }} // Driver Codeint main(){    // Input tree    TreeNode* root = newNode(1);    root->left = newNode(2);    root->right = newNode(3);    root->left->left = newNode(4);    root->left->right = newNode(5);    root->right->left = newNode(6);    root->right->right = newNode(7);    root->left->left->left = newNode(8);    root->left->left->right = newNode(9);     int target = 3;     // Sum of depth of all    // nodes from root node    int distanceroot        = sumofdepth(root, 0);     // Number of nodes in the    // left and right subtree    int totalnodes = Noofnodes(root);     distance(root, target, distanceroot,             totalnodes);     // Print the sum of distances    cout << sum;     return 0;}

## Java

 // Java program for the above approachimport java.io.*; class GFG{ // Structure of a// Binary Tree Nodestatic class TreeNode{    int data;    TreeNode left, right;} // Function that allocates a new node// with the given data and NULL to its// left and right pointersstatic TreeNode newNode(int data){    TreeNode Node = new TreeNode();    Node.data = data;    Node.left = Node.right = null;    return (Node);} // Function which calculates sum// of depths of all nodesstatic int sumofdepth(TreeNode root, int l){         // Base Case    if (root == null)        return 0;             // Return recursively    return l + sumofdepth(root.left, l + 1) +              sumofdepth(root.right, l + 1);} // Function to count of nodes// in the left and right subtreestatic int Noofnodes(TreeNode root){         // Base Case    if (root == null)        return 0;     // Return recursively    return Noofnodes(root.left) +          Noofnodes(root.right) + 1;} // Stores the sum of distances// of all nodes from given nodepublic static int sum = 0; // Function to find sum of distances// of all nodes from a given nodestatic void distance(TreeNode root, int target,                     int distancesum, int n){         // If target node matches    // with the current node    if (root.data == target)    {        sum = distancesum;        return;    }     // If left of current node exists    if (root.left != null)    {                 // Count number of nodes        // in the left subtree        int nodes = Noofnodes(root.left);         // Update sum        int tempsum = distancesum - nodes +                               (n - nodes);         // Recur for the left subtree        distance(root.left, target, tempsum, n);    }     // If right is not null    if (root.right != null)    {                 // Find number of nodes        // in the left subtree        int nodes = Noofnodes(root.right);         // Applying the formula given        // in the approach        int tempsum = distancesum - nodes +                               (n - nodes);         // Recur for the right subtree        distance(root.right, target, tempsum, n);    }} // Driver Codepublic static void main(String[] args){         // Input tree    TreeNode root = newNode(1);    root.left = newNode(2);    root.right = newNode(3);    root.left.left = newNode(4);    root.left.right = newNode(5);    root.right.left = newNode(6);    root.right.right = newNode(7);    root.left.left.left = newNode(8);    root.left.left.right = newNode(9);     int target = 3;     // Sum of depth of all    // nodes from root node    int distanceroot = sumofdepth(root, 0);     // Number of nodes in the    // left and right subtree    int totalnodes = Noofnodes(root);     distance(root, target, distanceroot,             totalnodes);     // Print the sum of distances    System.out.println(sum);}} // This code is contributed by Dharanendra L V

## Python3

 # Python3 program for the above approach # Structure of a# Binary Tree Nodeclass TreeNode:    def __init__(self, x):        self.data = x        self.left = None        self.right = None # Function which calculates sum# of depths of all nodesdef sumofdepth(root, l):       # Base Case    if (root == None):        return 0     # Return recursively    return l + sumofdepth(root.left, l + 1)+ sumofdepth(root.right, l + 1) # Function to count of nodes# in the left and right subtreedef Noofnodes(root):       # Base Case    if (root == None):        return 0     # Return recursively    return Noofnodes(root.left) + Noofnodes(root.right) + 1 # Stores the sum of distances# of all nodes from given nodesum = 0 # Function to find sum of distances# of all nodes from a given nodedef distance(root, target, distancesum, n):    global sum         # If target node matches    # with the current node    if (root.data == target):        sum = distancesum        return     # If left of current node exists    if (root.left):         # Count number of nodes        # in the left subtree        nodes = Noofnodes(root.left)         # Update sum        tempsum = distancesum - nodes + (n - nodes)         # Recur for the left subtree        distance(root.left, target, tempsum, n)     # If right is not null    if (root.right):         # Find number of nodes        # in the left subtree        nodes = Noofnodes(root.right)         # Applying the formula given        # in the approach        tempsum = distancesum - nodes + (n - nodes)         # Recur for the right subtree        distance(root.right, target, tempsum, n) # Driver Codeif __name__ == '__main__':       # Input tree    root = TreeNode(1)    root.left = TreeNode(2)    root.right = TreeNode(3)    root.left.left = TreeNode(4)    root.left.right = TreeNode(5)    root.right.left = TreeNode(6)    root.right.right = TreeNode(7)    root.left.left.left = TreeNode(8)    root.left.left.right = TreeNode(9)    target = 3     # Sum of depth of all    # nodes from root node    distanceroot = sumofdepth(root, 0)     # Number of nodes in the    # left and right subtree    totalnodes = Noofnodes(root)    distance(root, target, distanceroot, totalnodes)     # Print the sum of distances    print (sum)     # This code is contributed by mohit kumar 29.

## C#

 // C# program for the above approachusing System; public class GFG{ // Structure of a// Binary Tree Nodeclass TreeNode{    public int data;    public TreeNode left, right;} // Function that allocates a new node// with the given data and NULL to its// left and right pointersstatic TreeNode newNode(int data){    TreeNode Node = new TreeNode();    Node.data = data;    Node.left = Node.right = null;    return (Node);} // Function which calculates sum// of depths of all nodesstatic int sumofdepth(TreeNode root, int l){         // Base Case    if (root == null)        return 0;             // Return recursively    return l + sumofdepth(root.left, l + 1) +              sumofdepth(root.right, l + 1);} // Function to count of nodes// in the left and right subtreestatic int Noofnodes(TreeNode root){         // Base Case    if (root == null)        return 0;     // Return recursively    return Noofnodes(root.left) +          Noofnodes(root.right) + 1;} // Stores the sum of distances// of all nodes from given nodepublic static int sum = 0; // Function to find sum of distances// of all nodes from a given nodestatic void distance(TreeNode root, int target,                     int distancesum, int n){         // If target node matches    // with the current node    if (root.data == target)    {        sum = distancesum;        return;    }     // If left of current node exists    if (root.left != null)    {                 // Count number of nodes        // in the left subtree        int nodes = Noofnodes(root.left);         // Update sum        int tempsum = distancesum - nodes +                               (n - nodes);         // Recur for the left subtree        distance(root.left, target, tempsum, n);    }     // If right is not null    if (root.right != null)    {                 // Find number of nodes        // in the left subtree        int nodes = Noofnodes(root.right);         // Applying the formula given        // in the approach        int tempsum = distancesum - nodes +                               (n - nodes);         // Recur for the right subtree        distance(root.right, target, tempsum, n);    }} // Driver Codepublic static void Main(String[] args){         // Input tree    TreeNode root = newNode(1);    root.left = newNode(2);    root.right = newNode(3);    root.left.left = newNode(4);    root.left.right = newNode(5);    root.right.left = newNode(6);    root.right.right = newNode(7);    root.left.left.left = newNode(8);    root.left.left.right = newNode(9);     int target = 3;     // Sum of depth of all    // nodes from root node    int distanceroot = sumofdepth(root, 0);     // Number of nodes in the    // left and right subtree    int totalnodes = Noofnodes(root);     distance(root, target, distanceroot,             totalnodes);     // Print the sum of distances    Console.WriteLine(sum);}} // This code is contributed by shikhasingrajput

## Javascript


Output:
19

Time Complexity: O(N2)
Auxiliary Space: O(1)

Efficient Approach: The above approach can be optimized by adding an extra variable, say size, to denote the count of nodes in its left and the right subtrees, in the structure of a node. This reduces the task of calculating the size of subtrees to constant computational time

Below is the implementation of the above approach:

## C++

 // C++ program for the above approach#include using namespace std; // Structure of a// binary tree nodeclass TreeNode {public:    int data, size;    TreeNode* left;    TreeNode* right;}; // Function that allocates a new node// with the given data and NULL to// its left and right pointersTreeNode* newNode(int data){    TreeNode* Node = new TreeNode();    Node->data = data;    Node->left = NULL;    Node->right = NULL;     // Return newly created node    return (Node);} // Function to count the number of// nodes in the left and right subtreespair sumofsubtree(TreeNode* root){    // Initialize a pair that stores    // the pair {number of nodes, depth}    pair p = make_pair(1, 0);     // Finding the number of nodes    // in the left subtree    if (root->left) {        pair ptemp            = sumofsubtree(root->left);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Find the number of nodes    // in the right subtree    if (root->right) {         pair ptemp            = sumofsubtree(root->right);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Filling up size field    root->size = p.first;    return p;} // Stores the sum of distances of all// nodes from the given nodeint sum = 0; // Function to find the total distancevoid distance(TreeNode* root, int target,              int distancesum, int n){    // If target node matches with    // the current node    if (root->data == target) {        sum = distancesum;    }     // If root->left is not null    if (root->left) {         // Update sum        int tempsum = distancesum                      - root->left->size                      + (n - root->left->size);         // Recur for the left subtree        distance(root->left, target,                 tempsum, n);    }     // If root->right is not null    if (root->right) {         // Apply the formula given        // in the approach        int tempsum = distancesum                      - root->right->size                      + (n - root->right->size);         // Recur for the right subtree        distance(root->right, target,                 tempsum, n);    }} // Driver Codeint main(){    // Input tree    TreeNode* root = newNode(1);    root->left = newNode(2);    root->right = newNode(3);    root->left->left = newNode(4);    root->left->right = newNode(5);    root->right->left = newNode(6);    root->right->right = newNode(7);    root->left->left->left = newNode(8);    root->left->left->right = newNode(9);     int target = 3;     pair p = sumofsubtree(root);     // Total number of nodes    int totalnodes = p.first;     distance(root, target, p.second,             totalnodes);     // Print the sum of distances    cout << sum << endl;     return 0;}

## Java

 // Java program for the above approachimport java.util.*;class GFG{    static class pair    {        int first, second;        public pair(int first, int second)         {            this.first = first;            this.second = second;        }       }   // Structure of a// binary tree nodestatic class TreeNode{    int data, size;    TreeNode left;    TreeNode right;}; // Function that allocates a new node// with the given data and null to// its left and right pointersstatic TreeNode newNode(int data){    TreeNode Node = new TreeNode();    Node.data = data;    Node.left = null;    Node.right = null;     // Return newly created node    return (Node);} // Function to count the number of// nodes in the left and right subtreesstatic pair sumofsubtree(TreeNode root){       // Initialize a pair that stores    // the pair {number of nodes, depth}    pair p = new pair(1, 0);     // Finding the number of nodes    // in the left subtree    if (root.left != null)    {        pair ptemp            = sumofsubtree(root.left);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Find the number of nodes    // in the right subtree    if (root.right != null)    {        pair ptemp            = sumofsubtree(root.right);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Filling up size field    root.size = p.first;    return p;} // Stores the sum of distances of all// nodes from the given nodestatic int sum = 0; // Function to find the total distancestatic void distance(TreeNode root, int target,              int distancesum, int n){       // If target node matches with    // the current node    if (root.data == target)    {        sum = distancesum;    }     // If root.left is not null    if (root.left != null)    {         // Update sum        int tempsum = distancesum                      - root.left.size                      + (n - root.left.size);         // Recur for the left subtree        distance(root.left, target,                 tempsum, n);    }     // If root.right is not null    if (root.right != null)    {         // Apply the formula given        // in the approach        int tempsum = distancesum                      - root.right.size                      + (n - root.right.size);         // Recur for the right subtree        distance(root.right, target,                 tempsum, n);    }} // Driver Codepublic static void main(String[] args){       // Input tree    TreeNode root = newNode(1);    root.left = newNode(2);    root.right = newNode(3);    root.left.left = newNode(4);    root.left.right = newNode(5);    root.right.left = newNode(6);    root.right.right = newNode(7);    root.left.left.left = newNode(8);    root.left.left.right = newNode(9);    int target = 3;    pair p = sumofsubtree(root);     // Total number of nodes    int totalnodes = p.first;    distance(root, target, p.second,             totalnodes);     // Print the sum of distances    System.out.print(sum +"\n");}} // This code is contributed by shikhasingrajput

## Python3

 # Python3 program for the above approach # Stores the sum of distances of all# nodes from the given nodesum = 0 # Structure of a# binary tree nodeclass TreeNode:         def __init__(self, data):                 self.data = data        self.size = 0        self.left = None        self.right = None # Function to count the number of# nodes in the left and right subtreesdef sumofsubtree(root):         # Initialize a pair that stores    # the pair {number of nodes, depth}    p =  [1, 0]     # Finding the number of nodes    # in the left subtree    if (root.left):        ptemp = sumofsubtree(root.left)        p[1] += ptemp[0] + ptemp[1]        p[0] += ptemp[0]     # Find the number of nodes    # in the right subtree    if (root.right):        ptemp = sumofsubtree(root.right)        p[1] += ptemp[0] + ptemp[1]        p[0] += ptemp[0]     # Filling up size field    root.size = p[0]    return p # Function to find the total distancedef distance(root, target, distancesum, n):         global sum         # If target node matches with    # the current node    if (root.data == target):        sum = distancesum     # If root.left is not null    if (root.left):                 # Update sum        tempsum = (distancesum - root.left.size +                            (n - root.left.size))         # Recur for the left subtree        distance(root.left, target, tempsum, n)     # If root.right is not null    if (root.right):                 # Apply the formula given        # in the approach        tempsum = (distancesum - root.right.size +                            (n - root.right.size))         # Recur for the right subtree        distance(root.right, target, tempsum, n) # Driver Codeif __name__ == '__main__':         # Input tree    root = TreeNode(1)    root.left = TreeNode(2)    root.right = TreeNode(3)    root.left.left = TreeNode(4)    root.left.right = TreeNode(5)    root.right.left = TreeNode(6)    root.right.right = TreeNode(7)    root.left.left.left = TreeNode(8)    root.left.left.right = TreeNode(9)     target = 3     p = sumofsubtree(root)     # Total number of nodes    totalnodes = p[0]     distance(root, target, p[1], totalnodes)     # Print the sum of distances    print(sum) # This code is contributed by ipg2016107

## C#

 // C# program for the above approachusing System;public class GFG{    class pair    {        public int first, second;        public pair(int first, int second)         {            this.first = first;            this.second = second;        }       }   // Structure of a// binary tree nodeclass TreeNode{    public int data, size;    public TreeNode left;    public TreeNode right;}; // Function that allocates a new node// with the given data and null to// its left and right pointersstatic TreeNode newNode(int data){    TreeNode Node = new TreeNode();    Node.data = data;    Node.left = null;    Node.right = null;     // Return newly created node    return (Node);} // Function to count the number of// nodes in the left and right subtreesstatic pair sumofsubtree(TreeNode root){       // Initialize a pair that stores    // the pair {number of nodes, depth}    pair p = new pair(1, 0);     // Finding the number of nodes    // in the left subtree    if (root.left != null)    {        pair ptemp            = sumofsubtree(root.left);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Find the number of nodes    // in the right subtree    if (root.right != null)    {        pair ptemp            = sumofsubtree(root.right);         p.second += ptemp.first                    + ptemp.second;        p.first += ptemp.first;    }     // Filling up size field    root.size = p.first;    return p;} // Stores the sum of distances of all// nodes from the given nodestatic int sum = 0; // Function to find the total distancestatic void distance(TreeNode root, int target,              int distancesum, int n){       // If target node matches with    // the current node    if (root.data == target)    {        sum = distancesum;    }     // If root.left is not null    if (root.left != null)    {         // Update sum        int tempsum = distancesum                      - root.left.size                      + (n - root.left.size);         // Recur for the left subtree        distance(root.left, target,                 tempsum, n);    }     // If root.right is not null    if (root.right != null)    {         // Apply the formula given        // in the approach        int tempsum = distancesum                      - root.right.size                      + (n - root.right.size);         // Recur for the right subtree        distance(root.right, target,                 tempsum, n);    }} // Driver Codepublic static void Main(String[] args){       // Input tree    TreeNode root = newNode(1);    root.left = newNode(2);    root.right = newNode(3);    root.left.left = newNode(4);    root.left.right = newNode(5);    root.right.left = newNode(6);    root.right.right = newNode(7);    root.left.left.left = newNode(8);    root.left.left.right = newNode(9);    int target = 3;    pair p = sumofsubtree(root);     // Total number of nodes    int totalnodes = p.first;    distance(root, target, p.second,             totalnodes);     // Print the sum of distances    Console.Write(sum +"\n");}} // This code is contributed by shikhasingrajput

## Javascript


Output:
19

Time Complexity: O(N)
Auxiliary Space: O(1)

My Personal Notes arrow_drop_up