Replace each node of a Binary Tree with the sum of all the nodes present in its diagonal
• Last Updated : 15 Jan, 2021

Given a Binary Tree, the task is to print the level order traversal of the tree after replacing the value of each node of the tree with the sum of all the nodes on the same diagonal.

Examples:

Input:

```             9
/ \
6   10
/ \   \
4   7   11
/ \   \
3  5   8```

Output: 30 21 30 9 21 30 3 9 21
Explanation:

```             30
/ \
21   30
/ \   \
9   21  30
/ \   \
3   9   21```

Diagonal traversal of the binary tree
9 10 11
6 7 8
4 5

Input:

```             5
/ \
6   3
/ \   \
4   9   2```

Output: 10 15 10 4 15 10
Explanation:

```             10
/ \
15  10
/ \   \
4   15  10```

Approach: The idea is to perform the diagonal traversal of the binary tree and store the sum of each node on the same diagonal. Finally, traverse the tree and replace each node with the sum of nodes at that diagonal. Follow the steps below to solve the problem:

Below is the implementation of the above approach:

## C++

 `// CPP program to implement``// the above approach``#include ``using` `namespace` `std;` `// Structure of a tree node``struct` `TreeNode``{``  ``int` `val;``  ``struct` `TreeNode *left,*right;``  ``TreeNode(``int` `x)``  ``{``    ``val = x;``    ``left = NULL;``    ``right = NULL;``  ``}``};` `// Function to replace each node with``// the sum of nodes at the same diagonal``void`  `replaceDiag(TreeNode *root, ``int` `d,``                  ``unordered_map<``int``,``int``> &diagMap){` `  ``// IF root is NULL``  ``if` `(!root)``    ``return``;` `  ``// Replace nodes``  ``root->val = diagMap[d];` `  ``// Traverse the left subtree``  ``replaceDiag(root->left, d + 1, diagMap);` `  ``// Traverse the right subtree``  ``replaceDiag(root->right, d, diagMap);``}` `// Function to find the sum of all the nodes``// at each diagonal of the tree``void` `getDiagSum(TreeNode *root, ``int` `d,``                ``unordered_map<``int``,``int``> &diagMap)``{` `  ``// If root is not NULL``  ``if` `(!root)``    ``return``;` `  ``// If update sum of nodes``  ``// at current diagonal``  ``if` `(diagMap[d] > 0)``    ``diagMap[d] += root->val;``  ``else``    ``diagMap[d] = root->val;` `  ``// Traverse the left subtree``  ``getDiagSum(root->left, d + 1, diagMap);` `  ``// Traverse the right subtree``  ``getDiagSum(root->right, d, diagMap);``}` `// Function to print the nodes of the tree``// using level order traversal``void` `levelOrder(TreeNode *root)``{` `  ``// Stores node at each level of the tree``  ``queue q;``  ``q.push(root);``  ``while` `(``true``)``  ``{` `    ``// Stores count of nodes``    ``// at current level``    ``int` `length = q.size();``    ``if` `(!length)``      ``break``;``    ``while` `(length)``    ``{` `      ``// Stores front element``      ``// of the queue``      ``auto` `temp = q.front();``      ``q.pop();` `      ``cout << temp->val << ``" "``;` `      ``// Insert left subtree``      ``if` `(temp->left)``        ``q.push(temp->left);` `      ``// Insert right subtree``      ``if` `(temp->right)``        ``q.push(temp->right);` `      ``// Update length``      ``length -= 1;``    ``}``  ``}``}` `// Driver Code``int` `main()``{``  ``// Build tree``  ``TreeNode *root = ``new` `TreeNode(5);``  ``root->left = ``new` `TreeNode(6);``  ``root->right = ``new` `TreeNode(3);``  ``root->left->left =``new` `TreeNode(4);``  ``root->left->right = ``new` `TreeNode(9);``  ``root->right->right = ``new` `TreeNode(2);` `  ``// Store sum of nodes at each``  ``// diagonal of the tree``  ``unordered_map<``int``,``int``> diagMap;` `  ``// Find sum of nodes at each``  ``// diagonal of the tree``  ``getDiagSum(root, 0, diagMap);` `  ``// Replace nodes with the sum``  ``// of nodes at the same diagonal``  ``replaceDiag(root, 0, diagMap);` `  ``// Print tree``  ``levelOrder(root);``  ``return` `0;``}` `// This code is contributed by mohit kumar 29`

## Python3

 `# Python program to implement``# the above approach` `# Structure of a tree node``class` `TreeNode:``    ``def` `__init__(``self``, val ``=` `0``, left ``=` `None``,``                            ``right ``=` `None``):``        ``self``.val ``=` `val``        ``self``.left ``=` `left``        ``self``.right ``=` `right` `# Function to replace each node with ``# the sum of nodes at the same diagonal``def` `replaceDiag(root, d, diagMap):``    ` `    ``# IF root is NULL``    ``if` `not` `root:``        ``return` `    ``# Replace nodes``    ``root.val ``=` `diagMap[d]``    ` `    ``# Traverse the left subtree``    ``replaceDiag(root.left, d ``+` `1``, diagMap)``    ` `    ``# Traverse the right subtree``    ``replaceDiag(root.right, d, diagMap)` `# Function to find the sum of all the nodes``# at each diagonal of the tree``def` `getDiagSum(root, d, diagMap):``    ` `    ``# If root is not NULL``    ``if` `not` `root:``        ``return` `    ``# If update sum of nodes``    ``# at current diagonal``    ``if` `d ``in` `diagMap:``        ``diagMap[d] ``+``=` `root.val``    ``else``:``        ``diagMap[d] ``=` `root.val``        ` `    ``# Traverse the left subtree   ``    ``getDiagSum(root.left, d ``+` `1``, diagMap)``    ` `    ``# Traverse the right subtree``    ``getDiagSum(root.right, d, diagMap)` `# Function to print the nodes of the tree``# using level order traversal``def` `levelOrder(root):``    ` `    ``# Stores node at each level of the tree``    ``que ``=` `[root]``    ``while` `True``:``        ` `        ``# Stores count of nodes``        ``# at current level``        ``length ``=` `len``(que)     ``        ``if` `not` `length:``            ``break``        ``while` `length:``            ` `            ``# Stores front element``            ``# of the queue``            ``temp ``=` `que.pop(``0``)``            ``print``(temp.val, end ``=``' '``)``            ` `            ``# Insert left subtree``            ``if` `temp.left:``                ``que.append(temp.left)``                ` `            ``# Insert right subtree   ``            ``if` `temp.right:``                ``que.append(temp.right)``                ` `            ``# Update length   ``            ``length ``-``=` `1` `# Driver code``if` `__name__ ``=``=` `'__main__'``:``    ` `    ``# Build tree``    ``root ``=` `TreeNode(``5``)``    ``root.left ``=` `TreeNode(``6``)``    ``root.right ``=` `TreeNode(``3``)``    ``root.left.left ``=` `TreeNode(``4``)``    ``root.left.right ``=` `TreeNode(``9``)``    ``root.right.right ``=` `TreeNode(``2``)``    ` `    ``# Store sum of nodes at each``    ``# diagonal of the tree``    ``diagMap ``=` `{}` `    ``# Find sum of nodes at each``    ``# diagonal of the tree``    ``getDiagSum(root, ``0``, diagMap)` `    ``# Replace nodes with the sum``    ``# of nodes at the same diagonal``    ``replaceDiag(root, ``0``, diagMap)` `    ``# Print tree``    ``levelOrder(root)`
Output:
`10 15 10 4 15 10`

Time Complexity: O(N)
Auxiliary Space: O(N)

