Count all Grandparent-Parent-Child Triplets in a binary tree whose sum is greater than X

Given an integer X and a binary tree, the task is to count the number of triplet triplets of nodes such that their sum is greater than X and they have a grandparent -> parent -> child relationship.


Example:

Input: X = 100

               10
         /           \
       1              22
    /    \          /    \
  35      4       15       67
 /  \    /  \    /  \     /  \
57  38  9   10  110 312  131 414
                    /          \
                   8            39
Output: 6
The triplets are:
22 -> 15 -> 110
22 -> 15 -> 312
15 -> 312 -> 8
22 -> 67 -> 131
22 -> 67 -> 414
67 -> 414 -> 39



Approach: The problem can be solved using rolling sum up approach with rolling period as 3 (grandparent -> parent -> child)

  1. Traverse the tree in preorder or postorder (INORDER WON’T WORK)
  2. Maintain a stack where we maintain rolling sum with rolling period as 3
  3. Whenever we have more than 3 elements in stack and if the top most value is greater than X, we increment result by 1.
  4. When we move up the recursion tree we do a POP operation on stack so that all the rolling sums of lower levels get removed from the stack.

Below is the implementation of the above approach:

filter_none

edit
close

play_arrow

link
brightness_4
code

# Python3 implementation of the approach
  
# Class to store node information
class Node:
    def __init__(self, val = None, left = None, right = None):
        self.val = val
        self.left = left
        self.right = right
  
# Stack to perform stack operations
class Stack:
    def __init__(self):
        self.stack = []
  
    @property
    def size(self):
        return len(self.stack)
  
    def top(self):
        return self.stack[self.size - 1] if self.size > 0 else 0
  
    def push(self, val):
        self.stack.append(val)
  
    def pop(self):
        if self.size >= 1:
            self.stack.pop(self.size - 1)
        else:
            self.stack = []
  
    # Period is 3 to satisfy grandparent-parent-child relationship
    def rolling_push(self, val, period = 3):
          
        # Find the index of element to remove
        to_remove_idx = self.size - period
  
        # If index is out of bounds then we remove nothing, i.e, 0
        to_remove = 0 if to_remove_idx < 0 else self.stack[to_remove_idx]
  
        # For rolling sum what we want is that at each index i, 
        # we remove out-of-period elements, but since we are not
        # maintaining a separate list of actual elements, 
        # we can get the actual element by taking diff between current
        # and previous element. So every time we remove an element, 
        # we also add the element just before it, which is
        # equivalent to removing the actual value and not the rolling sum.
  
        # If index is out of bounds or 0 then we add nothing 
        # i.e, 0, because there is no previous element
        to_add = 0 if to_remove_idx <= 0 else self.stack[to_remove_idx - 1]
  
        # If stack is empty then just push the value
        # Else add last element to current value to get rolling sum 
        # then subtract out-of-period elements,
        # then finally add the element just before out-of-period element
        self.push(val if self.size <= 0 else val + self.stack[self.size - 1
            - to_remove + to_add)
  
    def show(self):
        for item in self.stack:
            print(item)
  
  
# Global variables used by count_with_greater_sum()
count = 0
s = Stack()
  
def count_with_greater_sum(root_node, x):
    global s, count
  
    if not root_node:
        return 0
  
    s.rolling_push(root_node.val)
  
    if s.size >= 3 and s.top() > x:
        count += 1
  
    count_with_greater_sum(root_node.left, x)
    count_with_greater_sum(root_node.right, x)
  
    # Moving up the tree so pop the last element
    s.pop()
  
  
if __name__ == '__main__':
    root = Node(10)
  
    root.left = Node(1)
    root.right = Node(22)
  
    root.left.left = Node(35)
    root.left.right = Node(4)
  
    root.right.left = Node(15)
    root.right.right = Node(67)
  
    root.left.left.left = Node(57)
    root.left.left.right = Node(38)
  
    root.left.right.left = Node(9)
    root.left.right.right = Node(10)
  
    root.right.left.left = Node(110)
    root.right.left.right = Node(312)
  
    root.right.right.left = Node(131)
    root.right.right.right = Node(414)
  
    root.right.left.right.left = Node(8)
  
    root.right.right.right.right = Node(39)
  
    count_with_greater_sum(root, 100)
  
    print(count)

chevron_right


Output:

6


My Personal Notes arrow_drop_up

Check out this Author's contributed articles.

If you like GeeksforGeeks and would like to contribute, you can also write an article using contribute.geeksforgeeks.org or mail your article to contribute@geeksforgeeks.org. See your article appearing on the GeeksforGeeks main page and help other Geeks.

Please Improve this article if you find anything incorrect by clicking on the "Improve Article" button below.