Open In App

Count all k-sum paths in a Binary Tree

Last Updated : 17 Apr, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

Given a binary tree and an integer k. The task is to count the number of paths in the tree with the sum of the nodes equals to k
A path can start from any node and end at any node and must be downward only, i.e. they need not be root node and leaf node, and negative numbers can also be there in the tree.

Examples:  

Input : k = 5  
Root of below binary tree:
1
/ \
3 -1
/ \ / \
2 1 4 5
/ / \ \
1 1 2 6

Output : No of paths with sum equals to 5 are: 8
3 2
3 1 1
1 3 1
4 1
1 -1 4 1
-1 4 2
5
1 -1 5
Input : k = 3
1
/ \
2 -1
/ \ /
1 2 3
/ \
2 5
Output : No of paths with sum equals to 3 are : 4

Approach : 
The implementation of printing paths with path sum equal to k is already discussed in this post using vector. However, storing in vector and recursively moving increases both space and time complexity if we just want to store the count of such paths
An efficient implementation of the above problem can be done in using backtracking and unordered maps

Steps:  

  • We will be using a unordered map which will be filled with various path sum.
  • For every node we will check if current sum and root’s value equal to k or not. If the sum equals to k then increment the required answer by one.
  • Then we will add all those path sum in map which differs from current sum+root->data value by a constant integer k.
  • Then we will be inserting the current sum + root->data value in the map.
  • We will recursively check for left and right subtrees of current root
  • After the right subtree is also traversed we will remove the current sum + root->data value from the map so that it is not taken into consideration in further traversals of other nodes other than the current root’s

Below is the implementation of the above approach:

C++
#include <iostream>
#include <unordered_map>

using namespace std;

static int res = 0;

// tree structure
struct Node {
    int data;
    Node *left, *right;

    Node(int data) : data(data), left(nullptr), right(nullptr) {}
};

// Function to backtrack the tree path and add each path sum in the unordered map
static void k_paths(Node *root, int k, unordered_map<int, int> &p, int sum) {
    // If root is not null
    if (root != nullptr) {
        // If root value and previous sum equal to k then increase the count
        if (sum + root->data == k)
            res++;

        // Add those values also which differ by the current sum and root data by k
        res += p[sum + root->data - k];

        // Insert the sum + root value in the map
        p[sum + root->data]++;

        // Move to left and right trees
        k_paths(root->left, k, p, sum + root->data);
        k_paths(root->right, k, p, sum + root->data);

        // Remove the sum + root.data value from the map if they are not required further
        // or they do not sum up to k in any way
        p[sum + root->data]--;
    }
}

// Function to print the count of paths with sum equals to k
static int printCount(Node *root, int k) {
    // To store the sum
    unordered_map<int, int> p;

    // Function call
    k_paths(root, k, p, 0);

    // Return the required answer
    return res;
}

int main() {
    res = 0;
    Node *root = new Node(1);
    root->left = new Node(2);
    root->left->left = new Node(1);
    root->left->right = new Node(2);
    root->right = new Node(-1);
    root->right->left = new Node(3);
    root->right->left->left = new Node(2);
    root->right->left->right = new Node(5);

    int k = 5;
    cout << "No of paths with sum equals to " << k << " are: " << printCount(root, k) << endl;

    // Deallocate memory
    delete root->right->left->right;
    delete root->right->left->left;
    delete root->right->left;
    delete root->right;
    delete root->left->right;
    delete root->left->left;
    delete root->left;
    delete root;

    return 0;
}
Java
// Java program to print all root to leaf
// paths with there relative position
import java.util.ArrayList;
import java.util.HashMap;

class Graph{

static int res;

// tree structure
static class Node
{
    int data;
    Node left, right;

    public Node(int data)
    {
        this.data = data;
        this.left = this.right = null;
    }
};

// Function to backtrack the tree path and
// add each path sum in the unordered map
static void k_paths(Node root, int k, 
  HashMap<Integer, Integer> p, int sum) 
{
    
    // If root is not null
    if (root != null)
    {
        
        // If root value and previous sum equal
        // to k then increase the count
        if (sum + root.data == k)
            res++;

        // Add those values also which differs
        // by the current sum and root data by k
        res += p.get(sum + root.data - k) == null ? 
           0 : p.get(sum + root.data - k);

        // Insert the sum + root value in the map
        if (!p.containsKey(sum + root.data)) 
        {
            p.put(sum + root.data, 0);
        }
        p.put(sum + root.data, 
        p.get(sum + root.data) + 1);

        // Move to left and right trees
        k_paths(root.left, k, p,
                sum + root.data);
        k_paths(root.right, k, p, 
                sum + root.data);

        // Remove the sum + root.data value 
        // from the map if they are n not 
        // required further or they do no
        // sum up to k in any way
        p.put(sum + root.data, 
        p.get(sum + root.data) - 1);
    }
}

// Function to print the count
// of paths with sum equals to k
static int printCount(Node root, int k)
{
    
    // To store the sum
    HashMap<Integer, Integer> p = new HashMap<>();

    // Function call
    k_paths(root, k, p, 0);

    // Return the required answer
    return res;
}

// Driver code
public static void main(String[] args)
{
    res = 0;
    Node root = new Node(1);
    root.left = new Node(2);
    root.left.left = new Node(1);
    root.left.right = new Node(2);
    root.right = new Node(-1);
    root.right.left = new Node(3);
    root.right.left.left = new Node(2);
    root.right.left.right = new Node(5);

    int k = 3;
    System.out.printf("No of paths with sum " + 
                      "equals to %d are: %d\n", k,
                      printCount(root, k));
}
}

// This code is contributed by sanjeev2552
Python3
from functools import reduce
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass


@dataclass
class TreeNode:
    value: int
    left: Optional["TreeNode"]
    right: Optional["TreeNode"]


def dfs_k_path_sum(root: TreeNode, k: int) -> Tuple[List[int], int]:
    """
    This is a post-order DFS algorithm, it runs one this function
    exactly once per node, therefore it's called O(n) times,
    where n is the number of nodes.

    There are two aggregation functions, one for summing up the result,
    this is O(1) per dfs_k_path_sum function call, and another for
    storing all possible sum paths from the root node
    (to pass it on to upper nodes) which is O(n * logn) for balanced trees
    and O(n^2) for unbalanced trees.
    """
    left_sum_paths, left_res = dfs_k_path_sum(root.left, k) if root.left else ([], 0)
    right_sum_paths, right_res = (
        dfs_k_path_sum(root.right, k) if root.right else ([], 0)
    )

    return_sum_paths = [
        child_sum + root.value for child_sum in left_sum_paths + right_sum_paths
    ] + [root.value]

    return_res = (
        left_res + right_res + sum(1 for sum_path in return_sum_paths if sum_path == k)
    )

    print(f"traversing {root} left: {left_sum_paths} right: {right_sum_paths}")
    print(f"agg_paths: {return_sum_paths}, agg_res: {return_res}")

    return (return_sum_paths, return_res)


node5 = TreeNode(1, None, None)
node4 = TreeNode(1, node5, None)
node3 = TreeNode(2, None, None)
node2 = TreeNode(3, node3, node4)
node11 = TreeNode(6, None, None)
node10 = TreeNode(5, None, node11)
node9 = TreeNode(2, None, None)
node8 = TreeNode(1, None, None)
node7 = TreeNode(4, node8, node9)
node6 = TreeNode(-1, node7, node10)
node1 = TreeNode(1, node2, node6)

k = 5
paths, result = dfs_k_path_sum(node1, k)
print(paths)
print(f"result: {result}")
print(f"sum paths: {paths}")
print(f"space size is: {len(paths)}")

# Contribution by Diego Tsutsumi (diego0mpq)
C#
//C# code for the above approach
using System;
using System.Collections.Generic;

class Graph
{
    static int res;

    // tree structure
    class Node
    {
        public int data;
        public Node left, right;

        public Node(int data)
        {
            this.data = data;
            this.left = this.right = null;
        }
    }

    // Function to backtrack the tree path and
    // add each path sum in the dictionary
    static void k_paths(Node root, int k, 
        Dictionary<int, int> p, int sum) 
    {
        
        // If root is not null
        if (root != null)
        {
            
            // If root value and previous sum equal
            // to k then increase the count
            if (sum + root.data == k)
                res++;

            // Add those values also which differs
            // by the current sum and root data by k
            res += p.ContainsKey(sum + root.data - k) ? 
                p[sum + root.data - k] : 0;

            // Insert the sum + root value in the dictionary
            if (!p.ContainsKey(sum + root.data)) 
            {
                p.Add(sum + root.data, 0);
            }
            p[sum + root.data]++;

            // Move to left and right trees
            k_paths(root.left, k, p,
                    sum + root.data);
            k_paths(root.right, k, p, 
                    sum + root.data);

            // Remove the sum + root.data value 
            // from the dictionary if they are not 
            // required further or they do no
            // sum up to k in any way
            p[sum + root.data]--;
        }
    }

    // Function to print the count
    // of paths with sum equals to k
    static int printCount(Node root, int k)
    {
        
        // To store the sum
        Dictionary<int, int> p = new Dictionary<int, int>();

        // Function call
        k_paths(root, k, p, 0);

        // Return the required answer
        return res;
    }

    // Driver code
    static void Main(string[] args)
    {
        res = 0;
        Node root = new Node(1);
        root.left = new Node(2);
        root.left.left = new Node(1);
        root.left.right = new Node(2);
        root.right = new Node(-1);
        root.right.left = new Node(3);
        root.right.left.left = new Node(2);
        root.right.left.right = new Node(5);

        int k = 3;
        Console.WriteLine("No of paths with sum " + 
                          "equals to {0} are: {1}", k,
                          printCount(root, k));
    }
}
//This code is contributed by Potta Lokesh
Javascript
<script>

// Javascript program to print all root to leaf
// paths with there relative position
let res;

// Tree structure
class Node
{
    constructor(data) 
    {
        this.left = null;
        this.right = null;
        this.data = data;
    }
}

// Function to backtrack the tree path and
// add each path sum in the unordered map
function k_paths(root, k, p, sum)
{
    
    // If root is not null
    if (root != null)
    {

        // If root value and previous sum equal
        // to k then increase the count
        if (sum + root.data == k)
            res++;

        // Add those values also which differs
        // by the current sum and root data by k
        res += p.get(sum + root.data - k) == null ? 
           0 : p.get(sum + root.data - k);

        // Insert the sum + root value in the map
        if (!p.has(sum + root.data))
        {
            p.set(sum + root.data, 0);
        }
        p.set(sum + root.data,
        p.get(sum + root.data) + 1);

        // Move to left and right trees
        k_paths(root.left, k, p,
                sum + root.data);
        k_paths(root.right, k, p,
                sum + root.data);

        // Remove the sum + root.data value
        // from the map if they are n not
        // required further or they do no
        // sum up to k in any way
        p.set(sum + root.data,
        p.get(sum + root.data) - 1);
    }
}

// Function to print the count
// of paths with sum equals to k
function printCount(root, k)
{
    
    // To store the sum
    let p = new Map();

    // Function call
    k_paths(root, k, p, 0);

    // Return the required answer
    return res;
}

// Driver code
res = 0;
let root = new Node(1);
root.left = new Node(2);
root.left.left = new Node(1);
root.left.right = new Node(2);
root.right = new Node(-1);
root.right.left = new Node(3);
root.right.left.left = new Node(2);
root.right.left.right = new Node(5);

let k = 3;
document.write("No of paths with sum " +
               "equals to " + k + " are: " + 
               printCount(root, k));

// This code is contributed by divyeshrabadiya07

</script>

Output
No of paths with sum equals to 3 are : 4

Time Complexity: O(N), where N is the number of nodes in the tree
Space Complexity: O(N), where N is the number of nodes in the tree 



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads