Skip to content
Related Articles

Related Articles

Extended Disjoint Set Union on Trees
  • Last Updated : 14 Aug, 2020
GeeksforGeeks - Summer Carnival Banner

Prerequisites: DFS, Trees, DSU

Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.

Examples:

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} }
Output:
False
True
True
Explanation:
Query 1: No number occurs three times in sub-tree 2
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5
Query 3: Number -1 occurs once in sub-tree 7

Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} }
Output:
False
True
Explanation:
Query 1: No number occurs exactly 2 times in sub-tree 2
Query 2: Number -1 occurs twice in sub-tree 4



Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.

Below is the implementation of the above approach:

Java




// Java program for the above approach
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
  
@SuppressWarnings("unchecked")
public class Main {
  
    // To store edges of tree
    static ArrayList<Integer> adj[];
  
    // To store number associated
    // with vertex
    static int num[];
  
    // To store frequency of number
    static HashMap<Integer, Integer> freq;
  
    // Function to add edges in tree
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
  
    // Function returns boolean value
    // representing is there any number
    // present in subtree qv having
    // frequency qc
    static boolean query(int qv, int qc)
    {
  
        freq = new HashMap<>();
  
        // Start from root
        int v = 1;
  
        // DFS Call
        if (qv == v) {
            dfs(v, 0, true, qv);
        }
        else
            dfs(v, 0, false, qv);
  
        // Check for frequency
        for (int fq : freq.values()) {
            if (fq == qc)
                return true;
        }
        return false;
    }
  
    // Function to implement DFS
    static void dfs(int v, int p,
                    boolean isQV, int qv)
    {
        if (isQV) {
  
            // If we are on subtree qv,
            // then increment freq of
            // num[v] in freq
            freq.put(num[v],
                     freq.getOrDefault(num[v], 0) + 1);
        }
  
        // Recursive DFS Call for
        // adjacency list of node u
        for (int u : adj[v]) {
            if (p != u) {
                if (qv == u) {
                    dfs(u, v, true, qv);
                }
                else
                    dfs(u, v, isQV, qv);
            }
        }
    }
  
    // Driver Code
    public static void main(String[] args)
    {
  
        // Given Nodes
        int n = 8;
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++)
            adj[i] = new ArrayList<>();
  
        // Given edges of tree
        // (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
  
        // Number assigned to each vertex
        num = new int[] { -1, 11, 2, 7, 1, -3, -1, -1, -3 };
  
        // Total number of queries
        int q = 3;
  
        // Function Call to find each query
        System.out.println(query(2, 3));
        System.out.println(query(5, 2));
        System.out.println(query(7, 1));
    }
}
Output:
false
true
true


Time Complexity: O(N * Q) Since in each query, tree needs to be traversed.
Auxiliary Space: O(N + E + Q)

Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:

  1. Create an array, size[] to store size of sub-trees.
  2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
  3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
  4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
  5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p.
  6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
  7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
  8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach:

Java




// Java program for the above approach
import java.awt.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
  
@SuppressWarnings("unchecked")
public class Main {
  
    // To store edges
    static ArrayList<Integer> adj[];
  
    // num[v] = number assigned
    // to vertex v
    static int num[];
  
    // map[v].get(c) = total vertices
    // in subtree v having number c
    static Map<Integer, Integer> map[];
  
    // size[v]=size of subtree v
    static int size[];
  
    static HashMap<Point, Boolean> ans;
    static ArrayList<Integer> qv[];
  
    // Function to add edges
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
  
    // Function to find subtree size
    // of every vertex using dfsSize()
    static void dfsSize(int v, int p)
    {
        size[v]++;
  
        // Traverse dfsSize recursively
        for (int u : adj[v]) {
            if (p != u) {
                dfsSize(u, v);
                size[v] += size[u];
            }
        }
    }
  
    // Function to implement DFS Traversal
    static void dfs(int v, int p)
    {
        int mx = -1, bigU = -1;
  
        // Find adjacent vertex with
        // maximum size
        for (int u : adj[v]) {
            if (u != p) {
                dfs(u, v);
                if (size[u] > mx) {
                    mx = size[u];
                    bigU = u;
                }
            }
        }
  
        if (bigU != -1) {
  
            // Passing referencing
            map[v] = map[bigU];
        }
        else {
  
            // If no adjacent vertex
            // present initialize map[v]
            map[v] = new HashMap<Integer, Integer>();
        }
  
        // Update frequency of current number
        map[v].put(num[v],
                   map[v].getOrDefault(num[v], 0) + 1);
  
        // Add all adjacent vertices
        // maps to map[v]
        for (int u : adj[v]) {
  
            if (u != bigU && u != p) {
  
                for (Entry<Integer, Integer>
                         pair : map[u].entrySet()) {
                    map[v].put(
                        pair.getKey(),
                        pair.getValue()
                            + map[v]
                                  .getOrDefault(
                                      pair.getKey(), 0));
                }
            }
        }
  
        // Store all queries related
        // to vertex v
        for (int freq : qv[v]) {
            ans.put(new Point(v, freq),
                    map[v].containsValue(freq));
        }
    }
  
    // Function to find answer for
    // each queries
    static void solveQuery(Point queries[],
                           int N, int q)
    {
        // Add all queries to qv
        // where i<qv[v].size()
        qv = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            qv[i] = new ArrayList<>();
  
        for (Point p : queries) {
            qv[p.x].add(p.y);
        }
  
        // Get sizes of all subtrees
        size = new int[N + 1];
  
        // calculate size[]
        dfsSize(1, 0);
  
        // Map will be used to store
        // answers for current vertex
        // on dfs
        map = new HashMap[N + 1];
  
        // To store answer of queries
        ans = new HashMap<>();
  
        // DFS Call
        dfs(1, 0);
  
        for (Point p : queries) {
  
            // Print answer for each query
            System.out.println(ans.get(p));
        }
    }
  
    // Driver Code
    public static void main(String[] args)
    {
        int N = 8;
        adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            adj[i] = new ArrayList<>();
  
        // Given edges (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
  
        // Store number given to vertices
        // set num[0]=-1 because
        // there is no vertex 0
        num
            = new int[] { -1, 11, 2, 7, 1,
                          -3, -1, -1, -3 };
  
        // Queries
        int q = 3;
  
        // To store queries
        Point queries[] = new Point[q];
  
        // Given Queries
        queries[0] = new Point(2, 3);
        queries[1] = new Point(5, 2);
        queries[2] = new Point(7, 1);
  
        // Function Call
        solveQuery(queries, N, q);
    }
}
Output:
false
true
true


Time Complexity: O(N*logN2)
Auxiliary Space: O(N + E + Q)

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  To complete your preparation from learning a language to DS Algo and many more,  please refer Complete Interview Preparation Course.

My Personal Notes arrow_drop_up
Recommended Articles
Page :