Open In App

Count all k-sum paths in a Binary Tree

Last Updated : 13 May, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

Given a binary tree and an integer k. The task is to count the number of paths in the tree with the sum of the nodes equals to k
A path can start from any node and end at any node and must be downward only, i.e. they need not be root node and leaf node, and negative numbers can also be there in the tree.

Examples:  

Input : k = 5  
Root of below binary tree:
1
/ \
3 -1
/ \ / \
2 1 4 5
/ / \ \
1 1 2 6

Output : No of paths with sum equals to 5 are: 8
3 2
3 1 1
1 3 1
4 1
1 -1 4 1
-1 4 2
5
1 -1 5
Input : k = 3
1
/ \
2 -1
/ \ /
1 2 3
/ \
2 5
Output : No of paths with sum equals to 3 are : 4

Approach : 
The implementation of printing paths with path sum equal to k is already discussed in this post using vector. However, storing in vector and recursively moving increases both space and time complexity if we just want to store the count of such paths
An efficient implementation of the above problem can be done in using backtracking and unordered maps

Steps:  

  • We will be using a unordered map which will be filled with various path sum.
  • For every node we will check if current sum and root’s value equal to k or not. If the sum equals to k then increment the required answer by one.
  • Then we will add all those path sum in map which differs from current sum+root->data value by a constant integer k.
  • Then we will be inserting the current sum + root->data value in the map.
  • We will recursively check for left and right subtrees of current root
  • After the right subtree is also traversed we will remove the current sum + root->data value from the map so that it is not taken into consideration in further traversals of other nodes other than the current root’s
  • We need to count the negative components.
  • We need to consider all the traversal, even if their sum temporarily becomes negative.

Below is the implementation of the above approach:

C++
#include <iostream>
#include <vector>

using namespace std;

struct TreeNode {
    int value;
    TreeNode* left;
    TreeNode* right;
    TreeNode(int val): value(val), left(nullptr), right(nullptr) {}
};

pair<vector<int>, int> dfs_k_path_sum(TreeNode* root, int k) {
    if (!root)
        return {{}, 0};

    auto [left_paths, left_count] = dfs_k_path_sum(root->left, k);
    auto [right_paths, right_count] = dfs_k_path_sum(root->right, k);

    vector<int> all_paths(left_paths);
    all_paths.insert(all_paths.end(), right_paths.begin(), right_paths.end());
    all_paths.push_back(root->value);

    int count = 0;
    for (int i = 0; i < all_paths.size(); ++i) {
        for (int j = i; j < all_paths.size(); ++j) {
            int sum = 0;
            for (int l = i; l <= j; ++l) {
                sum += all_paths[l];
            }
            if (sum == k)
                ++count;
        }
    }

    return {all_paths, left_count + right_count + count};
}

int main() {
    TreeNode* node5 = new TreeNode(1);
    TreeNode* node4 = new TreeNode(1);
    TreeNode* node3 = new TreeNode(2);
    TreeNode* node2 = new TreeNode(3);
    node2->left = node3;
    node2->right = node4;
    TreeNode* node11 = new TreeNode(6);
    TreeNode* node10 = new TreeNode(5);
    node10->right = node11;
    TreeNode* node9 = new TreeNode(2);
    TreeNode* node8 = new TreeNode(1);
    TreeNode* node7 = new TreeNode(4);
    node7->left = node8;
    node7->right = node9;
    TreeNode* node6 = new TreeNode(-1);
    node6->left = node7;
    node6->right = node10;
    TreeNode* node1 = new TreeNode(1);
    node1->left = node2;
    node1->right = node6;

    int k = 5;
    auto [paths, result] = dfs_k_path_sum(node1, k);
    cout << "Result: " << result << endl;
    cout << "Paths: ";
    for (auto path : paths) {
        cout << path << " ";
    }
    cout << endl;

    return 0;
}
Java
import java.util.ArrayList;
import java.util.List;

class TreeNode {
    int value;
    TreeNode left;
    TreeNode right;

    TreeNode(int val) {
        value = val;
    }
}

public class Main {
    public static List<Integer> dfs_k_path_sum(TreeNode root, int k) {
        if (root == null)
            return new ArrayList<>();

        List<Integer> leftPaths = dfs_k_path_sum(root.left, k);
        List<Integer> rightPaths = dfs_k_path_sum(root.right, k);

        List<Integer> allPaths = new ArrayList<>(leftPaths);
        allPaths.addAll(rightPaths);
        allPaths.add(root.value);

        List<Integer> paths = new ArrayList<>();
        for (int i = 0; i < allPaths.size(); i++) {
            for (int j = i; j < allPaths.size(); j++) {
                int sum = 0;
                for (int l = i; l <= j; l++) {
                    sum += allPaths.get(l);
                }
                if (sum == k)
                    paths.add(sum);
            }
        }

        return paths;
    }

    public static void main(String[] args) {
        TreeNode node5 = new TreeNode(1);
        TreeNode node4 = new TreeNode(1);
        TreeNode node3 = new TreeNode(2);
        TreeNode node2 = new TreeNode(3);
        node2.left = node3;
        node2.right = node4;
        TreeNode node11 = new TreeNode(6);
        TreeNode node10 = new TreeNode(5);
        node10.right = node11;
        TreeNode node9 = new TreeNode(2);
        TreeNode node8 = new TreeNode(1);
        TreeNode node7 = new TreeNode(4);
        node7.left = node8;
        node7.right = node9;
        TreeNode node6 = new TreeNode(-1);
        node6.left = node7;
        node6.right = node10;
        TreeNode node1 = new TreeNode(1);
        node1.left = node2;
        node1.right = node6;

        int k = 5;
        List<Integer> paths = dfs_k_path_sum(node1, k);
        System.out.println("Paths: " + paths);
    }
}
Python
class TreeNode:
    def __init__(self, val):
        self.value = val
        self.left = None
        self.right = None

def dfs_k_path_sum(root, k):
    if not root:
        return [], 0

    left_paths, left_count = dfs_k_path_sum(root.left, k)
    right_paths, right_count = dfs_k_path_sum(root.right, k)

    all_paths = left_paths + right_paths + [root.value]

    count = 0
    for i in range(len(all_paths)):
        for j in range(i, len(all_paths)):
            if sum(all_paths[i:j+1]) == k:
                count += 1

    return all_paths, left_count + right_count + count

# Example tree construction
node5 = TreeNode(1)
node4 = TreeNode(1)
node3 = TreeNode(2)
node2 = TreeNode(3)
node2.left = node3
node2.right = node4
node11 = TreeNode(6)
node10 = TreeNode(5)
node10.right = node11
node9 = TreeNode(2)
node8 = TreeNode(1)
node7 = TreeNode(4)
node7.left = node8
node7.right = node9
node6 = TreeNode(-1)
node6.left = node7
node6.right = node10
node1 = TreeNode(1)
node1.left = node2
node1.right = node6

# Target sum
k = 5

# Call the function and print the result
paths, result = dfs_k_path_sum(node1, k)
print("Result:", result)
print("Paths:", paths)

# This code is contributed by Ayush Mishra
C#
using System;
using System.Collections.Generic;

public class TreeNode {
    public int value;
    public TreeNode left;
    public TreeNode right;

    public TreeNode(int val) {
        value = val;
    }
}

public class Program {
    public static (List<int>, int) DfsKPathSum(TreeNode root, int k) {
        if (root == null)
            return (new List<int>(), 0);

        var leftResult = DfsKPathSum(root.left, k);
        var rightResult = DfsKPathSum(root.right, k);

        List<int> allPaths = new List<int>(leftResult.Item1);
        allPaths.AddRange(rightResult.Item1);
        allPaths.Add(root.value);

        int count = 0;
        for (int i = 0; i < allPaths.Count; i++) {
            for (int j = i; j < allPaths.Count; j++) {
                int sum = 0;
                for (int l = i; l <= j; l++) {
                    sum += allPaths[l];
                }
                if (sum == k)
                    count++;
            }
        }

        return (allPaths, leftResult.Item2 + rightResult.Item2 + count);
    }

    public static void Main() {
        TreeNode node5 = new TreeNode(1);
        TreeNode node4 = new TreeNode(1);
        TreeNode node3 = new TreeNode(2);
        TreeNode node2 = new TreeNode(3);
        node2.left = node3;
        node2.right = node4;
        TreeNode node11 = new TreeNode(6);
        TreeNode node10 = new TreeNode(5);
        node10.right = node11;
        TreeNode node9 = new TreeNode(2);
        TreeNode node8 = new TreeNode(1);
        TreeNode node7 = new TreeNode(4);
        node7.left = node8;
        node7.right = node9;
        TreeNode node6 = new TreeNode(-1);
        node6.left = node7;
        node6.right = node10;
        TreeNode node1 = new TreeNode(1);
        node1.left = node2;
        node1.right = node6;

        int k = 5;
        var result = DfsKPathSum(node1, k);
        Console.WriteLine("Result: " + result.Item2);
        Console.WriteLine("Paths: " + string.Join(", ", result.Item1));
    }
}
Javascript
class TreeNode {
    constructor(val) {
        this.value = val;
        this.left = null;
        this.right = null;
    }
}

function dfs_k_path_sum(root, k) {
    if (!root)
        return [[], 0];

    const [left_paths, left_count] = dfs_k_path_sum(root.left, k);
    const [right_paths, right_count] = dfs_k_path_sum(root.right, k);

    const all_paths = [...left_paths, ...right_paths, root.value];

    let count = 0;
    for (let i = 0; i < all_paths.length; ++i) {
        for (let j = i; j < all_paths.length; ++j) {
            let sum = 0;
            for (let l = i; l <= j; ++l) {
                sum += all_paths[l];
            }
            if (sum === k)
                ++count;
        }
    }

    return [all_paths, left_count + right_count + count];
}

// Example tree construction
const node5 = new TreeNode(1);
const node4 = new TreeNode(1);
const node3 = new TreeNode(2);
const node2 = new TreeNode(3);
node2.left = node3;
node2.right = node4;
const node11 = new TreeNode(6);
const node10 = new TreeNode(5);
node10.right = node11;
const node9 = new TreeNode(2);
const node8 = new TreeNode(1);
const node7 = new TreeNode(4);
node7.left = node8;
node7.right = node9;
const node6 = new TreeNode(-1);
node6.left = node7;
node6.right = node10;
const node1 = new TreeNode(1);
node1.left = node2;
node1.right = node6;

const k = 5;
const [paths, result] = dfs_k_path_sum(node1, k);
console.log("Result:", result);
console.log("Paths:", paths);

// This code is contributed by Ayush Mishra

Output
No of paths with sum equals to 3 are : 4

Time Complexity: O(N), where N is the number of nodes in the tree
Space Complexity: O(N), where N is the number of nodes in the tree 



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads