Extended Disjoint Set Union on Trees
Prerequisites: DFS, Trees, DSU
Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.
Examples:
Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} }
Output:
False
True
True
Explanation:
Query 1: No number occurs three times in sub-tree 2
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5
Query 3: Number -1 occurs once in sub-tree 7
Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} }
Output:
False
True
Explanation:
Query 1: No number occurs exactly 2 times in sub-tree 2
Query 2: Number -1 occurs twice in sub-tree 4
Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.
Below is the implementation of the above approach:
Java
// Java program for the above approach import java.util.ArrayList; import java.util.HashMap; import java.util.Map; @SuppressWarnings ( "unchecked" ) public class Main { // To store edges of tree static ArrayList<Integer> adj[]; // To store number associated // with vertex static int num[]; // To store frequency of number static HashMap<Integer, Integer> freq; // Function to add edges in tree static void add( int u, int v) { adj[u].add(v); adj[v].add(u); } // Function returns boolean value // representing is there any number // present in subtree qv having // frequency qc static boolean query( int qv, int qc) { freq = new HashMap<>(); // Start from root int v = 1 ; // DFS Call if (qv == v) { dfs(v, 0 , true , qv); } else dfs(v, 0 , false , qv); // Check for frequency for ( int fq : freq.values()) { if (fq == qc) return true ; } return false ; } // Function to implement DFS static void dfs( int v, int p, boolean isQV, int qv) { if (isQV) { // If we are on subtree qv, // then increment freq of // num[v] in freq freq.put(num[v], freq.getOrDefault(num[v], 0 ) + 1 ); } // Recursive DFS Call for // adjacency list of node u for ( int u : adj[v]) { if (p != u) { if (qv == u) { dfs(u, v, true , qv); } else dfs(u, v, isQV, qv); } } } // Driver Code public static void main(String[] args) { // Given Nodes int n = 8 ; adj = new ArrayList[n + 1 ]; for ( int i = 1 ; i <= n; i++) adj[i] = new ArrayList<>(); // Given edges of tree // (root=1) add( 1 , 2 ); add( 1 , 3 ); add( 2 , 4 ); add( 2 , 5 ); add( 5 , 8 ); add( 5 , 6 ); add( 6 , 7 ); // Number assigned to each vertex num = new int [] { - 1 , 11 , 2 , 7 , 1 , - 3 , - 1 , - 1 , - 3 }; // Total number of queries int q = 3 ; // Function Call to find each query System.out.println(query( 2 , 3 )); System.out.println(query( 5 , 2 )); System.out.println(query( 7 , 1 )); } } |
Python3
# Python 3 program for the above approach # To store edges of tree adj = [] # To store number associated # with vertex num = [] # To store frequency of number freq = dict () # Function to add edges in tree def add(u, v): adj[u].append(v) adj[v].append(u) # Function returns boolean value # representing is there any number # present in subtree qv having # frequency qc def query(qv, qc): freq.clear() # Start from root v = 1 # DFS Call if (qv = = v) : dfs(v, 0 , True , qv) else : dfs(v, 0 , False , qv) # Check for frequency if qc in freq.values() : return True return False # Function to implement DFS def dfs(v, p, isQV, qv): if (isQV) : # If we are on subtree qv, # then increment freq of # num[v] in freq freq[num[v]] = freq.get(num[v], 0 ) + 1 # Recursive DFS Call for # adjacency list of node u for u in adj[v]: if (p ! = u) : if (qv = = u) : dfs(u, v, True , qv) else : dfs(u, v, isQV, qv) # Driver Code if __name__ = = '__main__' : # Given Nodes n = 8 for _ in range (n + 1 ): adj.append([]) # Given edges of tree # (root=1) add( 1 , 2 ) add( 1 , 3 ) add( 2 , 4 ) add( 2 , 5 ) add( 5 , 8 ) add( 5 , 6 ) add( 6 , 7 ) # Number assigned to each vertex num = [ - 1 , 11 , 2 , 7 , 1 , - 3 , - 1 , - 1 , - 3 ] # Total number of queries q = 3 # Function Call to find each query print (query( 2 , 3 )) print (query( 5 , 2 )) print (query( 7 , 1 )) |
false true true
Time Complexity: O(N * Q) Since in each query, tree needs to be traversed.
Auxiliary Space: O(N + E + Q)
Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:
- Create an array, size[] to store size of sub-trees.
- Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
- Calculate the size of each subtree using DFS traversal by calling dfsSize().
- Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
- In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p.
- For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
- And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
- Now, check if map[V] contains frequency F or not. If yes then print True else print False.
Below is the implementation of the efficient approach:
Java
// Java program for the above approach import java.awt.*; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; @SuppressWarnings ( "unchecked" ) public class Main { // To store edges static ArrayList<Integer> adj[]; // num[v] = number assigned // to vertex v static int num[]; // map[v].get(c) = total vertices // in subtree v having number c static Map<Integer, Integer> map[]; // size[v]=size of subtree v static int size[]; static HashMap<Point, Boolean> ans; static ArrayList<Integer> qv[]; // Function to add edges static void add( int u, int v) { adj[u].add(v); adj[v].add(u); } // Function to find subtree size // of every vertex using dfsSize() static void dfsSize( int v, int p) { size[v]++; // Traverse dfsSize recursively for ( int u : adj[v]) { if (p != u) { dfsSize(u, v); size[v] += size[u]; } } } // Function to implement DFS Traversal static void dfs( int v, int p) { int mx = - 1 , bigU = - 1 ; // Find adjacent vertex with // maximum size for ( int u : adj[v]) { if (u != p) { dfs(u, v); if (size[u] > mx) { mx = size[u]; bigU = u; } } } if (bigU != - 1 ) { // Passing referencing map[v] = map[bigU]; } else { // If no adjacent vertex // present initialize map[v] map[v] = new HashMap<Integer, Integer>(); } // Update frequency of current number map[v].put(num[v], map[v].getOrDefault(num[v], 0 ) + 1 ); // Add all adjacent vertices // maps to map[v] for ( int u : adj[v]) { if (u != bigU && u != p) { for (Entry<Integer, Integer> pair : map[u].entrySet()) { map[v].put( pair.getKey(), pair.getValue() + map[v] .getOrDefault( pair.getKey(), 0 )); } } } // Store all queries related // to vertex v for ( int freq : qv[v]) { ans.put( new Point(v, freq), map[v].containsValue(freq)); } } // Function to find answer for // each queries static void solveQuery(Point queries[], int N, int q) { // Add all queries to qv // where i<qv[v].size() qv = new ArrayList[N + 1 ]; for ( int i = 1 ; i <= N; i++) qv[i] = new ArrayList<>(); for (Point p : queries) { qv[p.x].add(p.y); } // Get sizes of all subtrees size = new int [N + 1 ]; // calculate size[] dfsSize( 1 , 0 ); // Map will be used to store // answers for current vertex // on dfs map = new HashMap[N + 1 ]; // To store answer of queries ans = new HashMap<>(); // DFS Call dfs( 1 , 0 ); for (Point p : queries) { // Print answer for each query System.out.println(ans.get(p)); } } // Driver Code public static void main(String[] args) { int N = 8 ; adj = new ArrayList[N + 1 ]; for ( int i = 1 ; i <= N; i++) adj[i] = new ArrayList<>(); // Given edges (root=1) add( 1 , 2 ); add( 1 , 3 ); add( 2 , 4 ); add( 2 , 5 ); add( 5 , 8 ); add( 5 , 6 ); add( 6 , 7 ); // Store number given to vertices // set num[0]=-1 because // there is no vertex 0 num = new int [] { - 1 , 11 , 2 , 7 , 1 , - 3 , - 1 , - 1 , - 3 }; // Queries int q = 3 ; // To store queries Point queries[] = new Point[q]; // Given Queries queries[ 0 ] = new Point( 2 , 3 ); queries[ 1 ] = new Point( 5 , 2 ); queries[ 2 ] = new Point( 7 , 1 ); // Function Call solveQuery(queries, N, q); } } |
Python3
# Python 3 program for the above approach # To store edges of tree adj = [] # To store number associated # with vertex num = [] # mp[v] = total vertices # in subtree v having number c mp = dict () # size[v]=size of subtree v size = [] ans = dict () qv = [] # Function to add edges in tree def add(u, v): adj[u].append(v) adj[v].append(u) # Function to find subtree size # of every vertex using dfsSize() def dfsSize(v, p): size[v] + = 1 # Traverse dfsSize recursively for u in adj[v] : if (p ! = u) : dfsSize(u, v) size[v] + = size[u] # Function to implement DFS Traversal def dfs(v, p): global ans mx = - 1 ; bigU = - 1 # Find adjacent vertex with # maximum size for u in adj[v]: if (u ! = p) : dfs(u, v) if (size[u] > mx) : mx = size[u] bigU = u if (bigU ! = - 1 ) : # Passing referencing mp[v] = mp[bigU] else : # If no adjacent vertex # present initialize mp[v] mp[v] = dict () # Update frequency of current number mp[v][num[v]] = mp[v].get(num[v], 0 ) + 1 # Add all adjacent vertices # maps to mp[v] for u in adj[v] : if u not in (bigU,p) : for pair in mp[u].items(): mp[v][pair[ 0 ]] = pair[ 1 ] + mp[v].get(pair[ 0 ], 0 ) # Store all queries related # to vertex v for freq in qv[v] : ans[(v, freq)] = freq in mp[v] # Function to find answer for # each queries def solveQuery(queries, N, q): global size,mp,qv,ans # Add all queries to qv # where i<qv[v].size() qv = [] for i in range (N + 1 ): qv.append([]) for p in queries: qv[p[ 0 ]].append(p[ 1 ]) # Get sizes of all subtrees size = [ 0 ] * (N + 1 ) # calculate size[] dfsSize( 1 , 0 ) # mp will be used to store # answers for current vertex # on dfs mp = dict () # To store answer of queries ans = dict () # DFS Call dfs( 1 , 0 ) for p in queries: # Print answer for each query print (ans[p]) # Driver Code if __name__ = = '__main__' : N = 8 adj = [] for i in range (N + 1 ): adj.append([]) # Given edges (root=1) add( 1 , 2 ) add( 1 , 3 ) add( 2 , 4 ) add( 2 , 5 ) add( 5 , 8 ) add( 5 , 6 ) add( 6 , 7 ) # Store number given to vertices # set num[0]=-1 because # there is no vertex 0 num = [ - 1 , 11 , 2 , 7 , 1 , - 3 , - 1 , - 1 , - 3 ] # Queries q = 3 # To store queries queries = [] # Given Queries queries.append(( 2 , 3 )) queries.append(( 5 , 2 )) queries.append(( 7 , 1 )) # Function Call solveQuery(queries, N, q) |
false true true
Time Complexity: O(N*logN2)
Auxiliary Space: O(N + E + Q)
Please Login to comment...