Count Non-Path Triplets in a Tree
Last Updated :
27 Nov, 2023
Given a tree with N vertices numbered from 1 to N. The ith edge connects vertex Ai to vertex Bi, the task is to find the number of tuples of integers
(i, j, k) such that:
- i < j < k and 1 ≤ i, j, k ≤ N.
- The given tree does not contain a simple path that includes all three vertices i, j, and k.
Examples:
Inputs: N = 5, A[] = [1, 2, 2, 1], B[] = [2, 3, 4, 5]
Output: 2
Explanation: Two tuples satisfy the conditions: (i, j, k) = (1, 3, 4), (3, 4, 5).
Inputs: N = 6, A[] = [1, 2, 3, 4, 5], B[] = [2, 3, 4, 5, 6]
Output: 0
Approach: To solve the problem follow the below observations:
Observations:
In this explanation, we discussed the method to satisfy a certain condition in a tree structure and provided a formula to calculate a sum involving fixed points.
- Condition: The goal is to avoid having three consecutive vertices (a, b, c) on a simple path within the tree.
- Fixed Point: When you have three consecutive points on a path, you can consider a fixed point in the middle. This helps in satisfying the conditions mentioned above.
- Connected Components: The sizes of connected components, excluding certain specified points (m1, …, mk), are denoted as mi.
- Sum of Fixed Points: The sum of fixed points is calculated using the formula 2 * Σ (i < j) mi * mj.
- Transformation: The formula can be transformed as follows: 2 * Σ (i < j) mi * mj = ( Σ mi ) 2 – Σ (mi) 2.
- Complexity: The problem can be solved efficiently with complexity O(k) by pre-calculating the size of the subtree rooted at each vertex v against v and O(n) overall.
- No Operator Change: The problem can be solved using this approach without altering any mathematical operators.
Steps involved to solve the problem:
- Create an empty adjacency list ‘T’ to represent the tree structure. For each edge in the tree, add the corresponding vertices to each other’s adjacency lists to represent the connections.
- Initialize the variable ‘ans’ to n * (n – 1) * (n – 2) / 6, which represents the total number of tuples (i, j, k) with 1 ≤ i < j < k ≤ N.
- Initialize a vector ‘sz’ of size ‘n’, where each element is initialized to 1. This vector will be used to store the size of each subtree rooted at every vertex.
- Define a recursive function ‘dfs’ that takes two parameters ‘v’ (current vertex) and ‘pre’ (the parent vertex). The purpose of this function is to calculate the size of each subtree rooted at each vertex and update the ‘ans’ variable accordingly.
- Inside the ‘dfs’ function, iterate through the adjacent vertices of ‘v’ (excluding the parent vertex ‘pre’). For each adjacent vertex ‘c’, recursively call the ‘dfs’ function with ‘c’ as the current vertex and ‘v’ as the parent vertex.
- Calculate the ‘sum1‘ and ‘sum2‘ values for the current vertex ‘v’ based on the sizes of its subtrees and update the ‘sz’ vector accordingly.
- Decrement ‘ans’ by (sum1 * sum1 – sum2) >> 1. This step removes the tuples (i, j, k) that form a simple path containing all three vertices i, j, and k in the tree.
- Call the ‘dfs’ function with the starting vertex as 0 (or any other appropriate starting vertex) and the parent vertex as -1 (or any other value indicating that there is no parent).
- Finally, print the ‘ans’ value, which represents the number of tuples (i, j, k) that satisfy the condition of not having a simple path containing all three vertices i, j, and k in the tree.
Below is the implementation for the above approach:
C++
#include <bits/stdc++.h>
using namespace std;
using ll = long long ;
void dfs( int v, int pre, vector<vector< int > >& T,
vector<ll>& sz, ll n, ll& ans)
{
ll sum1 = 0, sum2 = 0;
for ( int c : T[v]) {
if (c == pre)
continue ;
dfs(c, v, T, sz, n, ans);
sum1 += sz;
sum2 += sz * sz;
sz[v] += sz;
}
sum1 += n - sz[v];
sum2 += (n - sz[v]) * (n - sz[v]);
ans -= (sum1 * sum1 - sum2) >> 1;
}
int main()
{
ll n = 5;
vector<vector< int > > T(n + 1);
T[1].push_back(2);
T[2].push_back(1);
T[2].push_back(3);
T[3].push_back(2);
T[2].push_back(4);
T[4].push_back(2);
T[1].push_back(5);
T[5].push_back(1);
ll ans = n * (n - 1) * (n - 2) / 6;
vector<ll> sz(n + 1, 1);
dfs(1, -1, T, sz, n, ans);
cout << ans << '\n' ;
return 0;
}
|
Java
import java.util.*;
public class Main {
static long ans = 0 ;
static void dfs( int v, int pre, List<List<Integer>> T, List<Long> sz, long n) {
long sum1 = 0 , sum2 = 0 ;
for ( int c : T.get(v)) {
if (c == pre)
continue ;
dfs(c, v, T, sz, n);
sum1 += sz.get(c);
sum2 += sz.get(c) * sz.get(c);
sz.set(v, sz.get(v) + sz.get(c));
}
sum1 += n - sz.get(v);
sum2 += (n - sz.get(v)) * (n - sz.get(v));
ans -= (sum1 * sum1 - sum2) / 2 ;
}
public static void main(String[] args) {
long n = 5 ;
List<List<Integer>> T = new ArrayList<>();
for ( int i = 0 ; i <= n; i++) {
T.add( new ArrayList<>());
}
T.get( 1 ).add( 2 );
T.get( 2 ).add( 1 );
T.get( 2 ).add( 3 );
T.get( 3 ).add( 2 );
T.get( 2 ).add( 4 );
T.get( 4 ).add( 2 );
T.get( 1 ).add( 5 );
T.get( 5 ).add( 1 );
ans = n * (n - 1 ) * (n - 2 ) / 6 ;
List<Long> sz = new ArrayList<>();
for ( int i = 0 ; i <= n; i++) {
sz.add(1L);
}
dfs( 1 , - 1 , T, sz, n);
System.out.println(ans);
}
}
|
Python3
def dfs(v, pre, T, sz, n, ans):
sum1 = 0
sum2 = 0
for c in T[v]:
if c = = pre:
continue
dfs(c, v, T, sz, n, ans)
sum1 + = sz
sum2 + = sz * sz
sz[v] + = sz
sum1 + = n - sz[v]
sum2 + = (n - sz[v]) * (n - sz[v])
ans[ 0 ] - = (sum1 * sum1 - sum2) / / 2
def main():
n = 5
T = [[] for _ in range (n + 1 )]
edges = [( 1 , 2 ), ( 2 , 3 ), ( 2 , 4 ), ( 1 , 5 )]
for A, B in edges:
T[A].append(B)
T[B].append(A)
ans = [n * (n - 1 ) * (n - 2 ) / / 6 ]
sz = [ 1 ] * (n + 1 )
dfs( 1 , - 1 , T, sz, n, ans)
print (ans[ 0 ])
if __name__ = = "__main__" :
main()
|
C#
using System;
using System.Collections.Generic;
class GFG
{
static long ans = 0;
static void Dfs( int v, int pre, List<List< int >> T, List< long > sz, long n)
{
long sum1 = 0, sum2 = 0;
foreach ( int c in T[v])
{
if (c == pre)
continue ;
Dfs(c, v, T, sz, n);
sum1 += sz;
sum2 += sz * sz;
sz[v] += sz;
}
sum1 += n - sz[v];
sum2 += (n - sz[v]) * (n - sz[v]);
ans -= (sum1 * sum1 - sum2) / 2;
}
public static void Main( string [] args)
{
long n = 5;
List<List< int >> T = new List<List< int >>();
for ( int i = 0; i <= n; i++)
{
T.Add( new List< int >());
}
T[1].Add(2);
T[2].Add(1);
T[2].Add(3);
T[3].Add(2);
T[2].Add(4);
T[4].Add(2);
T[1].Add(5);
T[5].Add(1);
ans = n * (n - 1) * (n - 2) / 6;
List< long > sz = new List< long >();
for ( int i = 0; i <= n; i++)
{
sz.Add(1L);
}
Dfs(1, -1, T, sz, n);
Console.WriteLine(ans);
}
}
|
Javascript
let ans = 0;
function dfs(v, pre, T, sz, n) {
let sum1 = 0, sum2 = 0;
for (let c of T[v]) {
if (c === pre)
continue ;
dfs(c, v, T, sz, n);
sum1 += sz;
sum2 += sz * sz;
sz[v] += sz;
}
sum1 += n - sz[v];
sum2 += (n - sz[v]) * (n - sz[v]);
ans -= (sum1 * sum1 - sum2) / 2;
}
function main() {
let n = 5;
let T = [];
for (let i = 0; i <= n; i++) {
T.push([]);
}
T[1].push(2);
T[2].push(1);
T[2].push(3);
T[3].push(2);
T[2].push(4);
T[4].push(2);
T[1].push(5);
T[5].push(1);
ans = n * (n - 1) * (n - 2) / 6;
let sz = Array.from({ length: n + 1 }, () => 1);
dfs(1, -1, T, sz, n);
console.log(ans);
}
main();
|
Time Complexity: O(n)
Auxiliary Space: O(n)
Share your thoughts in the comments
Please Login to comment...