Related Articles
Extended Disjoint Set Union on Trees
• Last Updated : 14 Aug, 2020

Prerequisites: DFS, Trees, DSU

Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.

Examples:

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} }
Output:
False
True
True
Explanation:
Query 1: No number occurs three times in sub-tree 2
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5
Query 3: Number -1 occurs once in sub-tree 7

Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} }
Output:
False
True
Explanation:
Query 1: No number occurs exactly 2 times in sub-tree 2
Query 2: Number -1 occurs twice in sub-tree 4

## Recommended: Please try your approach on {IDE} first, before moving on to the solution.

Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.

Below is the implementation of the above approach:

## Java

 `// Java program for the above approach``import` `java.util.ArrayList;``import` `java.util.HashMap;``import` `java.util.Map;`` ` `@SuppressWarnings``(``"unchecked"``)``public` `class` `Main {`` ` `    ``// To store edges of tree``    ``static` `ArrayList adj[];`` ` `    ``// To store number associated``    ``// with vertex``    ``static` `int` `num[];`` ` `    ``// To store frequency of number``    ``static` `HashMap freq;`` ` `    ``// Function to add edges in tree``    ``static` `void` `add(``int` `u, ``int` `v)``    ``{``        ``adj[u].add(v);``        ``adj[v].add(u);``    ``}`` ` `    ``// Function returns boolean value``    ``// representing is there any number``    ``// present in subtree qv having``    ``// frequency qc``    ``static` `boolean` `query(``int` `qv, ``int` `qc)``    ``{`` ` `        ``freq = ``new` `HashMap<>();`` ` `        ``// Start from root``        ``int` `v = ``1``;`` ` `        ``// DFS Call``        ``if` `(qv == v) {``            ``dfs(v, ``0``, ``true``, qv);``        ``}``        ``else``            ``dfs(v, ``0``, ``false``, qv);`` ` `        ``// Check for frequency``        ``for` `(``int` `fq : freq.values()) {``            ``if` `(fq == qc)``                ``return` `true``;``        ``}``        ``return` `false``;``    ``}`` ` `    ``// Function to implement DFS``    ``static` `void` `dfs(``int` `v, ``int` `p,``                    ``boolean` `isQV, ``int` `qv)``    ``{``        ``if` `(isQV) {`` ` `            ``// If we are on subtree qv,``            ``// then increment freq of``            ``// num[v] in freq``            ``freq.put(num[v],``                     ``freq.getOrDefault(num[v], ``0``) + ``1``);``        ``}`` ` `        ``// Recursive DFS Call for``        ``// adjacency list of node u``        ``for` `(``int` `u : adj[v]) {``            ``if` `(p != u) {``                ``if` `(qv == u) {``                    ``dfs(u, v, ``true``, qv);``                ``}``                ``else``                    ``dfs(u, v, isQV, qv);``            ``}``        ``}``    ``}`` ` `    ``// Driver Code``    ``public` `static` `void` `main(String[] args)``    ``{`` ` `        ``// Given Nodes``        ``int` `n = ``8``;``        ``adj = ``new` `ArrayList[n + ``1``];``        ``for` `(``int` `i = ``1``; i <= n; i++)``            ``adj[i] = ``new` `ArrayList<>();`` ` `        ``// Given edges of tree``        ``// (root=1)``        ``add(``1``, ``2``);``        ``add(``1``, ``3``);``        ``add(``2``, ``4``);``        ``add(``2``, ``5``);``        ``add(``5``, ``8``);``        ``add(``5``, ``6``);``        ``add(``6``, ``7``);`` ` `        ``// Number assigned to each vertex``        ``num = ``new` `int``[] { -``1``, ``11``, ``2``, ``7``, ``1``, -``3``, -``1``, -``1``, -``3` `};`` ` `        ``// Total number of queries``        ``int` `q = ``3``;`` ` `        ``// Function Call to find each query``        ``System.out.println(query(``2``, ``3``));``        ``System.out.println(query(``5``, ``2``));``        ``System.out.println(query(``7``, ``1``));``    ``}``}`
Output:
```false
true
true
```

Time Complexity: O(N * Q) Since in each query, tree needs to be traversed.
Auxiliary Space: O(N + E + Q)

Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:

1. Create an array, size[] to store size of sub-trees.
2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p.
6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach:

## Java

 `// Java program for the above approach``import` `java.awt.*;``import` `java.util.ArrayList;``import` `java.util.HashMap;``import` `java.util.Map;``import` `java.util.Map.Entry;`` ` `@SuppressWarnings``(``"unchecked"``)``public` `class` `Main {`` ` `    ``// To store edges``    ``static` `ArrayList adj[];`` ` `    ``// num[v] = number assigned``    ``// to vertex v``    ``static` `int` `num[];`` ` `    ``// map[v].get(c) = total vertices``    ``// in subtree v having number c``    ``static` `Map map[];`` ` `    ``// size[v]=size of subtree v``    ``static` `int` `size[];`` ` `    ``static` `HashMap ans;``    ``static` `ArrayList qv[];`` ` `    ``// Function to add edges``    ``static` `void` `add(``int` `u, ``int` `v)``    ``{``        ``adj[u].add(v);``        ``adj[v].add(u);``    ``}`` ` `    ``// Function to find subtree size``    ``// of every vertex using dfsSize()``    ``static` `void` `dfsSize(``int` `v, ``int` `p)``    ``{``        ``size[v]++;`` ` `        ``// Traverse dfsSize recursively``        ``for` `(``int` `u : adj[v]) {``            ``if` `(p != u) {``                ``dfsSize(u, v);``                ``size[v] += size[u];``            ``}``        ``}``    ``}`` ` `    ``// Function to implement DFS Traversal``    ``static` `void` `dfs(``int` `v, ``int` `p)``    ``{``        ``int` `mx = -``1``, bigU = -``1``;`` ` `        ``// Find adjacent vertex with``        ``// maximum size``        ``for` `(``int` `u : adj[v]) {``            ``if` `(u != p) {``                ``dfs(u, v);``                ``if` `(size[u] > mx) {``                    ``mx = size[u];``                    ``bigU = u;``                ``}``            ``}``        ``}`` ` `        ``if` `(bigU != -``1``) {`` ` `            ``// Passing referencing``            ``map[v] = map[bigU];``        ``}``        ``else` `{`` ` `            ``// If no adjacent vertex``            ``// present initialize map[v]``            ``map[v] = ``new` `HashMap();``        ``}`` ` `        ``// Update frequency of current number``        ``map[v].put(num[v],``                   ``map[v].getOrDefault(num[v], ``0``) + ``1``);`` ` `        ``// Add all adjacent vertices``        ``// maps to map[v]``        ``for` `(``int` `u : adj[v]) {`` ` `            ``if` `(u != bigU && u != p) {`` ` `                ``for` `(Entry``                         ``pair : map[u].entrySet()) {``                    ``map[v].put(``                        ``pair.getKey(),``                        ``pair.getValue()``                            ``+ map[v]``                                  ``.getOrDefault(``                                      ``pair.getKey(), ``0``));``                ``}``            ``}``        ``}`` ` `        ``// Store all queries related``        ``// to vertex v``        ``for` `(``int` `freq : qv[v]) {``            ``ans.put(``new` `Point(v, freq),``                    ``map[v].containsValue(freq));``        ``}``    ``}`` ` `    ``// Function to find answer for``    ``// each queries``    ``static` `void` `solveQuery(Point queries[],``                           ``int` `N, ``int` `q)``    ``{``        ``// Add all queries to qv``        ``// where i();`` ` `        ``for` `(Point p : queries) {``            ``qv[p.x].add(p.y);``        ``}`` ` `        ``// Get sizes of all subtrees``        ``size = ``new` `int``[N + ``1``];`` ` `        ``// calculate size[]``        ``dfsSize(``1``, ``0``);`` ` `        ``// Map will be used to store``        ``// answers for current vertex``        ``// on dfs``        ``map = ``new` `HashMap[N + ``1``];`` ` `        ``// To store answer of queries``        ``ans = ``new` `HashMap<>();`` ` `        ``// DFS Call``        ``dfs(``1``, ``0``);`` ` `        ``for` `(Point p : queries) {`` ` `            ``// Print answer for each query``            ``System.out.println(ans.get(p));``        ``}``    ``}`` ` `    ``// Driver Code``    ``public` `static` `void` `main(String[] args)``    ``{``        ``int` `N = ``8``;``        ``adj = ``new` `ArrayList[N + ``1``];``        ``for` `(``int` `i = ``1``; i <= N; i++)``            ``adj[i] = ``new` `ArrayList<>();`` ` `        ``// Given edges (root=1)``        ``add(``1``, ``2``);``        ``add(``1``, ``3``);``        ``add(``2``, ``4``);``        ``add(``2``, ``5``);``        ``add(``5``, ``8``);``        ``add(``5``, ``6``);``        ``add(``6``, ``7``);`` ` `        ``// Store number given to vertices``        ``// set num=-1 because``        ``// there is no vertex 0``        ``num``            ``= ``new` `int``[] { -``1``, ``11``, ``2``, ``7``, ``1``,``                          ``-``3``, -``1``, -``1``, -``3` `};`` ` `        ``// Queries``        ``int` `q = ``3``;`` ` `        ``// To store queries``        ``Point queries[] = ``new` `Point[q];`` ` `        ``// Given Queries``        ``queries[``0``] = ``new` `Point(``2``, ``3``);``        ``queries[``1``] = ``new` `Point(``5``, ``2``);``        ``queries[``2``] = ``new` `Point(``7``, ``1``);`` ` `        ``// Function Call``        ``solveQuery(queries, N, q);``    ``}``}`
Output:
```false
true
true
```

Time Complexity: O(N*logN2)
Auxiliary Space: O(N + E + Q)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  To complete your preparation from learning a language to DS Algo and many more,  please refer Complete Interview Preparation Course.

My Personal Notes arrow_drop_up