Skip to content
Related Articles

Related Articles

Replace each node by the sum of all nodes in the same level of a Binary Tree
  • Last Updated : 24 Feb, 2021

Given a Binary Tree, the task is to replace the value of each node with the sum of all the nodes present in the same level.

Examples:

Input: 
 

             9
            / \
           6   10
          / \   \
         4   7   11
        / \   \
       3  5   8

Output: 
 

             9
            / \
          16   16
          / \   \
        22   22  22
        / \   \
      16  16   16

Explanation: 
The sum of nodes at level1: 9 
The sum of nodes at level2: 6 + 10 = 16 
The sum of nodes at level3: 4 + 7 + 11 = 22 
The sum of nodes at level4: 3 + 5 + 8 = 16 
 



 

Input: 
 

             5
            / \
           6   3
          / \   \
         4   9   2

Output: 
 

             5
            / \
           9   9
          / \   \
        15   15   15

 

 

Approach:The idea to find the sum of nodes at each level of the Tree and replace each node with the sum of nodes at that level using Level Order Traversal. Follow the steps below to solve the problem:

  • Initialize two queues, say que1 and que2, where the sum of nodes at each level can be calculated using que1 and replace each node with the sum of all the nodes at the same level using que2.
  • Use Level Order Traversal and calculate the sum of all the nodes at each level of the tree.
  • Use Level Order Traversal and replace each node with the sum of all the nodes at that level of the tree.
  • Finally, print the tree using Level Order Traversal .

Below is the implementation of the above approach:

C++14

filter_none

edit
close

play_arrow

link
brightness_4
code

// C++ program to implement
// the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Structure of the binary tree
struct TreeNode
{
  int val = 0;
  TreeNode *left,*right;
  TreeNode(int x)
  {
    val = x;
    left = NULL;
    right = NULL;
  }
};
 
// Function to replace each node of the tree
// with the sum of nodes in the same level
TreeNode* replaceLvl(TreeNode *root){
 
  // If root is null
  if (!root)
    return root;
 
  // Queue to store nodes
  // of each level of the Tree
  queue<TreeNode*> que1;
  queue<TreeNode*> que2;
  que1.push(root);
  que2.push(root);
 
  while (true)
  {
 
    // Stores length of que1
    int lenque1 = que1.size();
 
    // Stores length of que2
    int lenque2 = que2.size();
 
    // Stores sum of nodes at
    // each level of the tree
    int lvlsum = 0;
 
    if (lenque1 == 0)
      break;
 
    // Traverse the tree and store
    // the sum at each level
    while (lenque1 > 0)
    {
 
      // Stores the front element
      // of the queue
      auto temp = que1.front();
      que1.pop();
 
      // Update lvlsum
      lvlsum += temp->val;
 
      // If left subtree not null
      if (temp->left)
        que1.push(temp->left);
 
      // If right subtree not null
      if (temp->right)
        que1.push(temp->right);
 
      // Update lenque1
      lenque1 -= 1;
    }
 
    // Replace all the nodes of at
    // each level of the tree
    while (lenque2 > 0)
    {
 
      // Stores front element
      // of the queue
      auto temp = que2.front();
      que2.pop();
 
      // Replace current node with
      // the sum of nodes
      temp->val = lvlsum;
 
      // If left subtree is not null
      if (temp->left)
        que2.push(temp->left);
 
      // If right subtree is not null
      if (temp->right)
        que2.push(temp->right);
 
      // Update lenque2
      lenque2 -= 1;
    }
  }
  return root;
}
 
// Function to print level order
// traversal of the Binary Tree
void printLvl(TreeNode *root)
{
 
  // Push root node into queue
  queue<TreeNode*> que;
  que.push(root);
 
  while (true)
  {
 
    // Stores count of nodes
    // at each level
    int length = que.size();
 
    // If no nodes left
    if (length == 0)
      break;
 
    while (length)
    {
 
      //Stores front element
      // of the queue
      auto temp = que.front();
      que.pop();
      cout << temp->val << " ";
 
      // If left subtree is not null
      if (temp->left)
        que.push(temp->left);
 
      // If right subtree is not null
      if (temp->right)
        que.push(temp->right);
 
      // Update length
      length -= 1;
    }
    cout << endl;
  }
}
 
// Driver Code
int main()
{
 
  TreeNode *root = new TreeNode(4);
  root->left = new TreeNode(5);
  root->right = new TreeNode(7);
  root->left->left = new TreeNode(1);
  root->left->right = new TreeNode(3);
  root->right->right = new TreeNode(5);
 
  // To update the tree
  root = replaceLvl(root);
 
  // To display the updated tree
  printLvl(root);
 
  return 0;
}
 
// This code is contributed by mohit kumar 29

chevron_right


Java

filter_none

edit
close

play_arrow

link
brightness_4
code

// Java program for above approach
import java.util.*;
import java.lang.*;
 
class GFG
{
  // Structure of a Node
  static class TreeNode
  {
    int val;
    TreeNode left, right;
 
    TreeNode(int key)
    {
      val = key;
      left = null;
      right = null;
    }
  };
 
 
  // Function to replace each node of the tree
  // with the sum of nodes in the same level
  static TreeNode replaceLvl(TreeNode root){
 
    // If root is null
    if (root==null)
      return root;
 
    // Queue to store nodes
    // of each level of the Tree
    Queue<TreeNode> que1=new LinkedList<>();
    Queue<TreeNode> que2=new LinkedList<>();
    que1.add(root);
    que2.add(root);
 
    while (true)
    {
 
      // Stores length of que1
      int lenque1 = que1.size();
 
      // Stores length of que2
      int lenque2 = que2.size();
 
      // Stores sum of nodes at
      // each level of the tree
      int lvlsum = 0;
 
      if (lenque1 == 0)
        break;
 
      // Traverse the tree and store
      // the sum at each level
      while (lenque1 > 0)
      {
 
        // Stores the front element
        // of the queue
        TreeNode temp = que1.peek();
        que1.poll();
 
        // Update lvlsum
        lvlsum += temp.val;
 
        // If left subtree not null
        if (temp.left!=null)
          que1.add(temp.left);
 
        // If right subtree not null
        if (temp.right!=null)
          que1.add(temp.right);
 
        // Update lenque1
        lenque1 -= 1;
      }
 
      // Replace all the nodes of at
      // each level of the tree
      while (lenque2 > 0)
      {
 
        // Stores front element
        // of the queue
        TreeNode temp = que2.peek();
        que2.poll();
 
        // Replace current node with
        // the sum of nodes
        temp.val = lvlsum;
 
        // If left subtree is not null
        if (temp.left!=null)
          que2.add(temp.left);
 
        // If right subtree is not null
        if (temp.right!=null)
          que2.add(temp.right);
 
        // Update lenque2
        lenque2 -= 1;
      }
    }
    return root;
  }
 
  // Function to print level order
  // traversal of the Binary Tree
  static void printLvl(TreeNode root)
  {
 
    // Push root node into queue
    Queue<TreeNode> que = new LinkedList<>();
    que.add(root);
 
    while (true)
    {
 
      // Stores count of nodes
      // at each level
      int length = que.size();
 
      // If no nodes left
      if (length == 0)
        break;
 
      while (length>0)
      {
 
        //Stores front element
        // of the queue
        TreeNode temp = que.peek();
        que.poll();
        System.out.print(temp.val+" ");
 
        // If left subtree is not null
        if (temp.left != null)
          que.add(temp.left);
 
        // If right subtree is not null
        if (temp.right != null)
          que.add(temp.right);
 
        // Update length
        length -= 1;
      }
      System.out.println();
    }
  }
   
  // Driver function
  public static void main (String[] args) {
    TreeNode root = new TreeNode(4);
    root.left = new TreeNode(5);
    root.left.left = new TreeNode(1);
    root.left.right = new TreeNode(3);
    root.right = new TreeNode(7);
    root.right.right = new TreeNode(5);
 
    // To update the tree
    root = replaceLvl(root);
 
    // To display the updated tree
    printLvl(root);
 
  }
}
 
// This code is contributed by offbeat

chevron_right


Python3

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python program to implement
# the above approach
 
# Structure of the binary tree
class TreeNode:
    def __init__(self, val = 0,
                left = None, right = None):
        self.val = val
        self.left = left
        self.right = right
 
# Function to replace each node of the tree
# with the sum of nodes in the same level
def replaceLvl(root):
 
    # If root is null
    if not root:
        return
 
    # Queue to store nodes
    # of each level of the Tree   
    que1 = [root]
    que2 = [root]
 
    while True:
 
        # Stores length of que1
        lenque1 = len(que1)
 
        # Stores length of que2
        lenque2 = len(que2)
 
        # Stores sum of nodes at
        # each level of the tree
        lvlsum = 0
        if not lenque1:
            break
 
        # Traverse the tree and store
        # the sum at each level
        while lenque1:
 
            # Stores the front element
            # of the queue
            temp = que1.pop(0)
 
            # Update lvlsum
            lvlsum += temp.val
 
            # If left subtree not null
            if temp.left:
                que1.append(temp.left)
 
            # If right subtree not null   
            if temp.right:
                que1.append(temp.right)
 
            # Update lenque1   
            lenque1 -= 1
 
        # Replace all the nodes of at
        # each level of the tree
        while lenque2:
 
            # Stores front element
            # of the queue
            temp = que2.pop(0)
 
            # Replace current node with
            # the sum of nodes
            temp.val = lvlsum
 
            # If left subtree is not null
            if temp.left:
                que2.append(temp.left)
 
            # If right subtree is not null   
            if temp.right:
                que2.append(temp.right)
 
            # Update lenque2   
            lenque2 -= 1
    return root
 
# Function to print level order
# traversal of the Binary Tree
def printLvl(root):
 
    # Push root node into queue
    que = [root]
    while True:
 
        # Stores count of nodes
        # at each level
        length = len(que)
 
        # If no nodes left
        if not length:
            break
 
        while length:
 
            # Stores front element
            # of the queue
            temp = que.pop(0)
            print(temp.val, end =' ')
 
            # If left subtree is not null
            if temp.left:
                que.append(temp.left)
 
            # If right subtree is not null   
            if temp.right:
                que.append(temp.right)
 
            # Update length   
            length -= 1
        print()
 
# Driver Code
 
root = TreeNode(4)
root.left = TreeNode(5)
root.right = TreeNode(7)
root.left.left = TreeNode(1)
root.left.right = TreeNode(3)
root.right.right = TreeNode(5)
 
# To update the tree
root = replaceLvl(root)
 
# To display the updated tree
printLvl(root)

chevron_right


Output: 

4 
12 12 
9 9 9

 

Time Complexity: O(N)
Auxiliary Space: O(N)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.

My Personal Notes arrow_drop_up
Recommended Articles
Page :