Open In App

Sum of all nodes with smaller values at a distance K from a given node in a BST

Last Updated : 16 Oct, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a Binary Search Tree, a target node in the BST, and an integer value K, the task is to find the sum of all nodes that are at a distance K from the target node whose value is less than the target node.

Examples:

Input: target = 7, K = 2

Output: 11
Explanation:
The nodes at a  distance K(= 2) from the node 7 is 1, 4, and 6. Therefore, the sum of nodes is 11.

Input: target = 5, K = 1

Output: 4

Approach: The given problem can be solved by performing DFS Traversal for K distance below the target node and perform the DFS Traversal upward K distance from the target node. Follow the steps below to solve the problem:

  • Define a function kDistanceDownSum(root, k, &sum) and perform the following steps:
    • For the Base Case, check if the root is nullptr and k is less than 0, then return from the function.
    • If the value of k equals 0, then add root->val to the variable sum and return.
    • Call the same function kDistanceDownSum(root->left, k-1, sum) and kDistanceDownSum(root->right, k – 1, sum) for the left and right sub-trees.
  • For the Base Case, check if the root is nullptr, then return -1.
  • If the root is the same as the target, then call the function kDistanceDownSum(root->left, k – 1, sum) to calculate the sum for the first type of nodes and return 0(No second type of nodes possible).
  • Initialize the variable dl as -1 and if the target is less than root, then set the value of dl as the value returned by the function kDistanceSum(root->left, target k, sum).
  • If the value of dl is not equal to -1, then if sum equals (dl + 1), then add the value of root->data to the sum and then return -1.
  • Similarly, initialize the variable dr as -1 and if the target is greater than the root, then update the value of dr to the value returned by kDistanceSum(root->right, target k, sum).
  • If the value of dr is not equal to -1, then if the value of sum equals (dr + 1), then add the value of root->data to the sum. Otherwise, call the function kDistanceDownSum(root->left, k – dr – 2, sum) and return (1 + dr).
  • After performing the above steps, print the value of ans as the resultant sum.

Following is the implementation of the above approach:

C++




// C++ program for the above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of Tree
struct TreeNode {
 
    int data;
    TreeNode* left;
    TreeNode* right;
 
    // Constructor
    TreeNode(int data)
    {
        this->data = data;
        this->left = NULL;
        this->right = NULL;
    }
};
 
// Function to add the node to the sum
// below the target node
void kDistanceDownSum(TreeNode* root,
                      int k, int& sum)
{
 
    // Base Case
    if (root == NULL || k < 0)
        return;
 
    // If Kth distant node is reached
    if (k == 0) {
        sum += root->data;
        return;
    }
 
    // Recur for the left and the
    // right subtrees
    kDistanceDownSum(root->left,
                     k - 1, sum);
    kDistanceDownSum(root->right,
                     k - 1, sum);
}
 
// Function to find the K distant nodes
// from target node, it returns -1 if
// target node is not present in tree
int kDistanceSum(TreeNode* root,
                 int target,
                 int k, int& sum)
{
    // Base Case 1
    if (root == NULL)
        return -1;
 
    // If target is same as root.
    if (root->data == target) {
        kDistanceDownSum(root->left,
                         k - 1, sum);
        return 0;
    }
 
    // Recur for the left subtree
    int dl = -1;
 
    // Tree is BST so reduce the
    // search space
    if (target < root->data) {
        dl = kDistanceSum(root->left,
                          target, k, sum);
    }
 
    // Check if target node was found
    // in left subtree
    if (dl != -1) {
 
        // If root is at distance k from
        // the target
        if (dl + 1 == k)
            sum += root->data;
 
        // Node less than target will be
        // present in left
        return -1;
    }
 
    // When node is not present in the
    // left subtree
    int dr = -1;
    if (target > root->data) {
        dr = kDistanceSum(root->right,
                          target, k, sum);
    }
 
    if (dr != -1) {
 
        // If Kth distant node is reached
        if (dr + 1 == k)
            sum += root->data;
 
        // Node less than target at k
        // distance maybe present in the
        // left tree
        else
            kDistanceDownSum(root->left,
                             k - dr - 2, sum);
 
        return 1 + dr;
    }
 
    // If target was not present in the
    // left nor in right subtree
    return -1;
}
 
// Function to insert a node in BST
TreeNode* insertNode(int data,
                     TreeNode* root)
{
    // If root is NULL
    if (root == NULL) {
        TreeNode* node = new TreeNode(data);
        return node;
    }
 
    // Insert the data in right half
    else if (data > root->data) {
        root->right = insertNode(
            data, root->right);
    }
 
    // Insert the data in left half
    else if (data <= root->data) {
        root->left = insertNode(
            data, root->left);
    }
 
    // Return the root node
    return root;
}
 
// Function to find the sum of K distant
// nodes from the target node having
// value less than target node
void findSum(TreeNode* root, int target,
             int K)
{
 
    // Stores the sum of nodes having
    // values < target at K distance
    int sum = 0;
 
    kDistanceSum(root, target, K, sum);
 
    // Print the resultant sum
    cout << sum;
}
 
// Driver Code
int main()
{
    TreeNode* root = NULL;
    int N = 11;
    int tree[] = { 3, 1, 7, 0, 2, 5,
                   10, 4, 6, 9, 8 };
 
    // Create the Tree
    for (int i = 0; i < N; i++) {
        root = insertNode(tree[i], root);
    }
 
    int target = 7;
    int K = 2;
    findSum(root, target, K);
 
    return 0;
}


Java




// Java program for the above approach
import java.util.*;
 
public class GFG{
    static int sum;
   
// Structure of Tree
static class TreeNode {
 
    int data;
    TreeNode left;
    TreeNode right;
 
    // Constructor
    TreeNode(int data)
    {
        this.data = data;
        this.left = null;
        this.right = null;
    }
};
 
// Function to add the node to the sum
// below the target node
static void kDistanceDownSum(TreeNode root,
                      int k)
{
 
    // Base Case
    if (root == null || k < 0)
        return;
 
    // If Kth distant node is reached
    if (k == 0) {
        sum += root.data;
        return;
    }
 
    // Recur for the left and the
    // right subtrees
    kDistanceDownSum(root.left,
                     k - 1);
    kDistanceDownSum(root.right,
                     k - 1);
}
 
// Function to find the K distant nodes
// from target node, it returns -1 if
// target node is not present in tree
static int kDistanceSum(TreeNode root,
                 int target,
                 int k)
{
    // Base Case 1
    if (root == null)
        return -1;
 
    // If target is same as root.
    if (root.data == target) {
        kDistanceDownSum(root.left,
                         k - 1);
    return 0;
    }
 
    // Recur for the left subtree
    int dl = -1;
 
    // Tree is BST so reduce the
    // search space
    if (target < root.data) {
        dl = kDistanceSum(root.left,
                          target, k);
    }
 
    // Check if target node was found
    // in left subtree
    if (dl != -1) {
 
        // If root is at distance k from
        // the target
        if (dl + 1 == k)
            sum += root.data;
 
        // Node less than target will be
        // present in left
        return -1;
    }
 
    // When node is not present in the
    // left subtree
    int dr = -1;
    if (target > root.data) {
        dr = kDistanceSum(root.right,
                          target, k);
    }
 
    if (dr != -1) {
 
        // If Kth distant node is reached
        if (dr + 1 == k)
            sum += root.data;
 
        // Node less than target at k
        // distance maybe present in the
        // left tree
        else
            kDistanceDownSum(root.left,
                             k - dr - 2);
 
        return 1 + dr;
    }
 
    // If target was not present in the
    // left nor in right subtree
    return -1;
}
 
// Function to insert a node in BST
static TreeNode insertNode(int data,
                     TreeNode root)
{
    // If root is null
    if (root == null) {
        TreeNode node = new TreeNode(data);
        return node;
    }
 
    // Insert the data in right half
    else if (data > root.data) {
        root.right = insertNode(
            data, root.right);
    }
 
    // Insert the data in left half
    else if (data <= root.data) {
        root.left = insertNode(
            data, root.left);
    }
 
    // Return the root node
    return root;
}
 
// Function to find the sum of K distant
// nodes from the target node having
// value less than target node
static void findSum(TreeNode root, int target,
             int K)
{
 
    // Stores the sum of nodes having
    // values < target at K distance
    sum = 0;
 
    kDistanceSum(root, target, K);
 
    // Print the resultant sum
    System.out.print(sum);
}
 
// Driver Code
public static void main(String[] args)
{
    TreeNode root = null;
    int N = 11;
    int tree[] = { 3, 1, 7, 0, 2, 5,
                   10, 4, 6, 9, 8 };
 
    // Create the Tree
    for (int i = 0; i < N; i++) {
        root = insertNode(tree[i], root);
    }
 
    int target = 7;
    int K = 2;
    findSum(root, target, K);
 
}
}
 
// This code is contributed by 29AjayKumar


Python3




# python 3 program for the above approach
 
# Structure of Tree
sum = 0
 
class Node:
    # A constructor to create a new node
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
# Function to add the node to the sum
# below the target node
def kDistanceDownSum(root, k):
    global sum
    # Base Case
    if (root == None or k < 0):
        return
 
    # If Kth distant node is reached
    if (k == 0):
        sum += root.data
        return
 
    # Recur for the left and the
    # right subtrees
    kDistanceDownSum(root.left,k - 1)
    kDistanceDownSum(root.right,k - 1)
 
# Function to find the K distant nodes
# from target node, it returns -1 if
# target node is not present in tree
def kDistanceSum(root, target, k):
    global sum
    # Base Case 1
    if (root == None):
        return -1
 
    # If target is same as root.
    if (root.data == target):
        kDistanceDownSum(root.left,k - 1)
        return 0
 
    # Recur for the left subtree
    dl = -1
 
    # Tree is BST so reduce the
    # search space
    if (target < root.data):
        dl = kDistanceSum(root.left, target, k)
 
    # Check if target node was found
    # in left subtree
    if (dl != -1):
        # If root is at distance k from
        # the target
        if (dl + 1 == k):
            sum += root.data
 
        # Node less than target will be
        # present in left
        return -1
 
    # When node is not present in the
    # left subtree
    dr = -1
    if (target > root.data):
        dr = kDistanceSum(root.right, target, k)
 
    if (dr != -1):
        # If Kth distant node is reached
        if (dr + 1 == k):
            sum += root.data
 
        # Node less than target at k
        # distance maybe present in the
        # left tree
        else:
            kDistanceDownSum(root.left, k - dr - 2)
 
        return 1 + dr
 
    # If target was not present in the
    # left nor in right subtree
    return -1
 
# Function to insert a node in BST
def insertNode(data, root):
    # If root is NULL
    if (root == None):
        node = Node(data)
        return node
 
    # Insert the data in right half
    elif (data > root.data):
        root.right = insertNode(data, root.right)
 
    # Insert the data in left half
    elif(data <= root.data):
        root.left = insertNode(data, root.left)
 
    # Return the root node
    return root
 
# Function to find the sum of K distant
# nodes from the target node having
# value less than target node
def findSum(root, target, K):
   
    # Stores the sum of nodes having
    # values < target at K distance
    kDistanceSum(root, target, K)
 
    # Print the resultant sum
    print(sum)
 
# Driver Code
if __name__ == '__main__':
    root = None
    N = 11
    tree = [3, 1, 7, 0, 2, 5,10, 4, 6, 9, 8]
 
    # Create the Tree
    for i in range(N):
        root = insertNode(tree[i], root)
 
    target = 7
    K = 2
    findSum(root, target, K)
     
    # This code is contributed by SURENDRA_GANGWAR.


C#




// C# program for the above approach
using System;
 
public class GFG{
    static int sum;
   
// Structure of Tree
public
 
 class TreeNode {
 
    public
 
 int data;
    public
 
 TreeNode left;
    public
 
 TreeNode right;
 
    // Constructor
    public TreeNode(int data)
    {
        this.data = data;
        this.left = null;
        this.right = null;
    }
};
 
// Function to add the node to the sum
// below the target node
static void kDistanceDownSum(TreeNode root,
                      int k)
{
 
    // Base Case
    if (root == null || k < 0)
        return;
 
    // If Kth distant node is reached
    if (k == 0) {
        sum += root.data;
        return;
    }
 
    // Recur for the left and the
    // right subtrees
    kDistanceDownSum(root.left,
                     k - 1);
    kDistanceDownSum(root.right,
                     k - 1);
}
 
// Function to find the K distant nodes
// from target node, it returns -1 if
// target node is not present in tree
static int kDistanceSum(TreeNode root,
                 int target,
                 int k)
{
    // Base Case 1
    if (root == null)
        return -1;
 
    // If target is same as root.
    if (root.data == target) {
        kDistanceDownSum(root.left,
                         k - 1);
    return 0;
    }
 
    // Recur for the left subtree
    int dl = -1;
 
    // Tree is BST so reduce the
    // search space
    if (target < root.data) {
        dl = kDistanceSum(root.left,
                          target, k);
    }
 
    // Check if target node was found
    // in left subtree
    if (dl != -1) {
 
        // If root is at distance k from
        // the target
        if (dl + 1 == k)
            sum += root.data;
 
        // Node less than target will be
        // present in left
        return -1;
    }
 
    // When node is not present in the
    // left subtree
    int dr = -1;
    if (target > root.data) {
        dr = kDistanceSum(root.right,
                          target, k);
    }
 
    if (dr != -1) {
 
        // If Kth distant node is reached
        if (dr + 1 == k)
            sum += root.data;
 
        // Node less than target at k
        // distance maybe present in the
        // left tree
        else
            kDistanceDownSum(root.left,
                             k - dr - 2);
 
        return 1 + dr;
    }
 
    // If target was not present in the
    // left nor in right subtree
    return -1;
}
 
// Function to insert a node in BST
static TreeNode insertNode(int data,
                     TreeNode root)
{
    // If root is null
    if (root == null) {
        TreeNode node = new TreeNode(data);
        return node;
    }
 
    // Insert the data in right half
    else if (data > root.data) {
        root.right = insertNode(
            data, root.right);
    }
 
    // Insert the data in left half
    else if (data <= root.data) {
        root.left = insertNode(
            data, root.left);
    }
 
    // Return the root node
    return root;
}
 
// Function to find the sum of K distant
// nodes from the target node having
// value less than target node
static void findSum(TreeNode root, int target,
             int K)
{
 
    // Stores the sum of nodes having
    // values < target at K distance
    sum = 0;
 
    kDistanceSum(root, target, K);
 
    // Print the resultant sum
    Console.Write(sum);
}
 
// Driver Code
public static void Main(String[] args)
{
    TreeNode root = null;
    int N = 11;
    int []tree = { 3, 1, 7, 0, 2, 5,
                   10, 4, 6, 9, 8 };
 
    // Create the Tree
    for (int i = 0; i < N; i++) {
        root = insertNode(tree[i], root);
    }
 
    int target = 7;
    int K = 2;
    findSum(root, target, K);
 
}
}
 
// This code is contributed by gauravrajput1


Javascript




<script>
// Javascript program for the above approach
 
// Structure of Tree
let sum = 0;
 
class TreeNode {
  // Constructor
  constructor(data = "", left = null, right = null) {
    this.data = data;
    this.left = left;
    this.right = right;
  }
}
 
// Function to add the node to the sum
// below the target node
function kDistanceDownSum(root, k)
{
  // Base Case
  if (root == null || k < 0) {
    return
  }
 
  // If Kth distant node is reached
  if (k == 0) {
    sum += root.data;
    return;
  }
 
  // Recur for the left and the
  // right subtrees
  kDistanceDownSum(root.left, k - 1);
  kDistanceDownSum(root.right, k - 1);
}
 
// Function to find the K distant nodes
// from target node, it returns -1 if
// target node is not present in tree
function kDistanceSum(root, target, k) {
  // Base Case 1
  if (root == null) return -1;
 
  // If target is same as root.
  if (root.data == target) {
    kDistanceDownSum(root.left, k - 1);
    return 0;
  }
 
  // Recur for the left subtree
  let dl = -1;
 
  // Tree is BST so reduce the
  // search space
  if (target < root.data) {
    dl = kDistanceSum(root.left, target, k);
  }
 
  // Check if target node was found
  // in left subtree
  if (dl != -1) {
    // If root is at distance k from
    // the target
    if (dl + 1 == k) sum += root.data;
 
    // Node less than target will be
    // present in left
    return -1;
  }
 
  // When node is not present in the
  // left subtree
  let dr = -1;
  if (target > root.data) {
    dr = kDistanceSum(root.right, target, k);
  }
 
  if (dr != -1) {
    // If Kth distant node is reached
    if (dr + 1 == k) sum += root.data;
    // Node less than target at k
    // distance maybe present in the
    // left tree
    else kDistanceDownSum(root.left, k - dr - 2);
 
    return 1 + dr;
  }
 
  // If target was not present in the
  // left nor in right subtree
  return -1;
}
 
// Function to insert a node in BST
function insertNode(data, root) {
  // If root is null
  if (root == null) {
    let node = new TreeNode(data);
    return node;
  }
 
  // Insert the data in right half
  else if (data > root.data) {
    root.right = insertNode(data, root.right);
  }
 
  // Insert the data in left half
  else if (data <= root.data) {
    root.left = insertNode(data, root.left);
  }
 
  // Return the root node
  return root;
}
 
// Function to find the sum of K distant
// nodes from the target node having
// value less than target node
function findSum(root, target, K) {
  // Stores the sum of nodes having
  // values < target at K distance
  kDistanceSum(root, target, K, sum);
 
  // Print the resultant sum
  document.write(sum);
}
 
// Driver Code
 
let root = null;
let N = 11;
let tree = [3, 1, 7, 0, 2, 5, 10, 4, 6, 9, 8];
 
// Create the Tree
for (let i = 0; i < N; i++) {
  root = insertNode(tree[i], root);
}
 
let target = 7;
let K = 2;
findSum(root, target, K);
</script>


Output

11







Time Complexity: O(N) where N is the number of nodes in given binary tree.
Auxiliary Space: O(h) where h is the height of binary tree due to recursion call stack.

Approach using BFS:-

  • We will be using level order traversal to find the sum of smaller value nodes

Implementation:-

  • First we will find the target node using level order traversal.
  • While finding the target node we will store the parent of each node so that we can move towards the parent of the node as well.
  • After this we will traverse from the target node to all the tree directions that is toward both child and parent till distance K and add the values of node into our answer which are smaller than target at distance K.

C++




// C++ program for the above approach
 
#include <bits/stdc++.h>
using namespace std;
 
// Structure of Tree
struct TreeNode {
 
    int data;
    TreeNode* left;
    TreeNode* right;
 
    // Constructor
    TreeNode(int data)
    {
        this->data = data;
        this->left = NULL;
        this->right = NULL;
    }
};
 
// Function to insert a node in BST
TreeNode* insertNode(int data,
                     TreeNode* root)
{
    // If root is NULL
    if (root == NULL) {
        TreeNode* node = new TreeNode(data);
        return node;
    }
 
    // Insert the data in right half
    else if (data > root->data) {
        root->right = insertNode(
            data, root->right);
    }
 
    // Insert the data in left half
    else if (data <= root->data) {
        root->left = insertNode(
            data, root->left);
    }
 
    // Return the root node
    return root;
}
 
// Function to find the sum of K distant
// nodes from the target node having
// value less than target node
void findSum(TreeNode* root, int target,
             int K)
{
    //variable to store answer
    int ans = 0;
 
    //queue for bfs
    queue<TreeNode*> q;
 
    q.push(root);
 
    //to store target node
    TreeNode* need;
 
    //map to store parent of each node
    unordered_map<TreeNode*, TreeNode*> m;
 
    //bfs
    while(q.size()){
 
      int s = q.size();
 
      //traversing to current level
      for(int i=0;i<s;i++){
 
        TreeNode* temp = q.front();
 
        q.pop();
 
        //if target value found
        if(temp->data==target) need=temp;
 
        if(temp->left){
          q.push(temp->left);
          m[temp->left]=temp;
        }
 
        if(temp->right){
          q.push(temp->right);
          m[temp->right]=temp;
        }
 
      }
 
    }
 
    //map to store occurrence of a node
    //that is the node has taken or not
    unordered_map<TreeNode*, int> mm;
 
    q.push(need);
 
    //to store current distance
    int c = 0;
 
    while(q.size()){
 
      int s = q.size();
 
      for(int i=0;i<s;i++){
 
        TreeNode* temp = q.front();
 
        q.pop();
 
        mm[temp] = 1;
         
        if(temp->data<target and c==K)
        ans+=temp->data;
 
        //moving left
        if(temp->left&&mm[temp->left]==0){
          q.push(temp->left);
        }
 
        //moving right
        if(temp->right&&mm[temp->right]==0){
          q.push(temp->right);
        }
 
        //movinf to parent
        if(m[temp]&&mm[m[temp]]==0){
          q.push(m[temp]);
        }
 
      }
 
      c++;
      if(c>K)break;
 
    }
    cout<<ans<<endl;
 
}
 
// Driver Code
int main()
{
    TreeNode* root = NULL;
    int N = 11;
    int tree[] = { 3, 1, 7, 0, 2, 5,
                   10, 4, 6, 9, 8 };
 
    // Create the Tree
    for (int i = 0; i < N; i++) {
        root = insertNode(tree[i], root);
    }
 
    int target = 7;
    int K = 2;
    findSum(root, target, K);
 
    return 0;
}
//code contributed by shubhamrajput6156


Java




import java.util.*;
 
// Structure of Tree
class TreeNode {
 
    int data;
    TreeNode left;
    TreeNode right;
 
    // Constructor
    TreeNode(int data)
    {
        this.data = data;
        this.left = null;
        this.right = null;
    }
}
 
public class Main {
 
    // Function to insert a node in BST
    public static TreeNode insertNode(int data,
                                      TreeNode root)
    {
        // If root is null
        if (root == null) {
            TreeNode node = new TreeNode(data);
            return node;
        }
 
        // Insert the data in right half
        else if (data > root.data) {
            root.right = insertNode(
                data, root.right);
        }
 
        // Insert the data in left half
        else if (data <= root.data) {
            root.left = insertNode(
                data, root.left);
        }
 
        // Return the root node
        return root;
    }
 
    // Function to find the sum of K distant
    // nodes from the target node having
    // value less than target node
    public static void findSum(TreeNode root, int target,
                               int K)
    {
        // variable to store answer
        int ans = 0;
 
        // queue for bfs
        Queue<TreeNode> q = new LinkedList<>();
 
        q.add(root);
 
        // to store target node
        TreeNode need = null;
 
        // map to store parent of each node
        HashMap<TreeNode, TreeNode> m = new HashMap<>();
 
        // bfs
        while (!q.isEmpty()) {
 
            int s = q.size();
 
            // traversing to current level
            for (int i = 0; i < s; i++) {
 
                TreeNode temp = q.peek();
                q.remove();
 
                // if target value found
                if (temp.data == target)
                    need = temp;
 
                if (temp.left != null) {
                    q.add(temp.left);
                    m.put(temp.left, temp);
                }
 
                if (temp.right != null) {
                    q.add(temp.right);
                    m.put(temp.right, temp);
                }
            }
        }
 
        // map to store occurrence of a node
        // that is the node has taken or not
        HashMap<TreeNode, Integer> mm = new HashMap<>();
 
        q.add(need);
 
        // to store current distance
        int c = 0;
 
        while (!q.isEmpty()) {
 
            int s = q.size();
 
            for (int i = 0; i < s; i++) {
 
                TreeNode temp = q.peek();
                q.remove();
 
                mm.put(temp, 1);
 
                if (temp.data < target && c == K)
                    ans += temp.data;
 
                // moving left
                if (temp.left != null && mm.getOrDefault(temp.left, 0) == 0) {
                    q.add(temp.left);
                }
 
                // moving right
                if (temp.right != null && mm.getOrDefault(temp.right, 0) == 0) {
                    q.add(temp.right);
                }
 
                // moving to parent
                if (m.get(temp) != null && mm.getOrDefault(m.get(temp), 0) == 0) {
                    q.add(m.get(temp));
                }
            }
 
            c++;
            if (c > K)
                break;
        }
        System.out.println(ans);
    }
 
    // Driver Code
    public static void main(String[] args)
    {
        TreeNode root = null;
        int N = 11;
        int[] tree = { 3, 1, 7, 0, 2, 5, 10, 4, 6, 9, 8 };
 
        // Create the Tree
        for (int i = 0; i < N; i++) {
            root = insertNode(tree[i], root);
        }
 
        int target = 7;
        int K = 2;
        findSum(root, target, K);
    }
}


Python




from collections import deque
 
# Structure of Tree
class TreeNode:
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
# Function to insert a node in BST
def insertNode(data, root):
    # If root is None
    if root is None:
        return TreeNode(data)
 
    # Insert the data in the right half
    if data > root.data:
        root.right = insertNode(data, root.right)
 
    # Insert the data in the left half
    else:
        root.left = insertNode(data, root.left)
 
    # Return the root node
    return root
 
# Function to find the sum of K distant nodes from the target node having value less than target node
def findSum(root, target, K):
    # Variable to store the answer
    ans = 0
 
    # Queue for BFS
    q = deque()
    q.append(root)
 
    # To store the target node
    need = None
 
    # Dictionary to store parent of each node
    m = {}
 
    # BFS
    while q:
        s = len(q)
 
        # Traverse the current level
        for i in range(s):
            temp = q.popleft()
 
            # If the target value is found
            if temp.data == target:
                need = temp
 
            if temp.left:
                q.append(temp.left)
                m[temp.left] = temp
 
            if temp.right:
                q.append(temp.right)
                m[temp.right] = temp
 
    # Dictionary to store the occurrence of a node
    # that is whether the node has been visited or not
    mm = {}
    q.append(need)
 
    # To store the current distance
    c = 0
 
    while q:
        s = len(q)
 
        for i in range(s):
            temp = q.popleft()
            mm[temp] = 1
 
            if temp.data < target and c == K:
                ans += temp.data
 
            # Moving left
            if temp.left and mm.get(temp.left, 0) == 0:
                q.append(temp.left)
 
            # Moving right
            if temp.right and mm.get(temp.right, 0) == 0:
                q.append(temp.right)
 
            # Moving to parent
            if temp in m and mm.get(m[temp], 0) == 0:
                q.append(m[temp])
 
        c += 1
        if c > K:
            break
 
    print(ans)
 
# Driver Code
if __name__ == "__main__":
    root = None
    N = 11
    tree = [3, 1, 7, 0, 2, 5, 10, 4, 6, 9, 8]
 
    # Create the Tree
    for i in range(N):
        root = insertNode(tree[i], root)
 
    target = 7
    K = 2
    findSum(root, target, K)


C#




using System;
using System.Collections.Generic;
 
namespace BinaryTreeKDistanceSum
{
    class TreeNode
    {
        public int data;
        public TreeNode left;
        public TreeNode right;
 
        // Constructor
        public TreeNode(int data)
        {
            this.data = data;
            this.left = null;
            this.right = null;
        }
    }
 
    class Program
    {
        static TreeNode InsertNode(int data, TreeNode root)
        {
            if (root == null)
            {
                TreeNode node = new TreeNode(data);
                return node;
            }
            else if (data > root.data)
            {
                root.right = InsertNode(data, root.right);
            }
            else if (data <= root.data)
            {
                root.left = InsertNode(data, root.left);
            }
            return root;
        }
 
        static void FindSum(TreeNode root, int target, int K)
        {
            int ans = 0;
            Queue<TreeNode> q = new Queue<TreeNode>();
            q.Enqueue(root);
 
            TreeNode need = null;
            Dictionary<TreeNode, TreeNode> m = new Dictionary<TreeNode, TreeNode>();
            while (q.Count > 0)
            {
                int s = q.Count;
                for (int i = 0; i < s; i++)
                {
                    TreeNode temp = q.Dequeue();
 
                    if (temp.data == target) need = temp;
 
                    if (temp.left != null)
                    {
                        q.Enqueue(temp.left);
                        m[temp.left] = temp;
                    }
                    if (temp.right != null)
                    {
                        q.Enqueue(temp.right);
                        m[temp.right] = temp;
                    }
                }
            }
 
            Dictionary<TreeNode, int> mm = new Dictionary<TreeNode, int>();
            q.Enqueue(need);
            int c = 0;
            while (q.Count > 0)
            {
                int s = q.Count;
                for (int i = 0; i < s; i++)
                {
                    TreeNode temp = q.Dequeue();
                    mm[temp] = 1;
 
                    if (temp.data < target && c == K)
                    {
                        ans += temp.data;
                    }
 
                    if (temp.left != null && !mm.ContainsKey(temp.left))
                    {
                        q.Enqueue(temp.left);
                    }
                    if (temp.right != null && !mm.ContainsKey(temp.right))
                    {
                        q.Enqueue(temp.right);
                    }
                    if (m.ContainsKey(temp) && !mm.ContainsKey(m[temp]))
                    {
                        q.Enqueue(m[temp]);
                    }
                }
                c++;
                if (c > K) break;
            }
            Console.WriteLine(ans);
        }
 
        static void Main(string[] args)
        {
            TreeNode root = null;
            int N = 11;
            int[] tree = { 3, 1, 7, 0, 2, 5, 10, 4, 6, 9, 8 };
 
            for (int i = 0; i < N; i++)
            {
                root = InsertNode(tree[i], root);
            }
 
            int target = 7;
            int K = 2;
            FindSum(root, target, K);
        }
    }
}


Javascript




class TreeNode {
    constructor(data) {
        this.data = data;
        this.left = null;
        this.right = null;
    }
}
 
// Function to insert a node into the Binary Search Tree (BST)
function insertNode(data, root) {
    if (root === null) {
        const node = new TreeNode(data);
        return node;
    }
 
    if (data > root.data) {
        root.right = insertNode(data, root.right);
    } else if (data <= root.data) {
        root.left = insertNode(data, root.left);
    }
 
    return root;
}
 
// Function to find the sum of K distant nodes from the target
// node having values less than the target node
function findSum(root, target, K) {
    let ans = 0; // Variable to store the sum
    const queue = [root]; // Initialize a queue for BFS
    let need = null; // Node to store the target node
    const parentMap = new Map(); // Map to store the parent of each node
 
    // Perform BFS to find the target node and build the parent map
    while (queue.length > 0) {
        const levelSize = queue.length;
 
        for (let i = 0; i < levelSize; i++) {
            const temp = queue.shift();
 
            if (temp.data === target) {
                need = temp; // Found the target node
            }
 
            if (temp.left !== null) {
                queue.push(temp.left);
                parentMap.set(temp.left, temp);
            }
 
            if (temp.right !== null) {
                queue.push(temp.right);
                parentMap.set(temp.right, temp);
            }
        }
    }
 
    const mm = new Map(); // Map to track visited nodes
    queue.push(need); // Initialize the queue with the target node
    let distance = 0; // Initialize the distance from the target node
 
    // Perform BFS to calculate the sum of nodes at distance K from the target node
    while (queue.length > 0) {
        const levelSize = queue.length;
 
        for (let i = 0; i < levelSize; i++) {
            const temp = queue.shift();
            mm.set(temp, 1); // Mark the node as visited
 
            if (temp.data < target && distance === K) {
                ans += temp.data; // Add the node's value to the sum if it meets the criteria
            }
 
            if (temp.left !== null && mm.get(temp.left) !== 1) {
                queue.push(temp.left); // Explore the left child if not visited
            }
 
            if (temp.right !== null && mm.get(temp.right) !== 1) {
                queue.push(temp.right); // Explore the right child if not visited
            }
 
            if (parentMap.has(temp) && mm.get(parentMap.get(temp)) !== 1) {
                queue.push(parentMap.get(temp)); // Explore the parent if not visited
            }
        }
 
        distance++;
 
        if (distance > K) {
            break; // Stop the BFS when the desired distance is reached
        }
    }
 
    console.log(ans); // Print the final sum
}
 
// Driver Code
function main() {
    let root = null;
    const N = 11;
    const tree = [3, 1, 7, 0, 2, 5, 10, 4, 6, 9, 8];
 
    // Create the BST by inserting elements from the 'tree' array
    for (let i = 0; i < N; i++) {
        root = insertNode(tree[i], root);
    }
 
    const target = 7;
    const K = 2;
    findSum(root, target, K); // Find and display the sum
}
 
main();


Output

11







Time Complexity:- O(N)
Auxiliary Space:- O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads