Open In App

Extended Disjoint Set Union on Trees

Improve
Improve
Like Article
Like
Save
Share
Report

Prerequisites: DFS, Trees, DSU
Given a tree with of N nodes from value 1 to N and E edges and array arr[] which denotes number associated to each node. You are also given Q queries which contains 2 integers {V, F}. For each query, there is a subtree with vertex V, the task is to check if there exists count of numbers associated with each node in that subtree is F or not. If yes then print True else print False.
Examples: 
 

Input: N = 8, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {5, 8}, {5, 6}, {6, 7} }, arr[] = { 11, 2, 7, 1, -3, -1, -1, -3 }, Q = 3, queries[] = { {2, 3}, {5, 2}, {7, 1} } 
Output: 
False 
True 
True 
Explanation: 
Query 1: No number occurs three times in sub-tree 2 
Query 2: Number -1 and -3 occurs 2 times in sub-tree 5 
Query 3: Number -1 occurs once in sub-tree 7 
Input: N = 11, E = { {1, 2}, {1, 3}, {2, 4}, {2, 5}, {4, 9}, {4, 8}, {3, 6}, {3, 7}, {4, 10}, {5, 11} }, arr[] = { 2, -1, -12, -1, 6, 14, 7, -1, -2, 13, 12 }, Q = 2, queries[] = { {2, 2}, {4, 2} } 
Output: 
False 
True 
Explanation: 
Query 1: No number occurs exactly 2 times in sub-tree 2 
Query 2: Number -1 occurs twice in sub-tree 4 
 

 

Naive Approach: The idea is to traverse the tree using DFS traversal and calculate the frequency of number associated with each vertices of sub-tree, V and store the result in a hashmap. After traversal, we just need to traverse the hashmap to check if the given frequency number exists.
Below is the implementation of the above approach:
 

C++




// C++ code for the above program
#include <bits/stdc++.h>
using namespace std;
 
// To store edges of tree
vector<int> adj[100005];
void dfs(int v, int p, bool isQV, int qv);
// To store number associated
// with vertex
int num[9] = {-1, 11, 2, 7, 1, -3, -1, -1, -3};
 
// To store frequency of number
unordered_map<int, int> freq;
 
// Function to add edges in tree
void add(int u, int v)
{
    adj[u].push_back(v);
    adj[v].push_back(u);
}
 
// Function returns boolean value
// representing is there any number
// present in subtree qv having
// frequency qc
bool query(int qv, int qc)
{
 
    freq.clear();
 
    // Start from root
    int v = 1;
 
    // DFS Call
    if (qv == v) {
        dfs(v, 0, true, qv);
    }
    else
        dfs(v, 0, false, qv);
 
    // Check for frequency
    for (auto fq : freq) {
        if (fq.second == qc)
            return true;
    }
    return false;
}
 
// Function to implement DFS
void dfs(int v, int p,
         bool isQV, int qv)
{
    if (isQV) {
 
        // If we are on subtree qv,
        // then increment freq of
        // num[v] in freq
        freq[num[v]]++;
    }
 
    // Recursive DFS Call for
    // adjacency list of node u
    for (int u : adj[v]) {
        if (p != u) {
            if (qv == u) {
                dfs(u, v, true, qv);
            }
            else
                dfs(u, v, isQV, qv);
        }
    }
}
 
// Driver Code
int main()
{
 
    // Given Nodes
    int n = 8;
    for (int i = 1; i <= n; i++)
        adj[i].clear();
 
    // Given edges of tree
    // (root=1)
    add(1, 2);
    add(1, 3);
    add(2, 4);
    add(2, 5);
    add(5, 8);
    add(5, 6);
    add(6, 7);
 
 
    // Total number of queries
    int q = 3;
 
    // Function Call to find each query
    cout << (query(2, 3) ? "true" : "false") << endl;
    cout << (query(5, 2) ? "true" : "false") << endl;
    cout << (query(7, 1) ? "true" : "false") << endl;
 
    return 0;
}


Java




// Java program for the above approach
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges of tree
    static ArrayList<Integer> adj[];
 
    // To store number associated
    // with vertex
    static int num[];
 
    // To store frequency of number
    static HashMap<Integer, Integer> freq;
 
    // Function to add edges in tree
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function returns boolean value
    // representing is there any number
    // present in subtree qv having
    // frequency qc
    static boolean query(int qv, int qc)
    {
 
        freq = new HashMap<>();
 
        // Start from root
        int v = 1;
 
        // DFS Call
        if (qv == v) {
            dfs(v, 0, true, qv);
        }
        else
            dfs(v, 0, false, qv);
 
        // Check for frequency
        for (int fq : freq.values()) {
            if (fq == qc)
                return true;
        }
        return false;
    }
 
    // Function to implement DFS
    static void dfs(int v, int p,
                    boolean isQV, int qv)
    {
        if (isQV) {
 
            // If we are on subtree qv,
            // then increment freq of
            // num[v] in freq
            freq.put(num[v],
                     freq.getOrDefault(num[v], 0) + 1);
        }
 
        // Recursive DFS Call for
        // adjacency list of node u
        for (int u : adj[v]) {
            if (p != u) {
                if (qv == u) {
                    dfs(u, v, true, qv);
                }
                else
                    dfs(u, v, isQV, qv);
            }
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
 
        // Given Nodes
        int n = 8;
        adj = new ArrayList[n + 1];
        for (int i = 1; i <= n; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges of tree
        // (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Number assigned to each vertex
        num = new int[] { -1, 11, 2, 7, 1, -3, -1, -1, -3 };
 
        // Total number of queries
        int q = 3;
 
        // Function Call to find each query
        System.out.println(query(2, 3));
        System.out.println(query(5, 2));
        System.out.println(query(7, 1));
    }
}


Python3




# Python 3 program for the above approach
 
# To store edges of tree
adj=[]
 
# To store number associated
# with vertex
num=[]
 
# To store frequency of number
freq=dict()
 
# Function to add edges in tree
def add(u, v):
    adj[u].append(v)
    adj[v].append(u)
 
 
# Function returns boolean value
# representing is there any number
# present in subtree qv having
# frequency qc
def query(qv, qc):
    freq.clear()
     
    # Start from root
    v = 1
 
    # DFS Call
    if (qv == v) :
        dfs(v, 0, True, qv)
     
    else:
        dfs(v, 0, False, qv)
 
    # Check for frequency
    if qc in freq.values() :
        return True
     
    return False
 
 
# Function to implement DFS
def dfs(v, p, isQV, qv):
    if (isQV) :
 
        # If we are on subtree qv,
        # then increment freq of
        # num[v] in freq
        freq[num[v]]=freq.get(num[v], 0) + 1
     
 
    # Recursive DFS Call for
    # adjacency list of node u
    for u in adj[v]:
        if (p != u) :
            if (qv == u) :
                dfs(u, v, True, qv)
             
            else:
                dfs(u, v, isQV, qv)
         
     
 
 
# Driver Code
if __name__ == '__main__':
 
    # Given Nodes
    n = 8
    for _ in range(n+1):
        adj.append([])
 
    # Given edges of tree
    # (root=1)
    add(1, 2)
    add(1, 3)
    add(2, 4)
    add(2, 5)
    add(5, 8)
    add(5, 6)
    add(6, 7)
 
    # Number assigned to each vertex
    num = [-1, 11, 2, 7, 1, -3, -1, -1, -3]
 
    # Total number of queries
    q = 3
 
    # Function Call to find each query
    print(query(2, 3))
    print(query(5, 2))
    print(query(7, 1))


C#




using System;
using System.Collections.Generic;
 
public class GFG
{
    // To store edges of tree
    static List<int>[] adj;
 
    // To store number associated
    // with vertex
    static int[] num;
 
    // To store frequency of number
    static Dictionary<int, int> freq;
 
    // Function to add edges in tree
    static void Add(int u, int v)
    {
        adj[u].Add(v);
        adj[v].Add(u);
    }
 
    // Function returns boolean value
    // representing is there any number
    // present in subtree qv having
    // frequency qc
    static bool Query(int qv, int qc)
    {
        freq = new Dictionary<int, int>();
 
        // Start from root
        int v = 1;
 
        // DFS Call
        if (qv == v)
        {
            Dfs(v, 0, true, qv);
        }
        else
        {
            Dfs(v, 0, false, qv);
        }
 
        // Check for frequency
        foreach (int fq in freq.Values)
        {
            if (fq == qc)
            {
                return true;
            }
        }
        return false;
    }
 
    // Function to implement DFS
    static void Dfs(int v, int p, bool isQV, int qv)
    {
        if (isQV)
        {
            // If we are on subtree qv,
            // then increment freq of
            // num[v] in freq
            freq[num[v]] = freq.GetValueOrDefault(num[v], 0) + 1;
        }
 
        // Recursive DFS Call for
        // adjacency list of node u
        foreach (int u in adj[v])
        {
            if (p != u)
            {
                if (qv == u)
                {
                    Dfs(u, v, true, qv);
                }
                else
                {
                    Dfs(u, v, isQV, qv);
                }
            }
        }
    }
 
    // Driver Code
    public static void Main(string[] args)
    {
        // Given Nodes
        int n = 8;
        adj = new List<int>[n + 1];
        for (int i = 1; i <= n; i++)
        {
            adj[i] = new List<int>();
        }
 
        // Given edges of tree
        // (root=1)
        Add(1, 2);
        Add(1, 3);
        Add(2, 4);
        Add(2, 5);
        Add(5, 8);
        Add(5, 6);
        Add(6, 7);
 
        // Number assigned to each vertex
        num = new int[] { -1, 11, 2, 7, 1, -3, -1, -1, -3 };
 
        // Total number of queries
        int q = 3;
 
        // Function Call to find each query
        Console.WriteLine(Query(2, 3));
        Console.WriteLine(Query(5, 2));
        Console.WriteLine(Query(7, 1));
    }
}


Javascript




// To store edges of tree
const adj = [];
 
// To store number associated with vertex
const num = [];
 
// To store frequency of number
let freq = {};
 
// Function to add edges in tree
function add(u, v) {
    adj[u].push(v);
    adj[v].push(u);
}
 
// Function returns boolean value
// representing is there any number
// present in subtree qv having
// frequency qc
function query(qv, qc) {
    freq = {};
 
    // Start from root
    let v = 1;
 
    // DFS Call
    if (qv == v) {
        dfs(v, 0, true, qv);
    } else {
        dfs(v, 0, false, qv);
    }
 
    // Check for frequency
    return Object.values(freq).includes(qc);
}
 
// Function to implement DFS
function dfs(v, p, isQV, qv) {
    if (isQV) {
        // If we are on subtree qv,
        // then increment freq of
        // num[v] in freq
        freq[num[v]] = (freq[num[v]] || 0) + 1;
    }
 
    // Recursive DFS Call for adjacency list of node u
    for (let u of adj[v]) {
        if (p !== u) {
            if (qv == u) {
                dfs(u, v, true, qv);
            } else {
                dfs(u, v, isQV, qv);
            }
        }
    }
}
 
// Driver Code
    // Given Nodes
    const n = 8;
    for (let i = 0; i <= n; i++) {
        adj.push([]);
    }
 
    // Given edges of tree (root=1)
    add(1, 2);
    add(1, 3);
    add(2, 4);
    add(2, 5);
    add(5, 8);
    add(5, 6);
    add(6, 7);
 
    // Number assigned to each vertex
    num.push(-1, 11, 2, 7, 1, -3, -1, -1, -3);
 
    // Total number of queries
    const q = 3;
 
    // Function Call to find each query
    console.log(query(2, 3));
    console.log(query(5, 2));
    console.log(query(7, 1));


Output: 

false
true
true

 

Time Complexity: O(N * Q) Since in each query, tree needs to be traversed. 
Auxiliary Space: O(N + E + Q)
Efficient Approach: The idea is to use Extended Disjoint Set Union to the above approach:
 

  1. Create an array, size[] to store size of sub-trees.
  2. Create an array of hashmaps, map[] i.e, map[V][X] = total vertices of number X in sub-tree, V.
  3. Calculate the size of each subtree using DFS traversal by calling dfsSize().
  4. Using DFS traversal by calling dfs(V, p), calculate the value of map[V].
  5. In the traversal, to calculate map[V], choose the adjacent vertex of V having maximum size ( bigU ) except parent vertex, p
     
  6. For join operation pass the reference of map[bigU] to map[V] i.e, map[V] = map[bigU].
  7. And atlast merge maps of all adjacent vertices, u to map[V], except parent vertex, p and bigU vertex.
  8. Now, check if map[V] contains frequency F or not. If yes then print True else print False.

Below is the implementation of the efficient approach: 
 

Java




// Java program for the above approach
import java.awt.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
 
@SuppressWarnings("unchecked")
public class Main {
 
    // To store edges
    static ArrayList<Integer> adj[];
 
    // num[v] = number assigned
    // to vertex v
    static int num[];
 
    // map[v].get(c) = total vertices
    // in subtree v having number c
    static Map<Integer, Integer> map[];
 
    // size[v]=size of subtree v
    static int size[];
 
    static HashMap<Point, Boolean> ans;
    static ArrayList<Integer> qv[];
 
    // Function to add edges
    static void add(int u, int v)
    {
        adj[u].add(v);
        adj[v].add(u);
    }
 
    // Function to find subtree size
    // of every vertex using dfsSize()
    static void dfsSize(int v, int p)
    {
        size[v]++;
 
        // Traverse dfsSize recursively
        for (int u : adj[v]) {
            if (p != u) {
                dfsSize(u, v);
                size[v] += size[u];
            }
        }
    }
 
    // Function to implement DFS Traversal
    static void dfs(int v, int p)
    {
        int mx = -1, bigU = -1;
 
        // Find adjacent vertex with
        // maximum size
        for (int u : adj[v]) {
            if (u != p) {
                dfs(u, v);
                if (size[u] > mx) {
                    mx = size[u];
                    bigU = u;
                }
            }
        }
 
        if (bigU != -1) {
 
            // Passing referencing
            map[v] = map[bigU];
        }
        else {
 
            // If no adjacent vertex
            // present initialize map[v]
            map[v] = new HashMap<Integer, Integer>();
        }
 
        // Update frequency of current number
        map[v].put(num[v],
                   map[v].getOrDefault(num[v], 0) + 1);
 
        // Add all adjacent vertices
        // maps to map[v]
        for (int u : adj[v]) {
 
            if (u != bigU && u != p) {
 
                for (Entry<Integer, Integer>
                         pair : map[u].entrySet()) {
                    map[v].put(
                        pair.getKey(),
                        pair.getValue()
                            + map[v]
                                  .getOrDefault(
                                      pair.getKey(), 0));
                }
            }
        }
 
        // Store all queries related
        // to vertex v
        for (int freq : qv[v]) {
            ans.put(new Point(v, freq),
                    map[v].containsValue(freq));
        }
    }
 
    // Function to find answer for
    // each queries
    static void solveQuery(Point queries[],
                           int N, int q)
    {
        // Add all queries to qv
        // where i<qv[v].size()
        qv = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            qv[i] = new ArrayList<>();
 
        for (Point p : queries) {
            qv[p.x].add(p.y);
        }
 
        // Get sizes of all subtrees
        size = new int[N + 1];
 
        // calculate size[]
        dfsSize(1, 0);
 
        // Map will be used to store
        // answers for current vertex
        // on dfs
        map = new HashMap[N + 1];
 
        // To store answer of queries
        ans = new HashMap<>();
 
        // DFS Call
        dfs(1, 0);
 
        for (Point p : queries) {
 
            // Print answer for each query
            System.out.println(ans.get(p));
        }
    }
 
    // Driver Code
    public static void main(String[] args)
    {
        int N = 8;
        adj = new ArrayList[N + 1];
        for (int i = 1; i <= N; i++)
            adj[i] = new ArrayList<>();
 
        // Given edges (root=1)
        add(1, 2);
        add(1, 3);
        add(2, 4);
        add(2, 5);
        add(5, 8);
        add(5, 6);
        add(6, 7);
 
        // Store number given to vertices
        // set num[0]=-1 because
        // there is no vertex 0
        num
            = new int[] { -1, 11, 2, 7, 1,
                          -3, -1, -1, -3 };
 
        // Queries
        int q = 3;
 
        // To store queries
        Point queries[] = new Point[q];
 
        // Given Queries
        queries[0] = new Point(2, 3);
        queries[1] = new Point(5, 2);
        queries[2] = new Point(7, 1);
 
        // Function Call
        solveQuery(queries, N, q);
    }
}


Python3




# Python program for the above approach
 
from typing import List, Dict, Tuple
 
def add(u: int, v: int, adj: List[List[int]]) -> None:
    adj[u].append(v)
    adj[v].append(u)
 
# Function to find subtree size
# of every vertex using dfsSize()
def dfs_size(v: int, p: int, adj: List[List[int]], size: List[int]) -> None:
    size[v] = 1
     
    # Traverse dfsSize recursively
    for u in adj[v]:
        if u != p:
            dfs_size(u, v, adj, size)
            size[v] += size[u]
 
 # Function to implement DFS Traversal
def dfs(v: int, p: int, adj: List[List[int]], num: List[int], map: List[Dict[int,int]], qv: List[List[int]], ans: Dict[Tuple[int,int],bool], size: List[int]) -> None:
    mx = -1
    big_u = -1
     
    # Find adjacent vertex with
    # maximum size
    for u in adj[v]:
        if u != p:
            dfs(u, v, adj, num, map, qv, ans, size)
            if size[u] > mx:
                mx = size[u]
                big_u = u
 
    # Passing referencing
    if big_u != -1:
        map[v] = map[big_u]
    else:
         
        # If no adjacent vertex
         # present initialize map[v]
        map[v] = {}
         
    # Update frequency of current number
    map[v][num[v]] = map[v].get(num[v], 0) + 1
 
    # Add all adjacent vertices
    # maps to map[v]
    for u in adj[v]:
        if u != big_u and u != p:
            for key, value in map[u].items():
                map[v][key] = value + map[v].get(key, 0)
 
    # Store all queries related
    # to vertex v
    for freq in qv[v]:
        ans[(v, freq)] = freq in map[v].values()
 
# Function to find answer for
# each queries
def solve_query(queries: List[Tuple[int,int]], N: int, q: int, adj: List[List[int]], num: List[int]) -> List[bool]:
     
    # Add all queries to qv
    # where i<qv[v].size()
    qv = [[] for i in range(N+1)]
    for p in queries:
        qv[p[0]].append(p[1])
 
    # Get sizes of all subtrees
    size = [0]*(N+1)
     
    # calculate size[]
    dfs_size(1, 0, adj, size)
 
    # Map will be used to store
    # answers for current vertex
    # on dfs
    map = [None]*(N+1)
     
    # To store answer of queries
    ans = {}
 
     # DFS Call
    dfs(1, 0, adj, num, map, qv, ans, size)
 
    # Print answer for each query
    return [ans[p] for p in queries]
 
if __name__ == '__main__':
    N = 8
     
    # To store edges
    adj = [[] for i in range(N+1)]
 
    # Given edges (root=1)
    add(1, 2, adj)
    add(1, 3, adj)
    add(2, 4, adj)
    add(2, 5, adj)
    add(5, 8, adj)
    add(5, 6, adj)
    add(6, 7, adj)
 
    # Store number given to vertices
    # set num[0]=-1 because
    # there is no vertex 0
    num = [-1, 11, 2, 7, 1, -3, -1, -1, -3]
 
    # Queries
    q = 3
 
    # Given Queries
    queries = [(2,3), (5,2), (7,1)]
 
    # Function Call
    ans = solve_query(queries, N, q, adj, num)
    for val in ans:
        print(val)
 
# The code is contributed by Arushi Goel.


Output: 

false
true
true

 

Time Complexity: O(N*logN2
Auxiliary Space: O(N + E + Q) 
 



Last Updated : 06 Apr, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads