Skip to content
Related Articles

Related Articles

Improve Article
Save Article
Like Article

Extended Disjoint Set Union on Trees

  • Last Updated : 01 Dec, 2021

Prerequisites: DFS, Trees, DSU
Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.
Examples: 
 

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} } 
Output: 
False 
True 
True 
Explanation: 
Query 1: No number occurs three times in sub-tree 2 
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5 
Query 3: Number -1 occurs once in sub-tree 7 
Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} } 
Output: 
False 
True 
Explanation: 
Query 1: No number occurs exactly 2 times in sub-tree 2 
Query 2: Number -1 occurs twice in sub-tree 4 
 

Attention reader! Don’t stop learning now. Get hold of all the important DSA concepts with the DSA Self Paced Course at a student-friendly price and become industry ready.  To complete your preparation from learning a language to DS Algo and many more,  please refer Complete Interview Preparation Course.

In case you wish to attend live classes with experts, please refer DSA Live Classes for Working Professionals and Competitive Programming Live for Students.

 



Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.
Below is the implementation of the above approach:
 

Java




// Java program for the above approach
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges of tree
    static ArrayList<Integer> adj[];
 
    // To store number associated
    // with vertex
    static int num[];
 
    // To store frequency of number
    static HashMap<Integer, Integer> freq;
 
    // Function to add edges in tree
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function returns boolean value
    // representing is there any number
    // present in subtree qv having
    // frequency qc
    static boolean query(int qv, int qc)
    {
 
        freq = new HashMap<>();
 
        // Start from root
        int v = 1;
 
        // DFS Call
        if (qv == v) {
            dfs(v, 0, true, qv);
        }
        else
            dfs(v, 0, false, qv);
 
        // Check for frequency
        for (int fq : freq.values()) {
            if (fq == qc)
                return true;
        }
        return false;
    }
 
    // Function to implement DFS
    static void dfs(int v, int p,
                    boolean isQV, int qv)
    {
        if (isQV) {
 
            // If we are on subtree qv,
            // then increment freq of
            // num[v] in freq
            freq.put(num[v],
                     freq.getOrDefault(num[v], 0) + 1);
        }
 
        // Recursive DFS Call for
        // adjacency list of node u
        for (int u : adj[v]) {
            if (p != u) {
                if (qv == u) {
                    dfs(u, v, true, qv);
                }
                else
                    dfs(u, v, isQV, qv);
            }
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
 
        // Given Nodes
        int n = 8;
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges of tree
        // (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Number assigned to each vertex
        num = new int[] { -1, 11, 2, 7, 1, -3, -1, -1, -3 };
 
        // Total number of queries
        int q = 3;
 
        // Function Call to find each query
        System.out.println(query(2, 3));
        System.out.println(query(5, 2));
        System.out.println(query(7, 1));
    }
}

Python3




# Python 3 program for the above approach
 
# To store edges of tree
adj=[]
 
# To store number associated
# with vertex
num=[]
 
# To store frequency of number
freq=dict()
 
# Function to add edges in tree
def add(u, v):
    adj[u].append(v)
    adj[v].append(u)
 
 
# Function returns boolean value
# representing is there any number
# present in subtree qv having
# frequency qc
def query(qv, qc):
    freq.clear()
     
    # Start from root
    v = 1
 
    # DFS Call
    if (qv == v) :
        dfs(v, 0, True, qv)
     
    else:
        dfs(v, 0, False, qv)
 
    # Check for frequency
    if qc in freq.values() :
        return True
     
    return False
 
 
# Function to implement DFS
def dfs(v, p, isQV, qv):
    if (isQV) :
 
        # If we are on subtree qv,
        # then increment freq of
        # num[v] in freq
        freq[num[v]]=freq.get(num[v], 0) + 1
     
 
    # Recursive DFS Call for
    # adjacency list of node u
    for u in adj[v]:
        if (p != u) :
            if (qv == u) :
                dfs(u, v, True, qv)
             
            else:
                dfs(u, v, isQV, qv)
         
     
 
 
# Driver Code
if __name__ == '__main__':
 
    # Given Nodes
    n = 8
    for _ in range(n+1):
        adj.append([])
 
    # Given edges of tree
    # (root=1)
    add(1, 2)
    add(1, 3)
    add(2, 4)
    add(2, 5)
    add(5, 8)
    add(5, 6)
    add(6, 7)
 
    # Number assigned to each vertex
    num = [-1, 11, 2, 7, 1, -3, -1, -1, -3]
 
    # Total number of queries
    q = 3
 
    # Function Call to find each query
    print(query(2, 3))
    print(query(5, 2))
    print(query(7, 1))
Output: 
false
true
true

 

Time Complexity: O(N * Q) Since in each query, tree needs to be traversed. 
Auxiliary Space: O(N + E + Q)
Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:
 

  1. Create an array, size[] to store size of sub-trees.
  2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
  3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
  4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
  5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p
     
  6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
  7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
  8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach: 
 

Java




// Java program for the above approach
import java.awt.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges
    static ArrayList<Integer> adj[];
 
    // num[v] = number assigned
    // to vertex v
    static int num[];
 
    // map[v].get(c) = total vertices
    // in subtree v having number c
    static Map<Integer, Integer> map[];
 
    // size[v]=size of subtree v
    static int size[];
 
    static HashMap<Point, Boolean> ans;
    static ArrayList<Integer> qv[];
 
    // Function to add edges
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function to find subtree size
    // of every vertex using dfsSize()
    static void dfsSize(int v, int p)
    {
        size[v]++;
 
        // Traverse dfsSize recursively
        for (int u : adj[v]) {
            if (p != u) {
                dfsSize(u, v);
                size[v] += size[u];
            }
        }
    }
 
    // Function to implement DFS Traversal
    static void dfs(int v, int p)
    {
        int mx = -1, bigU = -1;
 
        // Find adjacent vertex with
        // maximum size
        for (int u : adj[v]) {
            if (u != p) {
                dfs(u, v);
                if (size[u] > mx) {
                    mx = size[u];
                    bigU = u;
                }
            }
        }
 
        if (bigU != -1) {
 
            // Passing referencing
            map[v] = map[bigU];
        }
        else {
 
            // If no adjacent vertex
            // present initialize map[v]
            map[v] = new HashMap<Integer, Integer>();
        }
 
        // Update frequency of current number
        map[v].put(num[v],
                   map[v].getOrDefault(num[v], 0) + 1);
 
        // Add all adjacent vertices
        // maps to map[v]
        for (int u : adj[v]) {
 
            if (u != bigU && u != p) {
 
                for (Entry<Integer, Integer>
                         pair : map[u].entrySet()) {
                    map[v].put(
                        pair.getKey(),
                        pair.getValue()
                            + map[v]
                                  .getOrDefault(
                                      pair.getKey(), 0));
                }
            }
        }
 
        // Store all queries related
        // to vertex v
        for (int freq : qv[v]) {
            ans.put(new Point(v, freq),
                    map[v].containsValue(freq));
        }
    }
 
    // Function to find answer for
    // each queries
    static void solveQuery(Point queries[],
                           int N, int q)
    {
        // Add all queries to qv
        // where i<qv[v].size()
        qv = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            qv[i] = new ArrayList<>();
 
        for (Point p : queries) {
            qv[p.x].add(p.y);
        }
 
        // Get sizes of all subtrees
        size = new int[N + 1];
 
        // calculate size[]
        dfsSize(1, 0);
 
        // Map will be used to store
        // answers for current vertex
        // on dfs
        map = new HashMap[N + 1];
 
        // To store answer of queries
        ans = new HashMap<>();
 
        // DFS Call
        dfs(1, 0);
 
        for (Point p : queries) {
 
            // Print answer for each query
            System.out.println(ans.get(p));
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
        int N = 8;
        adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Store number given to vertices
        // set num[0]=-1 because
        // there is no vertex 0
        num
            = new int[] { -1, 11, 2, 7, 1,
                          -3, -1, -1, -3 };
 
        // Queries
        int q = 3;
 
        // To store queries
        Point queries[] = new Point[q];
 
        // Given Queries
        queries[0] = new Point(2, 3);
        queries[1] = new Point(5, 2);
        queries[2] = new Point(7, 1);
 
        // Function Call
        solveQuery(queries, N, q);
    }
}

Python3




# Python 3 program for the above approach
 
# To store edges of tree
adj=[]
 
# To store number associated
# with vertex
num=[]
 
# mp[v] = total vertices
# in subtree v having number c
mp=dict()
 
# size[v]=size of subtree v
size=[]
 
ans=dict()
 
qv=[]
 
# Function to add edges in tree
def add(u, v):
    adj[u].append(v)
    adj[v].append(u)
 
 
# Function to find subtree size
# of every vertex using dfsSize()
def dfsSize(v, p):
    size[v]+=1
 
    # Traverse dfsSize recursively
    for u in adj[v] :
        if (p != u) :
            dfsSize(u, v)
            size[v] += size[u]
         
     
 
 
# Function to implement DFS Traversal
def dfs(v, p):
    global ans
    mx = -1; bigU = -1
 
    # Find adjacent vertex with
    # maximum size
    for u in adj[v]:
        if (u != p) :
            dfs(u, v)
            if (size[u] > mx) :
                mx = size[u]
                bigU = u       
 
    if (bigU != -1) :
 
        # Passing referencing
        mp[v] = mp[bigU]
     
    else :
 
        # If no adjacent vertex
        # present initialize mp[v]
        mp[v] = dict()
     
 
    # Update frequency of current number
    mp[v][num[v]]=mp[v].get(num[v], 0) + 1
 
    # Add all adjacent vertices
    # maps to mp[v]
    for u in adj[v] :
        if u not in (bigU,p) :
            for pair in mp[u].items():
                mp[v][pair[0]]=pair[1]+mp[v].get(pair[0], 0)
         
     
 
    # Store all queries related
    # to vertex v
    for freq in qv[v] :
        ans[(v, freq)]=freq in mp[v]
 
# Function to find answer for
# each queries
def solveQuery(queries, N, q):
    global size,mp,qv,ans
    # Add all queries to qv
    # where i<qv[v].size()
    qv = []
    for i in range(N+1):
        qv.append([])
 
    for p in queries:
        qv[p[0]].append(p[1])
     
 
    # Get sizes of all subtrees   
    size = [0]*(N + 1)
 
    # calculate size[]
    dfsSize(1, 0)
 
    # mp will be used to store
    # answers for current vertex
    # on dfs
    mp = dict()
 
    # To store answer of queries
    ans = dict()
 
    # DFS Call
    dfs(1, 0)
 
    for p in queries:
 
        # Print answer for each query
        print(ans[p])
     
 
 
# Driver Code
if __name__ == '__main__':
    N = 8
    adj = []
    for i in range(N+1):
        adj.append([])
 
    # Given edges (root=1)
    add(1, 2)
    add(1, 3)
    add(2, 4)
    add(2, 5)
    add(5, 8)
    add(5, 6)
    add(6, 7)
 
    # Store number given to vertices
    # set num[0]=-1 because
    # there is no vertex 0
    num = [-1, 11, 2, 7, 1,-3, -1, -1, -3]
 
    # Queries
    q = 3
 
    # To store queries
    queries=[]
 
    # Given Queries
    queries.append((2, 3))
    queries.append((5, 2))
    queries.append((7, 1))
 
    # Function Call
    solveQuery(queries, N, q)
Output: 
false
true
true

 

Time Complexity: O(N*logN2
Auxiliary Space: O(N + E + Q) 

 




My Personal Notes arrow_drop_up
Recommended Articles
Page :

Start Your Coding Journey Now!