Open In App

Sum of subtree depths for every node of a given Binary Tree

Improve
Improve
Like Article
Like
Save
Share
Report

Given a binary tree consisting of N nodes, the task is to find the sum of depths of all subtree nodes in a given binary tree.

Examples:

Input:

Output: 26
Explanation:

The leaf nodes having value 8, 9, 5, 6, and 7 have the sum of subtree depths equal to 0.
Node 4 has a total of 2 nodes in its subtree, both in the same level. Therefore, sum of subtree depth is equal to 2.
Node 2 has a total of 4 nodes in its subtree, 2 in each level. Therefore, the sum of subtree depth is equal to 6.
Node 3 has a total of 2 nodes in its subtree, both at the same level. Therefore, the sum of subtree depth is equal to 2.
Node 1 has a total of 8 nodes in its subtree, 2 in the first level, 4 in the 2nd level, and 2 in the last. Therefore, sum of subtree depths is equal to 16.
Therefore, the total sum of all subtree depths is equal to 2 + 6 + 2 + 16 = 26

Input:

Output: 5

 

Naive Approach: The simplest approach is to traverse the tree and for every node, calculate the sum of depths of all nodes from that node recursively and print the final sum obtained.

Below is the implementation of the above approach:

C++




// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Binary tree node
class TreeNode {
public:
    int data;
    TreeNode* left;
    TreeNode* right;
};
 
// Function that allocates a new
// node with data and NULL to its
// left and right pointers
TreeNode* newNode(int data)
{
    TreeNode* Node = new TreeNode();
    Node->data = data;
    Node->left = NULL;
    Node->right = NULL;
 
    // Return the node
    return (Node);
}
 
// Function to find the sum of depths of
// all nodes in subtree of the current node
int sumofdepth(TreeNode* root, int d)
{
    // IF NULL node then return 0
    if (root == NULL)
        return 0;
 
    // Recursively find the sum of
    // depths of all nodes in the
    // left and right subtree
    return d + sumofdepth(root->left, d + 1)
           + sumofdepth(root->right, d + 1);
}
 
// Function to calculate the sum
// of depth of all the subtrees
int sumofallsubtrees(TreeNode* root)
{
    // if root is NULL return 0
    if (root == NULL)
        return 0;
 
    // Find the total depth sum of
    // current node and recursively
    return sumofdepth(root, 0)
           + sumofallsubtrees(root->left)
           + sumofallsubtrees(root->right);
}
 
// Driver Code
int main()
{
    // Given Tree
    TreeNode* root = newNode(1);
    root->left = newNode(2);
    root->right = newNode(3);
    root->left->left = newNode(4);
    root->left->right = newNode(5);
    root->right->left = newNode(6);
    root->right->right = newNode(7);
    root->left->left->left = newNode(8);
    root->left->left->right = newNode(9);
 
    // Function Call
    cout << sumofallsubtrees(root);
 
    return 0;
}


Java




// Java program for the above approach
class GFG{
 
// Binary tree node
static class TreeNode
{
    int data;
    TreeNode left, right;
}
 
// Function that allocates a new
// node with data and NULL to its
// left and right pointers
static TreeNode newNode(int data)
{
    TreeNode Node = new TreeNode();
    Node.data = data;
    Node.left = Node.right = null;
    return (Node);
}
 
// Function to find the sum of depths of
// all nodes in subtree of the current node
static int sumofdepth(TreeNode root, int d)
{
     
    // If NULL node then return 0
    if (root == null)
        return 0;
         
    // Recursively find the sum of
    // depths of all nodes in the
    // left and right subtree
    return d + sumofdepth(root.left, d + 1) +
              sumofdepth(root.right, d + 1);
}
 
// Function to calculate the sum
// of depth of all the subtrees
static int sumofallsubtrees(TreeNode root)
{
     
    // If root is NULL return 0
    if (root == null)
        return 0;
 
    // Find the total depth sum of
    // current node and recursively
    return sumofdepth(root, 0) +
           sumofallsubtrees(root.left) +
           sumofallsubtrees(root.right);
}
 
// Driver Code
public static void main(String[] args)
{
     
    // Given Tree
    TreeNode root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.left.left = newNode(4);
    root.left.right = newNode(5);
    root.right.left = newNode(6);
    root.right.right = newNode(7);
    root.left.left.left = newNode(8);
    root.left.left.right = newNode(9);
     
    // Function Call
    System.out.println(sumofallsubtrees(root));
}
}
 
// This code is contributed by Dharanendra L V


Python3




# Python3 program for the above approach
 
# Binary tree node
class TreeNode:
     
    def __init__(self, data):
         
        self.data = data
        self.left = None
        self.right = None
 
# Function to find the sum of depths of
# all nodes in subtree of the current node
def sumofdepth(root, d):
     
    # IF None node then return 0
    if (root == None):
        return 0
 
    # Recursively find the sum of
    # depths of all nodes in the
    # left and right subtree
    return (d + sumofdepth(root.left, d + 1) +
                sumofdepth(root.right, d + 1))
 
# Function to calculate the sum
# of depth of all the subtrees
def sumofallsubtrees(root):
     
    # If root is None return 0
    if (root == None):
        return 0
 
    # Find the total depth sum of
    # current node and recursively
    return (sumofdepth(root, 0) + sumofallsubtrees(root.left) +
                                  sumofallsubtrees(root.right))
 
# Driver Code
if __name__ == '__main__':
     
    # Given Tree
    root = TreeNode(1)
    root.left = TreeNode(2)
    root.right = TreeNode(3)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(5)
    root.right.left = TreeNode(6)
    root.right.right = TreeNode(7)
    root.left.left.left = TreeNode(8)
    root.left.left.right = TreeNode(9)
 
    # Function Call
    print(sumofallsubtrees(root))
 
# This code is contributed by ipg2016107


C#




// C# program for the above approach
using System;
public class GFG{
 
// Binary tree node
class TreeNode
{
    public int data;
    public TreeNode left, right;
}
 
// Function that allocates a new
// node with data and NULL to its
// left and right pointers
static TreeNode newNode(int data)
{
    TreeNode Node = new TreeNode();
    Node.data = data;
    Node.left = Node.right = null;
    return (Node);
}
 
// Function to find the sum of depths of
// all nodes in subtree of the current node
static int sumofdepth(TreeNode root, int d)
{
     
    // If NULL node then return 0
    if (root == null)
        return 0;
         
    // Recursively find the sum of
    // depths of all nodes in the
    // left and right subtree
    return d + sumofdepth(root.left, d + 1) +
              sumofdepth(root.right, d + 1);
}
 
// Function to calculate the sum
// of depth of all the subtrees
static int sumofallsubtrees(TreeNode root)
{
     
    // If root is NULL return 0
    if (root == null)
        return 0;
 
    // Find the total depth sum of
    // current node and recursively
    return sumofdepth(root, 0) +
           sumofallsubtrees(root.left) +
           sumofallsubtrees(root.right);
}
 
// Driver Code
public static void Main(String[] args)
{
     
    // Given Tree
    TreeNode root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.left.left = newNode(4);
    root.left.right = newNode(5);
    root.right.left = newNode(6);
    root.right.right = newNode(7);
    root.left.left.left = newNode(8);
    root.left.left.right = newNode(9);
     
    // Function Call
    Console.WriteLine(sumofallsubtrees(root));
}
}
 
// This code is contributed by shikhasingrajput


Javascript




<script>
    // Javascript program for the above approach
     
    // Binary tree node
    class TreeNode
    {
        constructor(data) {
           this.left = null;
           this.right = null;
           this.data = data;
        }
    }
     
    // Function that allocates a new
    // node with data and NULL to its
    // left and right pointers
    function newNode(data)
    {
        let Node = new TreeNode(data);
        return (Node);
    }
 
    // Function to find the sum of depths of
    // all nodes in subtree of the current node
    function sumofdepth(root, d)
    {
 
        // If NULL node then return 0
        if (root == null)
            return 0;
 
        // Recursively find the sum of
        // depths of all nodes in the
        // left and right subtree
        return d + sumofdepth(root.left, d + 1) +
                  sumofdepth(root.right, d + 1);
    }
 
    // Function to calculate the sum
    // of depth of all the subtrees
    function sumofallsubtrees(root)
    {
 
        // If root is NULL return 0
        if (root == null)
            return 0;
 
        // Find the total depth sum of
        // current node and recursively
        return sumofdepth(root, 0) +
               sumofallsubtrees(root.left) +
               sumofallsubtrees(root.right);
    }
     
    // Given Tree
    let root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.left.left = newNode(4);
    root.left.right = newNode(5);
    root.right.left = newNode(6);
    root.right.right = newNode(7);
    root.left.left.left = newNode(8);
    root.left.left.right = newNode(9);
      
    // Function Call
    document.write(sumofallsubtrees(root));
 
// This code is contributed by suresh07.
</script>


Output: 

26

 

Time complexity: O(N2)
Auxiliary Space: O(N)

Efficient Approach: The above approach can be optimized based on the following observations: 

Suppose X and Y are the number of nodes in left and right subtrees respectively and S1 and S2 are the sum of depths of nodes in left and right subtrees. 
Then, the sum of depths of the current node can be calculated as S1 + S2 + x + y, where depths of x + y nodes increase by 1.

Follow the steps below to solve the problem:

  • Define a recursive DFS function say sumofsubtree(root) as:
    • Initialize a pair p which stores the total nodes in the current nodes subtree and the total sum of depths of all node in the current subtree
    • Check if the left child is not NULL and then recursively call the function sumoftree(root->left) and increment the p.first by the total number of nodes in the left subtree and increment the p.second by the total sum of depths for all node in the left subtree.
    • Check if the right child is not NULL and then recursively call the function sumoftree(root->right) and increment the p.first by the total number of nodes in the right subtree and increment the p.second by the total sum of depths for all node in the right subtree.
    • Increment the total ans by p.second and return pair p.
  • Call the DFS function sumoftree(root) and print the result obtained ans.

Below is the implementation of the above approach:

C++




// C++ program for the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Binary tree node
class TreeNode {
public:
    int data;
    TreeNode* left;
    TreeNode* right;
};
 
// Function to allocate a new
// node with the given data and
// NULL in its left and right pointers
TreeNode* newNode(int data)
{
    TreeNode* Node = new TreeNode();
    Node->data = data;
    Node->left = NULL;
    Node->right = NULL;
 
    return (Node);
}
 
// DFS function to calculate the sum
// of depths of all subtrees depth sum
pair<int, int> sumofsubtree(TreeNode* root,
                            int& ans)
{
    // Store total number of node in
    // its subtree and total sum of
    // depth in its subtree
    pair<int, int> p = make_pair(1, 0);
 
    // Check if left is not NULL
    if (root->left) {
 
        // Call recursively the DFS
        // function for left child
        pair<int, int> ptemp
            = sumofsubtree(root->left, ans);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count
        // of noded in left subtree
        p.first += ptemp.first;
    }
 
    // Check if right is not NULL
    if (root->right) {
 
        // Call recursively the DFS
        // function for right child
        pair<int, int> ptemp
            = sumofsubtree(root->right, ans);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count of
        // nodes in right subtree
        p.first += ptemp.first;
    }
 
    // Increment the result by total
    // sum of depth in current subtree
    ans += p.second;
 
    // Return p
    return p;
}
 
// Driver Code
int main()
{
    // Given Tree
    TreeNode* root = newNode(1);
    root->left = newNode(2);
    root->right = newNode(3);
    root->left->left = newNode(4);
    root->left->right = newNode(5);
    root->right->left = newNode(6);
    root->right->right = newNode(7);
    root->left->left->left = newNode(8);
    root->left->left->right = newNode(9);
 
    int ans = 0;
 
    sumofsubtree(root, ans);
 
    // Print the result
    cout << ans;
 
    return 0;
}


Java




// Java program for the above approach
import java.util.*;
class GFG
{
    static int ans;
    static class pair
    {
        int first, second;
        public pair(int first, int second) 
        {
            this.first = first;
            this.second = second;
        }   
    }
   
// Binary tree node
static class TreeNode
{
    int data;
    TreeNode left;
    TreeNode right;
};
 
// Function to allocate a new
// node with the given data and
// null in its left and right pointers
static TreeNode newNode(int data)
{
    TreeNode Node = new TreeNode();
    Node.data = data;
    Node.left = null;
    Node.right = null;
    return (Node);
}
 
// DFS function to calculate the sum
// of depths of all subtrees depth sum
static pair sumofsubtree(TreeNode root)
{
   
    // Store total number of node in
    // its subtree and total sum of
    // depth in its subtree
    pair p = new pair(1, 0);
 
    // Check if left is not null
    if (root.left != null)
    {
 
        // Call recursively the DFS
        // function for left child
        pair ptemp
            = sumofsubtree(root.left);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count
        // of noded in left subtree
        p.first += ptemp.first;
    }
 
    // Check if right is not null
    if (root.right != null)
    {
 
        // Call recursively the DFS
        // function for right child
        pair ptemp
            = sumofsubtree(root.right);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count of
        // nodes in right subtree
        p.first += ptemp.first;
    }
 
    // Increment the result by total
    // sum of depth in current subtree
    ans += p.second;
 
    // Return p
    return p;
}
 
// Driver Code
public static void main(String[] args)
{
    // Given Tree
    TreeNode root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.left.left = newNode(4);
    root.left.right = newNode(5);
    root.right.left = newNode(6);
    root.right.right = newNode(7);
    root.left.left.left = newNode(8);
    root.left.left.right = newNode(9);
    ans = 0;
    sumofsubtree(root);
 
    // Print the result
    System.out.print(ans);
}
}
 
// This code is contributed by shikhasingrajput


Python3




# Python3 program for the above approach
 
# Binary tree node
class TreeNode:
   
    # Constructor to set the data of
    # the newly created tree node
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
ans = 0
 
# Function to allocate a new
# node with the given data and
# null in its left and right pointers
def newNode(data):
    Node = TreeNode(data)
    return (Node)
 
# DFS function to calculate the sum
# of depths of all subtrees depth sum
def sumofsubtree(root):
     
    global ans
     
    # Store total number of node in
    # its subtree and total sum of
    # depth in its subtree
    p = [1, 0]
  
    # Check if left is not null
    if (root.left != None):
        # Call recursively the DFS
        # function for left child
        ptemp = sumofsubtree(root.left)
  
        # Increment the sum of depths
        # by ptemp.first+p.temp.first
        p[1] += ptemp[0] + ptemp[1]
  
        # Increment p.first by count
        # of noded in left subtree
        p[0] += ptemp[0]
  
    # Check if right is not null
    if (root.right != None):
        # Call recursively the DFS
        # function for right child
        ptemp = sumofsubtree(root.right)
  
        # Increment the sum of depths
        # by ptemp.first+p.temp.first
        p[1] += ptemp[0] + ptemp[1]
  
        # Increment p.first by count of
        # nodes in right subtree
        p[0] += ptemp[0]
  
    # Increment the result by total
    # sum of depth in current subtree
    ans += p[1]
  
    # Return p
    return p
 
# Given Tree
root = newNode(1)
root.left = newNode(2)
root.right = newNode(3)
root.left.left = newNode(4)
root.left.right = newNode(5)
root.right.left = newNode(6)
root.right.right = newNode(7)
root.left.left.left = newNode(8)
root.left.left.right = newNode(9)
sumofsubtree(root)
 
# Print the result
print(ans)
 
# This code is contributed by decode2207.


C#




// C# program for the above approach
using System;
 
public class GFG
{
    static int ans;
    class pair
    {
        public int first, second;
        public pair(int first, int second) 
        {
            this.first = first;
            this.second = second;
        }   
    }
   
// Binary tree node
class TreeNode
{
    public int data;
    public TreeNode left;
    public TreeNode right;
};
 
// Function to allocate a new
// node with the given data and
// null in its left and right pointers
static TreeNode newNode(int data)
{
    TreeNode Node = new TreeNode();
    Node.data = data;
    Node.left = null;
    Node.right = null;
    return (Node);
}
 
// DFS function to calculate the sum
// of depths of all subtrees depth sum
static pair sumofsubtree(TreeNode root)
{
   
    // Store total number of node in
    // its subtree and total sum of
    // depth in its subtree
    pair p = new pair(1, 0);
 
    // Check if left is not null
    if (root.left != null)
    {
 
        // Call recursively the DFS
        // function for left child
        pair ptemp
            = sumofsubtree(root.left);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count
        // of noded in left subtree
        p.first += ptemp.first;
    }
 
    // Check if right is not null
    if (root.right != null)
    {
 
        // Call recursively the DFS
        // function for right child
        pair ptemp
            = sumofsubtree(root.right);
 
        // Increment the sum of depths
        // by ptemp.first+p.temp.first
        p.second += ptemp.first
                    + ptemp.second;
 
        // Increment p.first by count of
        // nodes in right subtree
        p.first += ptemp.first;
    }
 
    // Increment the result by total
    // sum of depth in current subtree
    ans += p.second;
 
    // Return p
    return p;
}
 
// Driver Code
public static void Main(String[] args)
{
    // Given Tree
    TreeNode root = newNode(1);
    root.left = newNode(2);
    root.right = newNode(3);
    root.left.left = newNode(4);
    root.left.right = newNode(5);
    root.right.left = newNode(6);
    root.right.right = newNode(7);
    root.left.left.left = newNode(8);
    root.left.left.right = newNode(9);
    ans = 0;
    sumofsubtree(root);
 
    // Print the result
    Console.Write(ans);
}
}
 
  
 
// This code contributed by shikhasingrajput


Javascript




<script>
      // JavaScript program for the above approach
      var ans;
      class pair {
        constructor(first, second) {
          this.first = first;
          this.second = second;
        }
      }
 
      // Binary tree node
      class TreeNode {
        constructor() {
          this.data = 0;
          this.left = null;
          this.right = null;
        }
      }
 
      // Function to allocate a new
      // node with the given data and
      // null in its left and right pointers
      function newNode(data) {
        var Node = new TreeNode();
        Node.data = data;
        Node.left = null;
        Node.right = null;
        return Node;
      }
 
      // DFS function to calculate the sum
      // of depths of all subtrees depth sum
      function sumofsubtree(root) {
        // Store total number of node in
        // its subtree and total sum of
        // depth in its subtree
        var p = new pair(1, 0);
 
        // Check if left is not null
        if (root.left != null) {
          // Call recursively the DFS
          // function for left child
          var ptemp = sumofsubtree(root.left);
 
          // Increment the sum of depths
          // by ptemp.first+p.temp.first
          p.second += ptemp.first + ptemp.second;
 
          // Increment p.first by count
          // of noded in left subtree
          p.first += ptemp.first;
        }
 
        // Check if right is not null
        if (root.right != null) {
          // Call recursively the DFS
          // function for right child
          var ptemp = sumofsubtree(root.right);
 
          // Increment the sum of depths
          // by ptemp.first+p.temp.first
          p.second += ptemp.first + ptemp.second;
 
          // Increment p.first by count of
          // nodes in right subtree
          p.first += ptemp.first;
        }
 
        // Increment the result by total
        // sum of depth in current subtree
        ans += p.second;
 
        // Return p
        return p;
      }
 
      // Driver Code
      // Given Tree
      var root = newNode(1);
      root.left = newNode(2);
      root.right = newNode(3);
      root.left.left = newNode(4);
      root.left.right = newNode(5);
      root.right.left = newNode(6);
      root.right.right = newNode(7);
      root.left.left.left = newNode(8);
      root.left.left.right = newNode(9);
      ans = 0;
      sumofsubtree(root);
 
      // Print the result
      document.write(ans);
       
      // This code is contributed by rdtank.
    </script>


Output: 

26

 

Time Complexity: O(N)
Auxiliary Space: O(N)

 



Last Updated : 17 Sep, 2021
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads