Open In App

Replace each node of a Binary Tree with the sum of all the nodes present in its diagonal

Last Updated : 19 Dec, 2022
Improve
Improve
Like Article
Like
Save
Share
Report

Given a Binary Tree, the task is to print the level order traversal of the tree after replacing the value of each node of the tree with the sum of all the nodes on the same diagonal.

Examples:

Input: 
 

             9
            / \
           6   10
          / \   \
         4   7   11
        / \   \
       3  5   8

Output: 30 21 30 9 21 30 3 9 21 
Explanation: 
 

             30
            / \
          21   30
          / \   \
         9   21  30
        / \   \
       3   9   21

Diagonal traversal of the binary tree 
9 10 11 
6 7 8 
4 5 

 

 

Input: 
 

             5
            / \
           6   3
          / \   \
         4   9   2

Output: 10 15 10 4 15 10 
Explanation: 
 

             10
            / \
           15  10
          / \   \
         4   15  10

 

 

Approach: The idea is to perform the diagonal traversal of the binary tree and store the sum of each node on the same diagonal. Finally, traverse the tree and replace each node with the sum of nodes at that diagonal. Follow the steps below to solve the problem:

Below is the implementation of the above approach:

C++




// CPP program to implement
// the above approach
#include <bits/stdc++.h>
using namespace std;
 
// Structure of a tree node
struct TreeNode
{
  int val;
  struct TreeNode *left,*right;
  TreeNode(int x)
  {
    val = x;
    left = NULL;
    right = NULL;
  }
};
 
// Function to replace each node with
// the sum of nodes at the same diagonal
void  replaceDiag(TreeNode *root, int d,
                  unordered_map<int,int> &diagMap){
 
  // IF root is NULL
  if (!root)
    return;
 
  // Replace nodes
  root->val = diagMap[d];
 
  // Traverse the left subtree
  replaceDiag(root->left, d + 1, diagMap);
 
  // Traverse the right subtree
  replaceDiag(root->right, d, diagMap);
}
 
// Function to find the sum of all the nodes
// at each diagonal of the tree
void getDiagSum(TreeNode *root, int d,
                unordered_map<int,int> &diagMap)
{
 
  // If root is not NULL
  if (!root)
    return;
 
  // If update sum of nodes
  // at current diagonal
  if (diagMap[d] > 0)
    diagMap[d] += root->val;
  else
    diagMap[d] = root->val;
 
  // Traverse the left subtree
  getDiagSum(root->left, d + 1, diagMap);
 
  // Traverse the right subtree
  getDiagSum(root->right, d, diagMap);
}
 
// Function to print the nodes of the tree
// using level order traversal
void levelOrder(TreeNode *root)
{
 
  // Stores node at each level of the tree
  queue<TreeNode*> q;
  q.push(root);
  while (true)
  {
 
    // Stores count of nodes
    // at current level
    int length = q.size();
    if (!length)
      break;
    while (length)
    {
 
      // Stores front element
      // of the queue
      auto temp = q.front();
      q.pop();
 
      cout << temp->val << " ";
 
      // Insert left subtree
      if (temp->left)
        q.push(temp->left);
 
      // Insert right subtree
      if (temp->right)
        q.push(temp->right);
 
      // Update length
      length -= 1;
    }
  }
}
 
// Driver Code
int main()
{
  // Build tree
  TreeNode *root = new TreeNode(5);
  root->left = new TreeNode(6);
  root->right = new TreeNode(3);
  root->left->left =new TreeNode(4);
  root->left->right = new TreeNode(9);
  root->right->right = new TreeNode(2);
 
  // Store sum of nodes at each
  // diagonal of the tree
  unordered_map<int,int> diagMap;
 
  // Find sum of nodes at each
  // diagonal of the tree
  getDiagSum(root, 0, diagMap);
 
  // Replace nodes with the sum
  // of nodes at the same diagonal
  replaceDiag(root, 0, diagMap);
 
  // Print tree
  levelOrder(root);
  return 0;
}
 
// This code is contributed by mohit kumar 29


Java




// Java program to implement
// the above approach
import java.util.LinkedList;
import java.util.Queue;
import java.util.TreeMap;
 
class GFG {
 
  // Structure of a tree node
  public static class TreeNode {
    int val;
    TreeNode left, right;
 
    TreeNode(int x) {
      val = x;
      left = null;
      right = null;
    }
  };
 
  // Function to replace each node with
  // the sum of nodes at the same diagonal
  static void replaceDiag(TreeNode root, int d,
                          TreeMap<Integer, Integer> diagMap) {
 
    // IF root is NULL
    if (root == null)
      return;
 
    // Replace nodes
    root.val = diagMap.get(d);
 
    // Traverse the left subtree
    replaceDiag(root.left, d + 1, diagMap);
 
    // Traverse the right subtree
    replaceDiag(root.right, d, diagMap);
  }
 
  // Function to find the sum of all the nodes
  // at each diagonal of the tree
  static void getDiagSum(TreeNode root, int d,
                         TreeMap<Integer, Integer> diagMap) {
 
    // If root is not NULL
    if (root == null)
      return;
 
    // If update sum of nodes
    // at current diagonal
    if (diagMap.get(d)  != null)
      diagMap.put(d, diagMap.get(d) + root.val);
    else
      diagMap.put(d, root.val);
 
    // Traverse the left subtree
    getDiagSum(root.left, d + 1, diagMap);
 
    // Traverse the right subtree
    getDiagSum(root.right, d, diagMap);
  }
 
  // Function to print the nodes of the tree
  // using level order traversal
  static void levelOrder(TreeNode root) {
 
    // Stores node at each level of the tree
    Queue<TreeNode> q = new LinkedList<TreeNode>();
 
    q.add(root);
    while (true) {
 
      // Stores count of nodes
      // at current level
      int length = q.size();
      if (length <= 0)
        break;
      while (length > 0) {
 
        // Stores front element
        // of the queue
        TreeNode temp = q.peek();
        q.remove();
 
        System.out.print(temp.val + " ");
 
        // Insert left subtree
        if (temp.left != null)
          q.add(temp.left);
 
        // Insert right subtree
        if (temp.right != null)
          q.add(temp.right);
 
        // Update length
        length -= 1;
      }
    }
  }
 
  // Driver Code
  public static void main(String args[]) {
    // Build tree
    TreeNode root = new TreeNode(5);
    root.left = new TreeNode(6);
    root.right = new TreeNode(3);
    root.left.left = new TreeNode(4);
    root.left.right = new TreeNode(9);
    root.right.right = new TreeNode(2);
 
    // Store sum of nodes at each
    // diagonal of the tree
    TreeMap<Integer, Integer> diagMap = new TreeMap<Integer, Integer> ();
 
    // Find sum of nodes at each
    // diagonal of the tree
    getDiagSum(root, 0, diagMap);
 
    // Replace nodes with the sum
    // of nodes at the same diagonal
    replaceDiag(root, 0, diagMap);
 
    // Print tree
    levelOrder(root);
  }
}
 
// This code is contributed by Saurabh Jaiswal


Python3




# Python program to implement
# the above approach
 
# Structure of a tree node
class TreeNode:
    def __init__(self, val = 0, left = None,
                            right = None):
        self.val = val
        self.left = left
        self.right = right
 
# Function to replace each node with 
# the sum of nodes at the same diagonal
def replaceDiag(root, d, diagMap):
     
    # IF root is NULL
    if not root:
        return
 
    # Replace nodes
    root.val = diagMap[d]
     
    # Traverse the left subtree
    replaceDiag(root.left, d + 1, diagMap)
     
    # Traverse the right subtree
    replaceDiag(root.right, d, diagMap)
 
# Function to find the sum of all the nodes
# at each diagonal of the tree
def getDiagSum(root, d, diagMap):
     
    # If root is not NULL
    if not root:
        return
 
    # If update sum of nodes
    # at current diagonal
    if d in diagMap:
        diagMap[d] += root.val
    else:
        diagMap[d] = root.val
         
    # Traverse the left subtree   
    getDiagSum(root.left, d + 1, diagMap)
     
    # Traverse the right subtree
    getDiagSum(root.right, d, diagMap)
 
# Function to print the nodes of the tree
# using level order traversal
def levelOrder(root):
     
    # Stores node at each level of the tree
    que = [root]
    while True:
         
        # Stores count of nodes
        # at current level
        length = len(que)     
        if not length:
            break
        while length:
             
            # Stores front element
            # of the queue
            temp = que.pop(0)
            print(temp.val, end =' ')
             
            # Insert left subtree
            if temp.left:
                que.append(temp.left)
                 
            # Insert right subtree   
            if temp.right:
                que.append(temp.right)
                 
            # Update length   
            length -= 1
 
# Driver code
if __name__ == '__main__':
     
    # Build tree
    root = TreeNode(5)
    root.left = TreeNode(6)
    root.right = TreeNode(3)
    root.left.left = TreeNode(4)
    root.left.right = TreeNode(9)
    root.right.right = TreeNode(2)
     
    # Store sum of nodes at each
    # diagonal of the tree
    diagMap = {}
 
    # Find sum of nodes at each
    # diagonal of the tree
    getDiagSum(root, 0, diagMap)
 
    # Replace nodes with the sum
    # of nodes at the same diagonal
    replaceDiag(root, 0, diagMap)
 
    # Print tree
    levelOrder(root)


C#




// C# program to implement
// the above approach
using System;
using System.Collections.Generic;
 
public class TreeNode {
    public int val;
    public TreeNode left, right;
 
    public TreeNode(int x) {
      val = x;
      left = null;
      right = null;
    }
  };
class GFG {
 
  // Structure of a tree node
   
  // Function to replace each node with
  // the sum of nodes at the same diagonal
  static void replaceDiag(TreeNode root, int d,
                          SortedDictionary<int, int> diagMap) {
 
    // IF root is NULL
    if (root == null)
      return;
 
    // Replace nodes
    root.val = diagMap[d];
 
    // Traverse the left subtree
    replaceDiag(root.left, d + 1, diagMap);
 
    // Traverse the right subtree
    replaceDiag(root.right, d, diagMap);
  }
 
  // Function to find the sum of all the nodes
  // at each diagonal of the tree
  static void getDiagSum(TreeNode root, int d,
                         SortedDictionary<int, int> diagMap) {
 
    // If root is not NULL
    if (root == null)
      return;
 
    // If update sum of nodes
    // at current diagonal
    if (diagMap.ContainsKey(d))
      diagMap[d] = diagMap[d] + root.val;
    else
      diagMap[d] = root.val;
 
    // Traverse the left subtree
    getDiagSum(root.left, d + 1, diagMap);
 
    // Traverse the right subtree
    getDiagSum(root.right, d, diagMap);
  }
 
  // Function to print the nodes of the tree
  // using level order traversal
  static void levelOrder(TreeNode root) {
 
    // Stores node at each level of the tree
    List<TreeNode> q = new List<TreeNode>();
 
    q.Add(root);
    while (true) {
 
      // Stores count of nodes
      // at current level
      int length = q.Count;
      if (length <= 0)
        break;
      while (length > 0) {
 
        // Stores front element
        // of the queue
        TreeNode temp = q[0];
        q.RemoveAt(0);
 
        Console.Write(temp.val + " ");
 
        // Insert left subtree
        if (temp.left != null)
          q.Add(temp.left);
 
        // Insert right subtree
        if (temp.right != null)
          q.Add(temp.right);
 
        // Update length
        length -= 1;
      }
    }
  }
 
  // Driver Code
  public static void Main(string[] args) {
    // Build tree
    TreeNode root = new TreeNode(5);
    root.left = new TreeNode(6);
    root.right = new TreeNode(3);
    root.left.left = new TreeNode(4);
    root.left.right = new TreeNode(9);
    root.right.right = new TreeNode(2);
 
    // Store sum of nodes at each
    // diagonal of the tree
    SortedDictionary<int, int> diagMap = new SortedDictionary<int, int> ();
 
    // Find sum of nodes at each
    // diagonal of the tree
    getDiagSum(root, 0, diagMap);
 
    // Replace nodes with the sum
    // of nodes at the same diagonal
    replaceDiag(root, 0, diagMap);
 
    // Print tree
    levelOrder(root);
  }
}
 
// This code is contributed by phasing17


Javascript




<script>
 
    // Javascript program to implement the above approach
     
    // Structure of a tree node
    class TreeNode
    {
        constructor(x) {
           this.left = null;
           this.right = null;
           this.val = x;
        }
    }
     
    // Store sum of nodes at each
    // diagonal of the tree
    let diagMap = new Map();
 
    // Function to replace each node with
    // the sum of nodes at the same diagonal
    function replaceDiag(root, d){
 
      // IF root is NULL
      if (root == null)
        return;
 
      // Replace nodes
      if(diagMap.has(d))
      root.val = diagMap.get(d);
 
      // Traverse the left subtree
      replaceDiag(root.left, d + 1);
 
      // Traverse the right subtree
      replaceDiag(root.right, d);
    }
 
    // Function to find the sum of all the nodes
    // at each diagonal of the tree
    function getDiagSum(root, d)
    {
 
      // If root is not NULL
      if(root == null)
        return;
 
      // If update sum of nodes
      // at current diagonal
      if (diagMap.has(d))
        diagMap.set(d, diagMap.get(d) + root.val);
      else
        diagMap.set(d, root.val);
 
      // Traverse the left subtree
      getDiagSum(root.left, d + 1);
 
      // Traverse the right subtree
      getDiagSum(root.right, d);
    }
 
    // Function to print the nodes of the tree
    // using level order traversal
    function levelOrder(root)
    {
 
      // Stores node at each level of the tree
      let q = [];
      q.push(root);
      while (true)
      {
 
        // Stores count of nodes
        // at current level
        let length = q.length;
        if (!length)
          break;
        while (length)
        {
 
          // Stores front element
          // of the queue
          let temp = q[0];
          q.shift();
 
          document.write(temp.val + " ");
 
          // Insert left subtree
          if (temp.left != null)
            q.push(temp.left);
 
          // Insert right subtree
          if (temp.right != null)
            q.push(temp.right);
 
          // Update length
          length -= 1;
        }
      }
    }
     
    // Build tree
    let root = new TreeNode(5);
    root.left = new TreeNode(6);
    root.right = new TreeNode(3);
    root.left.left =new TreeNode(4);
    root.left.right = new TreeNode(9);
    root.right.right = new TreeNode(2);
 
    // Find sum of nodes at each
    // diagonal of the tree
    getDiagSum(root, 0);
 
    // Replace nodes with the sum
    // of nodes at the same diagonal
    replaceDiag(root, 0);
 
    // Print tree
    levelOrder(root);
 
</script>


Output: 

10 15 10 4 15 10

 

Time Complexity: O(N) 
Auxiliary Space: O(N)



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads