Open In App

Number of pairs with a given sum in a Binary Search Tree

Given a Binary Search Tree, and a number X. The task is to find the number of distinct pairs of distinct nodes in BST with a sum equal to X. No two nodes have the same values.

Examples: 



Input : X = 5
          5 
        /   \ 
       3     7 
      / \   / \ 
     2   4 6   8    
Output : 1
{2, 3} is the only possible pair. 
Thus, the answer is equal to 1.

Input : X = 6
      1
       \
        2
         \
          3
           \
            4
             \
              5
   
Output : 2
Possible pairs are {{1, 5}, {2, 4}}.

Naive Approach: The idea is to hash all the elements of BST or convert the BST to a sorted array. After that find the number of pairs using the algorithm given here
Time Complexity: O(N). 
Space Complexity: O(N).

Space Optimized Approach : The idea is to use two pointer technique on BST. Maintain forward and backward iterators that will iterate the BST in the order of in-order and reverse in-order traversal respectively. 



  1. Create forward and backward iterators for BST. Let’s say that the value of nodes they are pointing at are equal to v1 and v2 respectively.
  2. Now at each step, 
    • If v1 + v2 = X, the pair is found, thus increase the count by 1.
    • If v1 + v2 is less than or equal to x, we will make forward iterator point to the next element.
    • If v1 + v2 is greater than x, we will make backward iterator point to the previous element.
  3. Repeat the above steps until both iterators don’t point to the same node.

Below is the implementation of the above approach:  




// C++ implementation of the approach
#include <bits/stdc++.h>
using namespace std;
 
// Node of Binary tree
struct node {
    int data;
    node* left;
    node* right;
    node(int data)
    {
        this->data = data;
        left = NULL;
        right = NULL;
    }
};
 
// Function to find a pair
int cntPairs(node* root, int x)
{
    // Stack to store nodes for
    // forward and backward iterator
    stack<node *> it1, it2;
 
    // Initializing forward iterator
    node* c = root;
    while (c != NULL)
        it1.push(c), c = c->left;
 
    // Initializing backward iterator
    c = root;
    while (c != NULL)
        it2.push(c), c = c->right;
 
    // Variable to store final answer
    int ans = 0;
 
    // two pointer technique
    while (it1.top() != it2.top()) {
 
        // Variables to store the
        // value of the nodes current
        // iterators are pointing to.
        int v1 = it1.top()->data;
        int v2 = it2.top()->data;
 
        // If we find a pair
        // then count is increased by 1
        if (v1 + v2 == x)
            ans++;
 
        // Moving forward iterator
        if (v1 + v2 <= x) {
            c = it1.top()->right;
            it1.pop();
            while (c != NULL)
                it1.push(c), c = c->left;
        }
 
        // Moving backward iterator
        else {
            c = it2.top()->left;
            it2.pop();
            while (c != NULL)
                it2.push(c), c = c->right;
        }
    }
 
    // Returning final answer
    return ans;
}
 
// Driver code
int main()
{
    /*    5
        /   \
       3     7
      / \   / \
     2   4 6   8 */
    node* root = new node(5);
    root->left = new node(3);
    root->right = new node(7);
    root->left->left = new node(2);
    root->left->right = new node(4);
    root->right->left = new node(6);
    root->right->right = new node(8);
 
    int x = 10;
 
    cout << cntPairs(root, x);
 
    return 0;
}




// Java implementation of the approach
import java.util.*;
 
class GFG
{
 
// Node of Binary tree
static class node
{
    int data;
    node left;
    node right;
    node(int data)
    {
        this.data = data;
        left = null;
        right = null;
    }
};
 
// Function to find a pair
static int cntPairs(node root, int x)
{
    // Stack to store nodes for
    // forward and backward iterator
    Stack<node > it1 = new Stack<>();
    Stack<node > it2 = new Stack<>();
 
    // Initializing forward iterator
    node c = root;
    while (c != null)
    {
        it1.push(c);
        c = c.left;
    }
 
    // Initializing backward iterator
    c = root;
    while (c != null)
    {
        it2.push(c);
        c = c.right;
    }
     
    // Variable to store final answer
    int ans = 0;
 
    // two pointer technique
    while (it1.peek() != it2.peek())
    {
 
        // Variables to store the
        // value of the nodes current
        // iterators are pointing to.
        int v1 = it1.peek().data;
        int v2 = it2.peek().data;
 
        // If we find a pair
        // then count is increased by 1
        if (v1 + v2 == x)
            ans++;
 
        // Moving forward iterator
        if (v1 + v2 <= x)
        {
            c = it1.peek().right;
            it1.pop();
            while (c != null)
            {
                it1.push(c);
                c = c.left;
            }
        }
 
        // Moving backward iterator
        else
        {
            c = it2.peek().left;
            it2.pop();
            while (c != null)
            {
                it2.push(c);
                c = c.right;
            }
        }
    }
 
    // Returning final answer
    return ans;
}
 
// Driver code
public static void main(String[] args)
{
    /* 5
        / \
    3     7
    / \ / \
    2 4 6 8 */
    node root = new node(5);
    root.left = new node(3);
    root.right = new node(7);
    root.left.left = new node(2);
    root.left.right = new node(4);
    root.right.left = new node(6);
    root.right.right = new node(8);
 
    int x = 10;
 
    System.out.print(cntPairs(root, x));
}
}
 
// This code is contributed by Rajput-Ji




# Python implementation of above algorithm
 
# Utility class to create a node
class node:
    def __init__(self, key):
        self.data = key
        self.left = self.right = None
 
# Function to find a pair
def cntPairs( root, x):
 
    # Stack to store nodes for
    # forward and backward iterator
    it1 = []
    it2 = []
 
    # Initializing forward iterator
    c = root
    while (c != None):
        it1.append(c)
        c = c.left
 
    # Initializing backward iterator
    c = root
    while (c != None):
        it2.append(c)
        c = c.right
 
    # Variable to store final answer
    ans = 0
 
    # two pointer technique
    while (it1[-1] != it2[-1]) :
 
        # Variables to store the
        # value of the nodes current
        # iterators are pointing to.
        v1 = it1[-1].data
        v2 = it2[-1].data
 
        # If we find a pair
        # then count is increased by 1
        if (v1 + v2 == x):
            ans=ans+1
 
        # Moving forward iterator
        if (v1 + v2 <= x) :
            c = it1[-1].right
            it1.pop()
            while (c != None):
                it1.append(c)
                c = c.left
         
        # Moving backward iterator
        else :
            c = it2[-1].left
            it2.pop()
            while (c != None):
                it2.append(c)
                c = c.right
         
    # Returning final answer
    return ans
 
# Driver code
 
#         5
#     / \
#     3     7
#     / \ / \
#     2 4 6 8
root = node(5)
root.left = node(3)
root.right = node(7)
root.left.left = node(2)
root.left.right = node(4)
root.right.left = node(6)
root.right.right = node(8)
 
x = 10
 
print(cntPairs(root, x))
 
# This code is contributed by Arnab Kundu




// C# implementation of the approach
using System;
using System.Collections.Generic;
 
class GFG
{
 
// Node of Binary tree
public class node
{
    public int data;
    public node left;
    public node right;
    public node(int data)
    {
        this.data = data;
        left = null;
        right = null;
    }
};
 
// Function to find a pair
static int cntPairs(node root, int x)
{
    // Stack to store nodes for
    // forward and backward iterator
    Stack<node > it1 = new Stack<node>();
    Stack<node > it2 = new Stack<node>();
 
    // Initializing forward iterator
    node c = root;
    while (c != null)
    {
        it1.Push(c);
        c = c.left;
    }
 
    // Initializing backward iterator
    c = root;
    while (c != null)
    {
        it2.Push(c);
        c = c.right;
    }
     
    // Variable to store readonly answer
    int ans = 0;
 
    // two pointer technique
    while (it1.Peek() != it2.Peek())
    {
 
        // Variables to store the
        // value of the nodes current
        // iterators are pointing to.
        int v1 = it1.Peek().data;
        int v2 = it2.Peek().data;
 
        // If we find a pair
        // then count is increased by 1
        if (v1 + v2 == x)
            ans++;
 
        // Moving forward iterator
        if (v1 + v2 <= x)
        {
            c = it1.Peek().right;
            it1.Pop();
            while (c != null)
            {
                it1.Push(c);
                c = c.left;
            }
        }
 
        // Moving backward iterator
        else
        {
            c = it2.Peek().left;
            it2.Pop();
            while (c != null)
            {
                it2.Push(c);
                c = c.right;
            }
        }
    }
 
    // Returning readonly answer
    return ans;
}
 
// Driver code
public static void Main(String[] args)
{
    /* 5
        / \
    3     7
    / \ / \
    2 4 6 8 */
    node root = new node(5);
    root.left = new node(3);
    root.right = new node(7);
    root.left.left = new node(2);
    root.left.right = new node(4);
    root.right.left = new node(6);
    root.right.right = new node(8);
 
    int x = 10;
 
    Console.Write(cntPairs(root, x));
}
}
 
// This code is contributed by Rajput-Ji




<script>
 
// Javascript implementation of the approach
 
// Node of Binary tree
class node
{
    constructor(data)
    {
        this.data = data;
        this.left = null;
        this.right = null;
    }
};
 
// Function to find a pair
function cntPairs(root, x)
{
     
    // Stack to store nodes for
    // forward and backward iterator
    var it1 = [];
    var it2 = [];
 
    // Initializing forward iterator
    var c = root;
     
    while (c != null)
    {
        it1.push(c);
        c = c.left;
    }
 
    // Initializing backward iterator
    c = root;
    while (c != null)
    {
        it2.push(c);
        c = c.right;
    }
     
    // Variable to store readonly answer
    var ans = 0;
 
    // two pointer technique
    while (it1[it1.length - 1] != it2[it2.length - 1])
    {
         
        // Variables to store the
        // value of the nodes current
        // iterators are pointing to.
        var v1 = it1[it1.length - 1].data;
        var v2 = it2[it2.length - 1].data;
 
        // If we find a pair
        // then count is increased by 1
        if (v1 + v2 == x)
            ans++;
 
        // Moving forward iterator
        if (v1 + v2 <= x)
        {
            c = it1[it1.length - 1].right;
            it1.pop();
             
            while (c != null)
            {
                it1.push(c);
                c = c.left;
            }
        }
 
        // Moving backward iterator
        else
        {
            c = it2[it2.length - 1].left;
            it2.pop();
             
            while (c != null)
            {
                it2.push(c);
                c = c.right;
            }
        }
    }
 
    // Returning readonly answer
    return ans;
}
 
// Driver code
/* 5
    / \
3     7
/ \ / \
2 4 6 8 */
var root = new node(5);
root.left = new node(3);
root.right = new node(7);
root.left.left = new node(2);
root.left.right = new node(4);
root.right.left = new node(6);
root.right.right = new node(8);
var x = 10;
 
document.write(cntPairs(root, x));
 
// This code is contributed by noob2000
 
</script>

Output: 
3

 

Time complexity: O(N) 
Space complexity: O(H) where H is the height of BST
 


Article Tags :