Open In App

Sum of nodes within K distance from target

Last Updated : 11 Nov, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a binary tree, a target node and a positive integer K on it,  the task is to find the sum of all nodes within distance K from the target node (including the value of the target node in the sum).

Examples:

Input: target = 9, K = 1,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 22
Explanation: Nodes within distance 1 from 9 is 9 + 5 + 7 + 1 = 22

Input: target = 40,  K = 2,  
Binary Tree =             1
                                /  \
                             2     9
                           /      /   \
                        4      5     7
                      /  \           /  \
                   8    19      20   11
                 /      /   \
             30     40   50
Output: 113
Explanation: Nodes within distance 2 from 40 is
40 + 19 + 50 + 4 = 113

 

Approach: This problem can be solved using hashing and Depth-First-Search based on the following idea:

Use a data structure to store the parent of each node. Now utilise that data structure to perform a DFS traversal from target and calculate the sum of all the nodes within K distance from that node.

Follow the steps mentioned below to implement the approach:

  • Create a hash table (say par)to store the parent of each node.
  • Perform a DFS and store the parent of each node.
  • Now find the target in the tree.
  • Create a hash table to mark the visited nodes.
  • Start a DFS from target:
    • If the distance is not K, add the value in the final sum.
    • If the node is not visited then continue the DFS traversal for its neighbours also (i.e. parent and child) with the help of par and the links of each node.
    • Return the sum of its neighbours while the recursion for the current node is complete
  • Return the sum of all the nodes within K distance from the target.

Below is the implementation of the above approach:

C++




// C++ code to implement above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
    {
        this->data = val;
        this->left = 0;
        this->right = 0;
    }
};
 
// Function for marking the parent node
// for all the nodes using DFS
void dfs(Node* root,
         unordered_map<Node*, Node*>& par)
{
    if (root == 0)
        return;
    if (root->left != 0)
        par[root->left] = root;
    if (root->right != 0)
        par[root->right] = root;
    dfs(root->left, par);
    dfs(root->right, par);
}
 
// Function calling for finding the sum
void dfs3(Node* root, int h, int& sum, int k,
          unordered_map<Node*, int>& vis,
          unordered_map<Node*, Node*>& par)
{
    if (h == k + 1)
        return;
    if (root == 0)
        return;
    if (vis[root])
        return;
    sum += root->data;
    vis[root] = 1;
    dfs3(root->left, h + 1, sum, k, vis, par);
    dfs3(root->right, h + 1, sum, k, vis, par);
    dfs3(par[root], h + 1, sum, k, vis, par);
}
 
// Function for finding
// the target node in the tree
Node* dfs2(Node* root, int target)
{
    if (root == 0)
        return 0;
    if (root->data == target)
        return root;
    Node* node1 = dfs2(root->left, target);
    Node* node2 = dfs2(root->right, target);
    if (node1 != 0)
        return node1;
    if (node2 != 0)
        return node2;
}
 
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
{
    // Hash Table to store
    // the parent of a node
    unordered_map<Node*, Node*> par;
 
    // Make the parent of root node as NULL
    // since it does not have any parent
    par[root] = 0;
 
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
 
    // Find the target node in the tree
    Node* node = dfs2(root, target);
 
    // Hash Table to mark
    // the visited nodes
    unordered_map<Node*, int> vis;
 
    int sum = 0;
 
    // DFS call to find the sum
    dfs3(node, 0, sum, k, vis, par);
    return sum;
}
 
// Driver Code
int main()
{
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
    root->right->right->right
        = new Node(11);
    root->left->left->left->left
        = new Node(30);
    root->left->left->right->left
        = new Node(40);
    root->left->left->right->right
        = new Node(50);
 
    int target = 9, K = 1;
 
    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
}


Java




// Java code to implement above approach
import java.util.*;
 
public class Main {
    // Structure of a tree node
    static class Node {
        int data;
        Node left;
        Node right;
        Node(int val)
        {
            this.data = val;
            this.left = null;
            this.right = null;
        }
    }
 
    // Function for marking the parent node
    // for all the nodes using DFS
    static void dfs(Node root,
            HashMap <Node, Node> par)
    {
        if (root == null)
            return;
        if (root.left != null)
            par.put( root.left, root);
        if (root.right != null)
            par.put( root.right, root);
        dfs(root.left, par);
        dfs(root.right, par);
    }
    static int sum;
    // Function calling for finding the sum
    static void dfs3(Node root, int h, int k,
            HashMap <Node, Integer> vis,
            HashMap <Node, Node> par)
    {
        if (h == k + 1)
            return;
        if (root == null)
            return;
        if (vis.containsKey(root))
            return;
        sum += root.data;
        vis.put(root, 1);
        dfs3(root.left, h + 1, k, vis, par);
        dfs3(root.right, h + 1, k, vis, par);
        dfs3(par.get(root), h + 1, k, vis, par);
    }
    // Function for finding
    // the target node in the tree
    static Node dfs2(Node root, int target)
    {
        if (root == null)
            return null;
        if (root.data == target)
            return root;
        Node node1 = dfs2(root.left, target);
        Node node2 = dfs2(root.right, target);
        if (node1 != null)
            return node1;
        if (node2 != null)
            return node2;
        return null;
    }
 
    static int sum_at_distK(Node root, int target,
                 int k)
    {
        // Hash Map to store
        // the parent of a node
        HashMap <Node, Node> par =  new HashMap<>();
 
        // Make the parent of root node as NULL
        // since it does not have any parent
        par.put(root, null);
 
        // Mark the parent node for all the
        // nodes using DFS
        dfs(root, par);
 
        // Find the target node in the tree
        Node node = dfs2(root, target);
 
        // Hash Map to mark
        // the visited nodes
        HashMap <Node, Integer> vis = new HashMap<>();
 
        sum = 0;
 
        // DFS call to find the sum
        dfs3(node, 0, k, vis, par);
        return sum;
    }
 
 
 
    public static void main(String args[]) {
        // Taking Input
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right
            = new Node(11);
        root.left.left.left.left
            = new Node(30);
        root.left.left.right.left
            = new Node(40);
        root.left.left.right.right
            = new Node(50);
 
        int target = 9, K = 1;
 
        // Function call
        System.out.println( sum_at_distK(root, target, K) );
         
    }
}
 
// This code has been contributed by Sachin Sahara (sachin801)


Python3




# python program to implement above approach
# structure of tree node
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None
 
 
# function for making the parent node
# for all the nodes using DFS
def dfs(root, par):
    if(root is None):
        return
    if(root.left is not None):
        par[root.left] = root
    if(root.right is not None):
        par[root.right] = root
    dfs(root.left, par)
    dfs(root.right, par)
 
 
# function calling for finding the sum
summ = 0
def dfs3(root, h, k, vis, par):
    if(h == k+1):
        return
    if(root is None):
        return
    if(vis.get(root) == 1):
        return
    global summ
    summ += root.data
    vis[root] = 1
    dfs3(root.left, h+1, k, vis, par)
    dfs3(root.right, h+1, k, vis, par)
    dfs3(par[root], h+1, k, vis, par)
 
 
# function for finding
# the target node in the tree
def dfs2(root, target):
    if(root is None):
        return None
    if(root.data == target):
        return root
    node1 = dfs2(root.left, target)
    node2 = dfs2(root.right, target)
    if(node1 is not None):
        return node1
    if(node2 is not None):
        return node2
         
 
# function tofind the sum at distance k
def sum_at_distK(root, target, k):
    # hash table to store
    # the parent of a node
    par = {}
     
    # make the parent of root node as None
    # since it does not have any parent
    par[root] = 0
     
    # make the parent node for all the
    # nodes using DFS
    dfs(root, par)
     
    # find the target node in the tree
    node = dfs2(root, target)
     
    # hash table to make the visited nodes
    vis = {}
     
    # dfs call to find the sum
    dfs3(node, 0, k, vis, par)
 
 
# driver program
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
 
target = 9
K = 1
 
# function call
sum_at_distK(root, target, K)
print(summ)
 
# this code is contributed by Yash Agarwal(yashagarwal2852002)


C#




// C# code to implement above approach
 
using System;
using System.Collections.Generic;
 
public class GFG {
 
  // Structure of a tree node
  class Node {
    public int data;
    public Node left;
    public Node right;
    public Node(int val)
    {
      this.data = val;
      this.left = null;
      this.right = null;
    }
  }
 
  // Function for marking the parent node
  // for all the nodes using DFS
  static void dfs(Node root, Dictionary<Node, Node> par)
  {
    if (root == null)
      return;
    if (root.left != null)
      par.Add(root.left, root);
    if (root.right != null)
      par.Add(root.right, root);
    dfs(root.left, par);
    dfs(root.right, par);
  }
 
  static int sum;
 
  // Function calling for finding the sum
  static void dfs3(Node root, int h, int k,
                   Dictionary<Node, int> vis,
                   Dictionary<Node, Node> par)
  {
    if (h == k + 1)
      return;
    if (root == null)
      return;
    if (vis.ContainsKey(root))
      return;
    sum += root.data;
    vis.Add(root, 1);
    dfs3(root.left, h + 1, k, vis, par);
    dfs3(root.right, h + 1, k, vis, par);
    dfs3(par[root], h + 1, k, vis, par);
  }
 
  // Function for finding
  // the target node in the tree
  static Node dfs2(Node root, int target)
  {
    if (root == null)
      return null;
    if (root.data == target)
      return root;
    Node node1 = dfs2(root.left, target);
    Node node2 = dfs2(root.right, target);
    if (node1 != null)
      return node1;
    if (node2 != null)
      return node2;
    return null;
  }
 
  static int sum_at_distK(Node root, int target, int k)
  {
 
    // Hash Map to store
    // the parent of a node
    Dictionary<Node, Node> par
      = new Dictionary<Node, Node>();
 
    // Make the parent of root node as NULL
    // since it does not have any parent
    par.Add(root, null);
 
    // Mark the parent node for all the
    // nodes using DFS
    dfs(root, par);
 
    // Find the target node in the tree
    Node node = dfs2(root, target);
 
    // Hash Map to mark
    // the visited nodes
    Dictionary<Node, int> vis
      = new Dictionary<Node, int>();
 
    sum = 0;
 
    // DFS call to find the sum
    dfs3(node, 0, k, vis, par);
    return sum;
  }
 
  static public void Main()
  {
 
    // Code
    Node root = new Node(1);
    root.left = new Node(2);
    root.right = new Node(9);
    root.left.left = new Node(4);
    root.right.left = new Node(5);
    root.right.right = new Node(7);
    root.left.left.left = new Node(8);
    root.left.left.right = new Node(19);
    root.right.right.left = new Node(20);
    root.right.right.right = new Node(11);
    root.left.left.left.left = new Node(30);
    root.left.left.right.left = new Node(40);
    root.left.left.right.right = new Node(50);
 
    int target = 9, K = 1;
 
    // Function call
    Console.Write(sum_at_distK(root, target, K));
  }
}
 
// This code is contributed by lokesh(lokeshmvs21).


Javascript




       // JavaScript code for the above approach
       // Structure of a tree node
       class Node {
           constructor(val) {
               this.data = val;
               this.left = null;
               this.right = null;
           }
       }
 
       // Function for marking the parent node
       // for all the nodes using DFS
       function dfs(root, par) {
           if (root === null) return;
           if (root.left !== null) par.set(root.left, root);
           if (root.right !== null) par.set(root.right, root);
           dfs(root.left, par);
           dfs(root.right, par);
       }
 
       let sum = 0;
 
       // Function calling for finding the sum
       function dfs3(root, h, k, vis, par) {
           if (h === k + 1) return;
           if (root === null) return;
           if (vis.has(root)) return;
           sum += root.data;
           vis.set(root, 1);
           dfs3(root.left, h + 1, k, vis, par);
           dfs3(root.right, h + 1, k, vis, par);
           if (par.get(root) !== null && vis.has(par.get(root))) {
               dfs3(par.get(root), h + 1, k, vis, par);
           }
       }
 
       // Function for finding
       // the target node in the tree
       function dfs2(root, target) {
           if (root === null) return null;
           if (root.data === target) return root;
           let node1 = dfs2(root.left, target);
           let node2 = dfs2(root.right, target);
           if (node1 !== null) return node1;
           if (node2 !== null) return node2;
           return null;
       }
 
       function sumAtDistK(root, target, k)
       {
        
           // Map to store the parent of a node
           let par = new Map();
 
           // Make the parent of root node as NULL
           // since it does not have any parent
           par.set(root, null);
 
           // Mark the parent node for all the
           // nodes using DFS
           dfs(root, par);
 
           // Find the target node in the tree
           let node = dfs2(root, target);
 
           // Map to mark the visited nodes
           let vis = new Map();
           sum = 1;
 
           // DFS call to find the sum
           dfs3(node, 0, k, vis, par);
           return sum;
       }
 
       // Taking Input
       let root = new Node(1);
       root.left = new Node(2);
       root.right = new Node(9);
       root.left.left = new Node(4);
       root.right.left = new Node(5);
       root.right.right = new Node(7);
       root.left.left.left = new Node(8);
       root.left.left.right = new Node(19);
       root.right.right.left = new Node(20);
       root.right.right.right = new Node(11);
       root.left.left.left.left = new Node(30);
       root.left.left.right.left = new Node(40);
       root.left.left.right.right = new Node(50);
 
       let target = 9;
       let K = 1;
 
       console.log(sumAtDistK(root, target, K));
 
// This code is contributed by Potta Lokesh


Output

22






Time Complexity: O(N) where N is the number of nodes in the tree
Auxiliary Space: O(N)

Approach using BFS:-

  • We will be using level order traversal to find the sum of nodes

Implementation:-

  • First we will find the target node using level order traversal.
  • While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
  • After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer.

C++




// C++ code to implement above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of a tree node
struct Node {
    int data;
    Node* left;
    Node* right;
    Node(int val)
    {
        this->data = val;
        this->left = 0;
        this->right = 0;
    }
};
 
// Function to find the sum at distance K
int sum_at_distK(Node* root, int target,
                 int k)
{
  //variable to store answer
  int ans = 0;
   
  //queue for bfs
  queue<Node*> q;
   
  q.push(root);
   
  //to store target node
  Node* need;
   
  //map to store parent of each node
  unordered_map<Node*, Node*> m;
   
  //bfs
  while(q.size()){
     
    int s = q.size();
     
    //traversing to current level
    for(int i=0;i<s;i++){
       
      Node* temp = q.front();
       
      q.pop();
       
      //if target value found
      if(temp->data==target) need=temp;
       
      if(temp->left){
        q.push(temp->left);
        m[temp->left]=temp;
      }
       
      if(temp->right){
        q.push(temp->right);
        m[temp->right]=temp;
      }
       
    }
     
  }
   
  //map to store occurrence of a node
  //that is the node has taken or not
  unordered_map<Node*, int> mm;
   
  q.push(need);
   
  //to store current distance
  int c = 0;
   
  while(q.size()){
     
    int s = q.size();
     
    for(int i=0;i<s;i++){
       
      Node* temp = q.front();
       
      q.pop();
       
      mm[temp] = 1;
       
      ans+=temp->data;
       
      //moving left
      if(temp->left&&mm[temp->left]==0){
        q.push(temp->left);
      }
       
      //moving right
      if(temp->right&&mm[temp->right]==0){
        q.push(temp->right);
      }
       
      //movinf to parent
      if(m[temp]&&mm[m[temp]]==0){
        q.push(m[temp]);
      }
       
    }
     
    c++;
    if(c>k)break;
     
  }
  return ans;
}
 
// Driver Code
int main()
{
    // Taking Input
    Node* root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(9);
    root->left->left = new Node(4);
    root->right->left = new Node(5);
    root->right->right = new Node(7);
    root->left->left->left = new Node(8);
    root->left->left->right = new Node(19);
    root->right->right->left = new Node(20);
    root->right->right->right
        = new Node(11);
    root->left->left->left->left
        = new Node(30);
    root->left->left->right->left
        = new Node(40);
    root->left->left->right->right
        = new Node(50);
 
    int target = 9, K = 1;
 
    // Function call
    cout << sum_at_distK(root, target, K);
    return 0;
}
//code contributed by shubhamrajput6156


Java




import java.util.*;
 
// Structure of a tree node
class Node {
    int data;
    Node left;
    Node right;
 
    public Node(int val) {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}
 
public class Main {
 
    // Function to find the sum at distance K
    public static int sumAtDistK(Node root, int target, int k) {
        // Variable to store the answer
        int ans = 0;
 
        // Queue for BFS
        Queue<Node> q = new LinkedList<>();
 
        q.add(root);
 
        // To store the target node
        Node need = null;
 
        // Map to store the parent of each node
        Map<Node, Node> parentMap = new HashMap<>();
 
        // BFS
        while (!q.isEmpty()) {
            int size = q.size();
 
            // Traverse the current level
            for (int i = 0; i < size; i++) {
                Node temp = q.poll();
 
                // If the target value is found
                if (temp.data == target) {
                    need = temp;
                }
 
                if (temp.left != null) {
                    q.add(temp.left);
                    parentMap.put(temp.left, temp);
                }
 
                if (temp.right != null) {
                    q.add(temp.right);
                    parentMap.put(temp.right, temp);
                }
            }
        }
 
        // Map to store the occurrence of a node (whether it has been visited)
        Map<Node, Integer> visitedMap = new HashMap<>();
 
        q.add(need);
 
        // Current distance
        int currentDistance = 0;
 
        while (!q.isEmpty()) {
            int size = q.size();
 
            for (int i = 0; i < size; i++) {
                Node temp = q.poll();
 
                visitedMap.put(temp, 1);
 
                ans += temp.data;
 
                // Moving left
                if (temp.left != null && visitedMap.getOrDefault(temp.left, 0) == 0) {
                    q.add(temp.left);
                }
 
                // Moving right
                if (temp.right != null && visitedMap.getOrDefault(temp.right, 0) == 0) {
                    q.add(temp.right);
                }
 
                // Moving to parent
                if (parentMap.containsKey(temp) && visitedMap.getOrDefault(parentMap.get(temp), 0) == 0) {
                    q.add(parentMap.get(temp));
                }
            }
 
            currentDistance++;
            if (currentDistance > k) {
                break;
            }
        }
 
        return ans;
    }
 
    // Driver code
    public static void main(String[] args) {
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);
 
        int target = 9, K = 1;
 
        // Function call
        System.out.println(sumAtDistK(root, target, K));
    }
}


Python3




from collections import deque
 
class Node:
    def __init__(self, val):
        self.data = val
        self.left = None
        self.right = None
 
# Function to find the sum at distance K
def sum_at_distK(root, target, k):
    ans = 0
 
    # Queue for BFS
    q = deque()
    q.append(root)
 
    need = None
 
    # Dictionary to store parent of each node
    m = {}
 
    # BFS traversal to find the target node
    while q:
        s = len(q)
 
        # Traversing the current level
        for i in range(s):
            temp = q.popleft()
 
            if temp.data == target:
                need = temp
 
            if temp.left:
                q.append(temp.left)
                m[temp.left] = temp
 
            if temp.right:
                q.append(temp.right)
                m[temp.right] = temp
 
    # Dictionary to store occurrence of a node (visited or not)
    mm = {}
 
    q.append(need)
 
    c = 0
 
    # BFS traversal within K distance
    while q:
        s = len(q)
 
        for i in range(s):
            temp = q.popleft()
 
            mm[temp] = 1
 
            ans += temp.data
 
            # Moving left
            if temp.left and temp.left not in mm:
                q.append(temp.left)
 
            # Moving right
            if temp.right and temp.right not in mm:
                q.append(temp.right)
 
            # Moving to parent
            if temp in m and m[temp] not in mm:
                q.append(m[temp])
 
        c += 1
        if c > k:
            break
 
    return ans
 
# Driver Code
# Taking Input
root = Node(1)
root.left = Node(2)
root.right = Node(9)
root.left.left = Node(4)
root.right.left = Node(5)
root.right.right = Node(7)
root.left.left.left = Node(8)
root.left.left.right = Node(19)
root.right.right.left = Node(20)
root.right.right.right = Node(11)
root.left.left.left.left = Node(30)
root.left.left.right.left = Node(40)
root.left.left.right.right = Node(50)
 
target = 9
K = 1
 
# Function call
print(sum_at_distK(root, target, K))


C#




using System;
using System.Collections.Generic;
 
// Structure of a tree node
class Node
{
    public int data;
    public Node left;
    public Node right;
 
    public Node(int val)
    {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}
 
class GFG
{
    // Function to find the sum at distance K
    public static int SumAtDistK(Node root, int target, int k)
    {
        // Variable to store the answer
        int ans = 0;
 
        // Queue for BFS
        Queue<Node> q = new Queue<Node>();
 
        q.Enqueue(root);
 
        // To store the target node
        Node need = null;
 
        // Dictionary to store the parent of each node
        Dictionary<Node, Node> parentMap = new Dictionary<Node, Node>();
 
        // BFS
        while (q.Count > 0)
        {
            int size = q.Count;
 
            // Traverse the current level
            for (int i = 0; i < size; i++)
            {
                Node temp = q.Dequeue();
 
                // If the target value is found
                if (temp.data == target)
                {
                    need = temp;
                }
 
                if (temp.left != null)
                {
                    q.Enqueue(temp.left);
                    parentMap[temp.left] = temp;
                }
 
                if (temp.right != null)
                {
                    q.Enqueue(temp.right);
                    parentMap[temp.right] = temp;
                }
            }
        }
 
        // Dictionary to store the occurrence of a node (whether it has been visited)
        Dictionary<Node, int> visitedMap = new Dictionary<Node, int>();
 
        q.Enqueue(need);
 
        // Current distance
        int currentDistance = 0;
 
        while (q.Count > 0)
        {
            int size = q.Count;
 
            for (int i = 0; i < size; i++)
            {
                Node temp = q.Dequeue();
 
                visitedMap[temp] = 1;
 
                ans += temp.data;
 
                // Moving left
                if (temp.left != null && visitedMap.GetValueOrDefault(temp.left, 0) == 0)
                {
                    q.Enqueue(temp.left);
                }
 
                // Moving right
                if (temp.right != null && visitedMap.GetValueOrDefault(temp.right, 0) == 0)
                {
                    q.Enqueue(temp.right);
                }
 
                // Moving to parent
                if (parentMap.ContainsKey(temp) && visitedMap.GetValueOrDefault(parentMap[temp], 0) == 0)
                {
                    q.Enqueue(parentMap[temp]);
                }
            }
 
            currentDistance++;
            if (currentDistance > k)
            {
                break;
            }
        }
 
        return ans;
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        Node root = new Node(1);
        root.left = new Node(2);
        root.right = new Node(9);
        root.left.left = new Node(4);
        root.right.left = new Node(5);
        root.right.right = new Node(7);
        root.left.left.left = new Node(8);
        root.left.left.right = new Node(19);
        root.right.right.left = new Node(20);
        root.right.right.right = new Node(11);
        root.left.left.left.left = new Node(30);
        root.left.left.right.left = new Node(40);
        root.left.left.right.right = new Node(50);
 
        int target = 9, K = 1;
 
        // Function call
        Console.WriteLine(SumAtDistK(root, target, K));
    }
}


Javascript




class Node {
    constructor(val) {
        this.data = val;
        this.left = null;
        this.right = null;
    }
}
 
// Function to find the sum at distance K
function sum_at_distK(root, target, k) {
    let ans = 0;
 
    // Queue for BFS
    let q = [];
    q.push(root);
 
    let need = null;
 
    // Map to store parent of each node
    let m = new Map();
 
    // BFS traversal to find the target node
    while (q.length) {
        let s = q.length;
 
        // Traversing the current level
        for (let i = 0; i < s; i++) {
            let temp = q.shift();
 
            if (temp.data === target) {
                need = temp;
            }
 
            if (temp.left) {
                q.push(temp.left);
                m.set(temp.left, temp);
            }
 
            if (temp.right) {
                q.push(temp.right);
                m.set(temp.right, temp);
            }
        }
    }
 
    // Map to store occurrence of a node (visited or not)
    let mm = new Map();
 
    q.push(need);
 
    let c = 0;
 
    // BFS traversal within K distance
    while (q.length) {
        let s = q.length;
 
        for (let i = 0; i < s; i++) {
            let temp = q.shift();
 
            mm.set(temp, 1);
 
            ans += temp.data;
 
            // Moving left
            if (temp.left && !mm.has(temp.left)) {
                q.push(temp.left);
            }
 
            // Moving right
            if (temp.right && !mm.has(temp.right)) {
                q.push(temp.right);
            }
 
            // Moving to parent
            if (m.has(temp) && !mm.has(m.get(temp))) {
                q.push(m.get(temp));
            }
        }
 
        c++;
        if (c > k) break;
    }
    return ans;
}
 
// Driver Code
// Taking Input
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(9);
root.left.left = new Node(4);
root.right.left = new Node(5);
root.right.right = new Node(7);
root.left.left.left = new Node(8);
root.left.left.right = new Node(19);
root.right.right.left = new Node(20);
root.right.right.right = new Node(11);
root.left.left.left.left = new Node(30);
root.left.left.right.left = new Node(40);
root.left.left.right.right = new Node(50);
 
let target = 9, K = 1;
 
// Function call
console.log(sum_at_distK(root, target, K));


Output

22






Time Complexity:- O(N) Where N is the number of nodes
Auxiliary Space:- O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads