# Sum of nodes within K distance from target

• Difficulty Level : Hard
• Last Updated : 14 Mar, 2023

Given a binary tree, a target node and a positive integer K on it,  the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).

Examples:

Input: target = 9, K = 1,
Binary Tree =             1
/  \
2     9
/      /   \
4      5     7
/  \           /  \
8    19      20   11
/      /   \
30     40   50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22

Input: target = 40,  K = 2,
Binary Tree =             1
/  \
2     9
/      /   \
4      5     7
/  \           /  \
8    19      20   11
/      /   \
30     40   50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113

Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:

Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.

Follow the steps mentioned below to implement the approach:

• Create a hash table (say par)to store the parent of each node.
• Perform a DFS and store the parent of each node.
• Now find the target in the tree.
• Create a hash table to mark the visited nodes.
• Start a DFS from target:
• If the distance is not K, add the value in the final sum.
• If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
• Return the sum of its neighbours while the recursion for the current node is complete
• Return the sum of all the nodes within K distance from the target.

Below is the implementation of the above approach:

## C++

 `// C++ code to implement above approach` `#include ``using` `namespace` `std;` `// Structure of a tree node``struct` `Node {``    ``int` `data;``    ``Node* left;``    ``Node* right;``    ``Node(``int` `val)``    ``{``        ``this``->data = val;``        ``this``->left = 0;``        ``this``->right = 0;``    ``}``};` `// Function for marking the parent node``// for all the nodes using DFS``void` `dfs(Node* root,``         ``unordered_map& par)``{``    ``if` `(root == 0)``        ``return``;``    ``if` `(root->left != 0)``        ``par[root->left] = root;``    ``if` `(root->right != 0)``        ``par[root->right] = root;``    ``dfs(root->left, par);``    ``dfs(root->right, par);``}` `// Function calling for finding the sum``void` `dfs3(Node* root, ``int` `h, ``int``& sum, ``int` `k,``          ``unordered_map& vis,``          ``unordered_map& par)``{``    ``if` `(h == k + 1)``        ``return``;``    ``if` `(root == 0)``        ``return``;``    ``if` `(vis[root])``        ``return``;``    ``sum += root->data;``    ``vis[root] = 1;``    ``dfs3(root->left, h + 1, sum, k, vis, par);``    ``dfs3(root->right, h + 1, sum, k, vis, par);``    ``dfs3(par[root], h + 1, sum, k, vis, par);``}` `// Function for finding``// the target node in the tree``Node* dfs2(Node* root, ``int` `target)``{``    ``if` `(root == 0)``        ``return` `0;``    ``if` `(root->data == target)``        ``return` `root;``    ``Node* node1 = dfs2(root->left, target);``    ``Node* node2 = dfs2(root->right, target);``    ``if` `(node1 != 0)``        ``return` `node1;``    ``if` `(node2 != 0)``        ``return` `node2;``}` `// Function to find the sum at distance K``int` `sum_at_distK(Node* root, ``int` `target,``                 ``int` `k)``{``    ``// Hash Table to store``    ``// the parent of a node``    ``unordered_map par;` `    ``// Make the parent of root node as NULL``    ``// since it does not have any parent``    ``par[root] = 0;` `    ``// Mark the parent node for all the``    ``// nodes using DFS``    ``dfs(root, par);` `    ``// Find the target node in the tree``    ``Node* node = dfs2(root, target);` `    ``// Hash Table to mark``    ``// the visited nodes``    ``unordered_map vis;` `    ``int` `sum = 0;` `    ``// DFS call to find the sum``    ``dfs3(node, 0, sum, k, vis, par);``    ``return` `sum;``}` `// Driver Code``int` `main()``{``    ``// Taking Input``    ``Node* root = ``new` `Node(1);``    ``root->left = ``new` `Node(2);``    ``root->right = ``new` `Node(9);``    ``root->left->left = ``new` `Node(4);``    ``root->right->left = ``new` `Node(5);``    ``root->right->right = ``new` `Node(7);``    ``root->left->left->left = ``new` `Node(8);``    ``root->left->left->right = ``new` `Node(19);``    ``root->right->right->left = ``new` `Node(20);``    ``root->right->right->right``        ``= ``new` `Node(11);``    ``root->left->left->left->left``        ``= ``new` `Node(30);``    ``root->left->left->right->left``        ``= ``new` `Node(40);``    ``root->left->left->right->right``        ``= ``new` `Node(50);` `    ``int` `target = 9, K = 1;` `    ``// Function call``    ``cout << sum_at_distK(root, target, K);``    ``return` `0;``}`

## Java

 `// Java code to implement above approach``import` `java.util.*;` `public` `class` `Main {``    ``// Structure of a tree node``    ``static` `class` `Node {``        ``int` `data;``        ``Node left;``        ``Node right;``        ``Node(``int` `val)``        ``{``            ``this``.data = val;``            ``this``.left = ``null``;``            ``this``.right = ``null``;``        ``}``    ``}` `    ``// Function for marking the parent node``    ``// for all the nodes using DFS``    ``static` `void` `dfs(Node root,``            ``HashMap par)``    ``{``        ``if` `(root == ``null``)``            ``return``;``        ``if` `(root.left != ``null``)``            ``par.put( root.left, root);``        ``if` `(root.right != ``null``)``            ``par.put( root.right, root);``        ``dfs(root.left, par);``        ``dfs(root.right, par);``    ``}``    ``static` `int` `sum;``    ``// Function calling for finding the sum``    ``static` `void` `dfs3(Node root, ``int` `h, ``int` `k,``            ``HashMap vis,``            ``HashMap par)``    ``{``        ``if` `(h == k + ``1``)``            ``return``;``        ``if` `(root == ``null``)``            ``return``;``        ``if` `(vis.containsKey(root))``            ``return``;``        ``sum += root.data;``        ``vis.put(root, ``1``);``        ``dfs3(root.left, h + ``1``, k, vis, par);``        ``dfs3(root.right, h + ``1``, k, vis, par);``        ``dfs3(par.get(root), h + ``1``, k, vis, par);``    ``}``    ``// Function for finding``    ``// the target node in the tree``    ``static` `Node dfs2(Node root, ``int` `target)``    ``{``        ``if` `(root == ``null``)``            ``return` `null``;``        ``if` `(root.data == target)``            ``return` `root;``        ``Node node1 = dfs2(root.left, target);``        ``Node node2 = dfs2(root.right, target);``        ``if` `(node1 != ``null``)``            ``return` `node1;``        ``if` `(node2 != ``null``)``            ``return` `node2;``        ``return` `null``;``    ``}` `    ``static` `int` `sum_at_distK(Node root, ``int` `target,``                 ``int` `k)``    ``{``        ``// Hash Map to store``        ``// the parent of a node``        ``HashMap par =  ``new` `HashMap<>();` `        ``// Make the parent of root node as NULL``        ``// since it does not have any parent``        ``par.put(root, ``null``);` `        ``// Mark the parent node for all the``        ``// nodes using DFS``        ``dfs(root, par);` `        ``// Find the target node in the tree``        ``Node node = dfs2(root, target);` `        ``// Hash Map to mark``        ``// the visited nodes``        ``HashMap vis = ``new` `HashMap<>();` `        ``sum = ``0``;` `        ``// DFS call to find the sum``        ``dfs3(node, ``0``, k, vis, par);``        ``return` `sum;``    ``}`   `    ``public` `static` `void` `main(String args[]) {``        ``// Taking Input``        ``Node root = ``new` `Node(``1``);``        ``root.left = ``new` `Node(``2``);``        ``root.right = ``new` `Node(``9``);``        ``root.left.left = ``new` `Node(``4``);``        ``root.right.left = ``new` `Node(``5``);``        ``root.right.right = ``new` `Node(``7``);``        ``root.left.left.left = ``new` `Node(``8``);``        ``root.left.left.right = ``new` `Node(``19``);``        ``root.right.right.left = ``new` `Node(``20``);``        ``root.right.right.right``            ``= ``new` `Node(``11``);``        ``root.left.left.left.left``            ``= ``new` `Node(``30``);``        ``root.left.left.right.left``            ``= ``new` `Node(``40``);``        ``root.left.left.right.right``            ``= ``new` `Node(``50``);` `        ``int` `target = ``9``, K = ``1``;` `        ``// Function call``        ``System.out.println( sum_at_distK(root, target, K) );``        ` `    ``}``}` `// This code has been contributed by Sachin Sahara (sachin801)`

## Python3

 `# python program to implement above approach``# structure of tree node``class` `Node:``    ``def` `__init__(``self``, val):``        ``self``.data ``=` `val``        ``self``.left ``=` `None``        ``self``.right ``=` `None`  `# function for making the parent node``# for all the nodes using DFS``def` `dfs(root, par):``    ``if``(root ``is` `None``):``        ``return``    ``if``(root.left ``is` `not` `None``):``        ``par[root.left] ``=` `root``    ``if``(root.right ``is` `not` `None``):``        ``par[root.right] ``=` `root``    ``dfs(root.left, par)``    ``dfs(root.right, par)`  `# function calling for finding the sum``summ ``=` `0``def` `dfs3(root, h, k, vis, par):``    ``if``(h ``=``=` `k``+``1``):``        ``return``    ``if``(root ``is` `None``):``        ``return``    ``if``(vis.get(root) ``=``=` `1``):``        ``return``    ``global` `summ``    ``summ ``+``=` `root.data``    ``vis[root] ``=` `1``    ``dfs3(root.left, h``+``1``, k, vis, par)``    ``dfs3(root.right, h``+``1``, k, vis, par)``    ``dfs3(par[root], h``+``1``, k, vis, par)`  `# function for finding``# the target node in the tree``def` `dfs2(root, target):``    ``if``(root ``is` `None``):``        ``return` `None``    ``if``(root.data ``=``=` `target):``        ``return` `root``    ``node1 ``=` `dfs2(root.left, target)``    ``node2 ``=` `dfs2(root.right, target)``    ``if``(node1 ``is` `not` `None``):``        ``return` `node1``    ``if``(node2 ``is` `not` `None``):``        ``return` `node2``        `  `# function tofind the sum at distance k``def` `sum_at_distK(root, target, k):``    ``# hash table to store``    ``# the parent of a node``    ``par ``=` `{}``    ` `    ``# make the parent of root node as None``    ``# since it does not have any parent``    ``par[root] ``=` `0``    ` `    ``# make the parent node for all the``    ``# nodes using DFS``    ``dfs(root, par)``    ` `    ``# find the target node in the tree``    ``node ``=` `dfs2(root, target)``    ` `    ``# hash table to make the visited nodes``    ``vis ``=` `{}``    ` `    ``# dfs call to find the sum``    ``dfs3(node, ``0``, k, vis, par)`  `# driver program``root ``=` `Node(``1``)``root.left ``=` `Node(``2``)``root.right ``=` `Node(``9``)``root.left.left ``=` `Node(``4``)``root.right.left ``=` `Node(``5``)``root.right.right ``=` `Node(``7``)``root.left.left.left ``=` `Node(``8``)``root.left.left.right ``=` `Node(``19``)``root.right.right.left ``=` `Node(``20``)``root.right.right.right ``=` `Node(``11``)``root.left.left.left.left ``=` `Node(``30``)``root.left.left.right.left ``=` `Node(``40``)``root.left.left.right.right ``=` `Node(``50``)` `target ``=` `9``K ``=` `1` `# function call``sum_at_distK(root, target, K)``print``(summ)` `# this code is contributed by Yash Agarwal(yashagarwal2852002)`

## C#

 `// C# code to implement above approach` `using` `System;``using` `System.Collections.Generic;` `public` `class` `GFG {` `  ``// Structure of a tree node``  ``class` `Node {``    ``public` `int` `data;``    ``public` `Node left;``    ``public` `Node right;``    ``public` `Node(``int` `val)``    ``{``      ``this``.data = val;``      ``this``.left = ``null``;``      ``this``.right = ``null``;``    ``}``  ``}` `  ``// Function for marking the parent node``  ``// for all the nodes using DFS``  ``static` `void` `dfs(Node root, Dictionary par)``  ``{``    ``if` `(root == ``null``)``      ``return``;``    ``if` `(root.left != ``null``)``      ``par.Add(root.left, root);``    ``if` `(root.right != ``null``)``      ``par.Add(root.right, root);``    ``dfs(root.left, par);``    ``dfs(root.right, par);``  ``}` `  ``static` `int` `sum;` `  ``// Function calling for finding the sum``  ``static` `void` `dfs3(Node root, ``int` `h, ``int` `k,``                   ``Dictionary vis,``                   ``Dictionary par)``  ``{``    ``if` `(h == k + 1)``      ``return``;``    ``if` `(root == ``null``)``      ``return``;``    ``if` `(vis.ContainsKey(root))``      ``return``;``    ``sum += root.data;``    ``vis.Add(root, 1);``    ``dfs3(root.left, h + 1, k, vis, par);``    ``dfs3(root.right, h + 1, k, vis, par);``    ``dfs3(par[root], h + 1, k, vis, par);``  ``}` `  ``// Function for finding``  ``// the target node in the tree``  ``static` `Node dfs2(Node root, ``int` `target)``  ``{``    ``if` `(root == ``null``)``      ``return` `null``;``    ``if` `(root.data == target)``      ``return` `root;``    ``Node node1 = dfs2(root.left, target);``    ``Node node2 = dfs2(root.right, target);``    ``if` `(node1 != ``null``)``      ``return` `node1;``    ``if` `(node2 != ``null``)``      ``return` `node2;``    ``return` `null``;``  ``}` `  ``static` `int` `sum_at_distK(Node root, ``int` `target, ``int` `k)``  ``{` `    ``// Hash Map to store``    ``// the parent of a node``    ``Dictionary par``      ``= ``new` `Dictionary();` `    ``// Make the parent of root node as NULL``    ``// since it does not have any parent``    ``par.Add(root, ``null``);` `    ``// Mark the parent node for all the``    ``// nodes using DFS``    ``dfs(root, par);` `    ``// Find the target node in the tree``    ``Node node = dfs2(root, target);` `    ``// Hash Map to mark``    ``// the visited nodes``    ``Dictionary vis``      ``= ``new` `Dictionary();` `    ``sum = 0;` `    ``// DFS call to find the sum``    ``dfs3(node, 0, k, vis, par);``    ``return` `sum;``  ``}` `  ``static` `public` `void` `Main()``  ``{` `    ``// Code``    ``Node root = ``new` `Node(1);``    ``root.left = ``new` `Node(2);``    ``root.right = ``new` `Node(9);``    ``root.left.left = ``new` `Node(4);``    ``root.right.left = ``new` `Node(5);``    ``root.right.right = ``new` `Node(7);``    ``root.left.left.left = ``new` `Node(8);``    ``root.left.left.right = ``new` `Node(19);``    ``root.right.right.left = ``new` `Node(20);``    ``root.right.right.right = ``new` `Node(11);``    ``root.left.left.left.left = ``new` `Node(30);``    ``root.left.left.right.left = ``new` `Node(40);``    ``root.left.left.right.right = ``new` `Node(50);` `    ``int` `target = 9, K = 1;` `    ``// Function call``    ``Console.Write(sum_at_distK(root, target, K));``  ``}``}` `// This code is contributed by lokesh(lokeshmvs21).`

## Javascript

 `       ``// JavaScript code for the above approach``       ``// Structure of a tree node``       ``class Node {``           ``constructor(val) {``               ``this``.data = val;``               ``this``.left = ``null``;``               ``this``.right = ``null``;``           ``}``       ``}` `       ``// Function for marking the parent node``       ``// for all the nodes using DFS``       ``function` `dfs(root, par) {``           ``if` `(root === ``null``) ``return``;``           ``if` `(root.left !== ``null``) par.set(root.left, root);``           ``if` `(root.right !== ``null``) par.set(root.right, root);``           ``dfs(root.left, par);``           ``dfs(root.right, par);``       ``}` `       ``let sum = 0;` `       ``// Function calling for finding the sum``       ``function` `dfs3(root, h, k, vis, par) {``           ``if` `(h === k + 1) ``return``;``           ``if` `(root === ``null``) ``return``;``           ``if` `(vis.has(root)) ``return``;``           ``sum += root.data;``           ``vis.set(root, 1);``           ``dfs3(root.left, h + 1, k, vis, par);``           ``dfs3(root.right, h + 1, k, vis, par);``           ``if` `(par.get(root) !== ``null` `&& vis.has(par.get(root))) {``               ``dfs3(par.get(root), h + 1, k, vis, par);``           ``}``       ``}` `       ``// Function for finding``       ``// the target node in the tree``       ``function` `dfs2(root, target) {``           ``if` `(root === ``null``) ``return` `null``;``           ``if` `(root.data === target) ``return` `root;``           ``let node1 = dfs2(root.left, target);``           ``let node2 = dfs2(root.right, target);``           ``if` `(node1 !== ``null``) ``return` `node1;``           ``if` `(node2 !== ``null``) ``return` `node2;``           ``return` `null``;``       ``}` `       ``function` `sumAtDistK(root, target, k)``       ``{``       ` `           ``// Map to store the parent of a node``           ``let par = ``new` `Map();` `           ``// Make the parent of root node as NULL``           ``// since it does not have any parent``           ``par.set(root, ``null``);` `           ``// Mark the parent node for all the``           ``// nodes using DFS``           ``dfs(root, par);` `           ``// Find the target node in the tree``           ``let node = dfs2(root, target);` `           ``// Map to mark the visited nodes``           ``let vis = ``new` `Map();``           ``sum = 1;` `           ``// DFS call to find the sum``           ``dfs3(node, 0, k, vis, par);``           ``return` `sum;``       ``}` `       ``// Taking Input``       ``let root = ``new` `Node(1);``       ``root.left = ``new` `Node(2);``       ``root.right = ``new` `Node(9);``       ``root.left.left = ``new` `Node(4);``       ``root.right.left = ``new` `Node(5);``       ``root.right.right = ``new` `Node(7);``       ``root.left.left.left = ``new` `Node(8);``       ``root.left.left.right = ``new` `Node(19);``       ``root.right.right.left = ``new` `Node(20);``       ``root.right.right.right = ``new` `Node(11);``       ``root.left.left.left.left = ``new` `Node(30);``       ``root.left.left.right.left = ``new` `Node(40);``       ``root.left.left.right.right = ``new` `Node(50);` `       ``let target = 9;``       ``let K = 1;` `       ``console.log(sumAtDistK(root, target, K));` `// This code is contributed by Potta Lokesh`

Output

`22`

Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)

My Personal Notes arrow_drop_up