Number of ways to divide a Binary tree into two halves
Last Updated :
28 Jan, 2022
Given a binary tree, the task is to count the number of ways to remove a single edge from the tree such that the tree gets divided into two halves with equal sum.
Examples:
Input:
1
/ \
-1 -1
\
1
Output: 1
Only way to do this will be to remove the edge from the right of the root.
After that we will get 2 sub-trees with sum = 0.
1
/
-1
and
-1
\
1
will be the two sub-trees.
Input:
1
/ \
-1 -1
\
-1
Output: 2
A simple solution will be to remove all the edges of the tree one by one and check if that splits the tree into two halves with the same sum. If it does, we will increase the final answer by 1. This will take O(N2) time in the worst case where “N” is the number of nodes in the tree.
Efficient approach:
- Create a variable ‘sum’ and store the sum of all the elements of the Binary tree in it. We can find the sum of all the elements of a Binary tree in O(N) time as discussed in this article.
- Now we perform the following steps recursively starting from root node:
- Find the sum of all the elements of its right sub-tree (“R”). If it’s equal to half of the total sum, we increase the count by 1. This is because removing the edge connecting the current node with its right child will divide the tree into two trees with equal sum.
- Find the sum of all the elements of its left sub-tree (“L”). If it’s equal to half of the total sum, we increase the count by 1.
Below is the implementation of the above approach:
C++
#include <bits/stdc++.h>
using namespace std;
struct node {
int data;
node* left;
node* right;
node( int data)
{
this ->data = data;
left = NULL;
right = NULL;
}
};
int findSum(node* curr)
{
if (curr == NULL)
return 0;
return curr->data + findSum(curr->left)
+ findSum(curr->right);
}
int checkSum(node* curr, int sum, int & ans)
{
int l = 0, r = 0;
if (curr->left != NULL) {
l = checkSum(curr->left, sum, ans);
if (2 * l == sum)
ans++;
}
if (curr->right != NULL) {
r = checkSum(curr->right, sum, ans);
if (2 * r == sum)
ans++;
}
return l + r + curr->data;
}
int cntWays(node* root)
{
if (root == NULL)
return 0;
int ans = 0;
int sum = findSum(root);
if (sum % 2 == 1)
return 0;
checkSum(root, sum, ans);
return ans;
}
int main()
{
node* root = new node(1);
root->left = new node(-1);
root->right = new node(-1);
root->right->right = new node(1);
cout << cntWays(root);
return 0;
}
|
Java
class GFG
{
static class node
{
int data;
node left;
node right;
node( int data)
{
this .data = data;
left = null ;
right = null ;
}
};
static int ans;
static int findSum(node curr)
{
if (curr == null )
return 0 ;
return curr.data + findSum(curr.left) +
findSum(curr.right);
}
static int checkSum(node curr, int sum)
{
int l = 0 , r = 0 ;
if (curr.left != null )
{
l = checkSum(curr.left, sum);
if ( 2 * l == sum)
ans++;
}
if (curr.right != null )
{
r = checkSum(curr.right, sum);
if ( 2 * r == sum)
ans++;
}
return l + r + curr.data;
}
static int cntWays(node root)
{
if (root == null )
return 0 ;
ans = 0 ;
int sum = findSum(root);
if (sum % 2 == 1 )
return 0 ;
checkSum(root, sum);
return ans;
}
public static void main(String[] args)
{
node root = new node( 1 );
root.left = new node(- 1 );
root.right = new node(- 1 );
root.right.right = new node( 1 );
System.out.print(cntWays(root));
}
}
|
C#
using System;
class GFG
{
public class node
{
public int data;
public node left;
public node right;
public node( int data)
{
this .data = data;
left = null ;
right = null ;
}
};
static int ans;
static int findSum(node curr)
{
if (curr == null )
return 0;
return curr.data + findSum(curr.left) +
findSum(curr.right);
}
static int checkSum(node curr, int sum)
{
int l = 0, r = 0;
if (curr.left != null )
{
l = checkSum(curr.left, sum);
if (2 * l == sum)
ans++;
}
if (curr.right != null )
{
r = checkSum(curr.right, sum);
if (2 * r == sum)
ans++;
}
return l + r + curr.data;
}
static int cntWays(node root)
{
if (root == null )
return 0;
ans = 0;
int sum = findSum(root);
if (sum % 2 == 1)
return 0;
checkSum(root, sum);
return ans;
}
public static void Main(String[] args)
{
node root = new node(1);
root.left = new node(-1);
root.right = new node(-1);
root.right.right = new node(1);
Console.Write(cntWays(root));
}
}
|
Javascript
<script>
class node {
constructor(val) {
this .data = val;
this .prev = null ;
this .next = null ;
}
}
var ans;
function findSum( curr) {
if (curr == null )
return 0;
return curr.data + findSum(curr.left) + findSum(curr.right);
}
function checkSum( curr , sum) {
var l = 0, r = 0;
if (curr.left != null ) {
l = checkSum(curr.left, sum);
if (2 * l == sum)
ans++;
}
if (curr.right != null ) {
r = checkSum(curr.right, sum);
if (2 * r == sum)
ans++;
}
return l + r + curr.data;
}
function cntWays( root) {
if (root == null )
return 0;
ans = 0;
var sum = findSum(root);
if (sum % 2 == 1)
return 0;
checkSum(root, sum);
return ans;
}
root = new node(1);
root.left = new node(-1);
root.right = new node(-1);
root.right.right = new node(1);
document.write(cntWays(root));
</script>
|
Python3
class node:
def __init__( self ,data):
self .data = data
self .left = None
self .right = None
def findSum(curr):
if (curr = = None ):
return 0
return curr.data + findSum(curr.left) + findSum(curr.right)
def checkSum(curr, s):
global ans
l = 0 ; r = 0
if (curr.left ! = None ):
l = checkSum(curr.left, s)
if ( 2 * l = = s):
ans + = 1
if (curr.right ! = None ):
r = checkSum(curr.right, s)
if ( 2 * r = = s):
ans + = 1
return l + r + curr.data
def cntWays(root):
if (root = = None ):
return 0
global ans
ans = 0
s = findSum(root)
if (s % 2 ):
return 0
checkSum(root, s)
return ans
if __name__ = = '__main__' :
root = node( 1 )
root.left = node( - 1 )
root.right = node( - 1 )
root.right.right = node( 1 )
print (cntWays(root))
|
Time complexity of this approach will be O(N) and space complexity will O(H) where “N” equals number of node in Binary tree and “H” equals height of the Binary Tree.
Share your thoughts in the comments
Please Login to comment...