Open In App

Count of distinct colors in a subtree of a Colored Tree with given min frequency for Q queries

Last Updated : 17 Mar, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a N-ary tree with some colour associated with every node and Q queries. Each query contains two integers A and X. The task is to count all distinct colors in a subtree rooted at A, having frequency of colors greater than or equal to X in that subtree.
Examples: 

Input: Tree: 
 

 
           1(1)
         /       \ 
       /          \
     2(2)         5(3)
    /   \       /  |   \
   /     \     /   |    \
3(2)    4(3) 6(2) 7(3)  8(3)


query[] = {{1, 2}, {1, 3}, {1, 4}, {2, 3}, {5, 3}} 
Output: {2, 2, 1, 0, 1} 
Explanation: 
In the subtree rooted at 1, the frequency of colour 2 is 3 and colour 3 is 4. So the answer is 2 for the 1 query, 2 for 2nd query and 1 for 3rd query. 
For subtree rooted at 2, frequency of color 2 is 2 and color 3 is 1. So no color have frequency more than or equal to 4. 
For subtree rooted at 5, frequency of color 2 is 1 and color 3 is 3. So color 3 have frequency equal to 3. 
 


Naive Approach 

  • For each query, we’ll traverse the whole subtree of the given node.
  • We’ll maintain a map which stores the frequency of every colour in the subtree of given node.
  • Then, traverse the map and count number of colours such that it’s frequency is greater than given x.


Time Complexity: O(Q * N) 
Space Complexity: O(Q * N)


Approach: (Using Mo’s Algorithm)  

  • We’ll first flat the tree using Euler Tour.  
  • We’ll give the number to every node, when it will go in the DFS and when it came out. Let’s denote this with tin[node] and tout[node] for every node.
  • After we have flatten the tree into an array, every sub tree of can be denoted as some  array with start and end index as tin[node] and tout[node] respectively.
  • Now the question changed into number of elements with frequency greater than equal to X in some subarray.
  • We’ll use Mo’s algorithm to solve this problem.
  • First, we’ll store the queries and sort them according to the tin[node] / SQ where SQ is the Square root of N.
  • As we move the pointers, we’ll store the frequency at ith colour in an array and answer to the query is stored at Xth position of array as it stores the count of colours with frequency greater equal to X.

C++

// C++ program to count distinct colors
// in a subtree of a Colored Tree
// with given min frequency for Q queries
 
#include <bits/stdc++.h>
using namespace std;
 
const int N = 1e5 + 5;
 
// graph as adjacency list
vector<vector<int> > v(N);
 
// colour written on node.
vector<int> colour(N);
 
// order of node entering in DFS
vector<int> in(N);
 
// order of node exiting in DFS
vector<int> out(N);
 
// for counting the frequency of
// nodes with colour i
vector<int> cnt(N);
 
// for storing frequency of colours
vector<int> freq(N);
 
// tree in a flatten
// form (using Euler tour)
vector<int> flatTree(N);
 
// number of nodes
int n,
 
    // square root of n
    sq;
 
// indexes for in and
// out of node in DFS
int start = 0;
 
// DFS function to find
// order of euler tour
void DFSEuler(int a, int par)
{
 
    // storing the start index
    in[a] = ++start;
 
    // storing colour of node
    // in flatten array
    flatTree[start] = colour[a];
 
    for (int i : v[a]) {
 
        // doing DFS on its child
        // skipping parent
        if (i == par)
            continue;
 
        DFSEuler(i, a);
    }
 
    // out index of the node.
    out[a] = start;
}
 
// comparator for queries
bool comparator(pair<int, int>& a,
                pair<int, int>& b)
{
    // comparator for queries to be
    // sorted according to in[x] / sq
    if (in[a.first] / sq != in[b.first] / sq)
        return in[a.first] < in[b.first];
 
    return out[a.first] < out[b.first];
}
 
// Function to answer the queries
void solve(vector<pair<int, int> > arr,
           int q)
{
    sq = sqrt(n) + 1;
 
    // for storing answers
    vector<int> answer(q);
 
    // for storing indexes of queries
    // in the order of input.
    map<pair<int, int>, int> idx;
 
    for (int i = 0; i < q; i++) {
 
        // storing indexes of queries
        idx[arr[i]] = i;
    }
 
    // doing depth first search to
    // find indexes to flat the
    // tree using euler tour.
    DFSEuler(1, 0);
 
    // After doing Euler tour,
    // subtree of x can be
    // represented as a subarray
    // from in[x] to out[x];
 
    // we'll sort the queries
    // according to the in[i];
    sort(arr.begin(),
         arr.end(),
         comparator);
 
    // two pointers for
    // sliding the window
    int l = 1, r = 0;
 
    for (int i = 0; i < q; i++) {
 
        // finding answer to the query
        int node = arr[i].first,
            x = arr[i].second;
        int id = idx[arr[i]];
 
        while (l > in[node]) {
 
            // decrementing the pointer as
            // it is greater than start
            // and adding answer
            // to our freq array.
            l--;
            cnt[flatTree[l]]++;
            freq[cnt[flatTree[l]]]++;
        }
 
        while (r < out[node]) {
 
            // incrementing pointer as it is
            // less than the end value and
            // adding answer to our freq array.
            r++;
            cnt[flatTree[r]]++;
            freq[cnt[flatTree[r]]]++;
        }
 
        while (l < in[node]) {
 
            // removing the lth node from
            // freq array and incrementing
            // the pointer
            freq[cnt[flatTree[l]]]--;
            cnt[flatTree[l]]--;
            l++;
        }
 
        while (r > out[node]) {
 
            // removing the rth node from
            // freq array and decrementing
            // the pointer
            freq[cnt[flatTree[r]]]--;
            cnt[flatTree[r]]--;
            r--;
        }
 
        // answer to this query
        // is stored at freq[x]
        // freq[x] stores the frequency
        // of nodes greater equal to x
        answer[id] = freq[x];
    }
 
    // printing the queries
    for (int i = 0; i < q; i++)
        cout << answer[i] << " ";
}
 
int main()
{
    // Driver Code
    /*
               1(1)
             /       \
           /          \
         2(2)         5(3)
        /   \       /  |   \
       /     \     /   |    \
    3(2)    4(3) 6(2) 7(3)  8(3)
    */
    n = 8;
    v[1].push_back(2);
    v[2].push_back(1);
    v[2].push_back(3);
    v[3].push_back(2);
    v[2].push_back(4);
    v[4].push_back(2);
    v[1].push_back(5);
    v[5].push_back(1);
    v[5].push_back(6);
    v[6].push_back(5);
    v[5].push_back(7);
    v[7].push_back(5);
    v[5].push_back(8);
    v[8].push_back(5);
 
    colour[1] = 1;
    colour[2] = 2;
    colour[3] = 2;
    colour[4] = 3;
    colour[5] = 3;
    colour[6] = 2;
    colour[7] = 3;
    colour[8] = 3;
 
    vector<pair<int, int> > queries
        = { { 1, 2 },
            { 1, 3 },
            { 1, 4 },
            { 2, 3 },
            { 5, 3 } };
    int q = queries.size();
 
    solve(queries, q);
    return 0;
}

                    

Java

// Java program to count distinct colors
// in a subtree of a Colored Tree
// with given min frequency for Q queries
import java.util.*;
 
public class Main {
 
    static int N = 100005;
    // graph as adjacency list
    static List<Integer>[] v = new List[N];
     
      // colour written on node.
    static int[] colour = new int[N];
   
      // order of node entering in DFS
    static int[] in_time = new int[N];
   
      // order of node exiting in DFS
    static int[] out_time = new int[N];
   
      // for counting the frequency of
    // nodes with colour i
    static int[] cnt = new int[N];
   
      // for storing frequency of colours
    static int[] freq = new int[N];
   
      // tree in a flatten
    // form (using Euler tour)
    static int[] flatTree = new int[N];
 
      // number of nodes
    static int n = 0;
   
      // square root of n
    static int sq = 0;
   
      // indexes for in and
    // out of node in DFS
    static int start = 0;
 
      // DFS function to find
    // order of euler tour
    static void DFSEuler(int a, int par)
    {
          // storing the start index
        start++;
        in_time[a] = start;
           
          // storing colour of node
        // in flatten array
        flatTree[start] = colour[a];
        for (int i : v[a]) {
              // doing DFS on its child
            // skipping parent
            if (i == par) {
                continue;
            }
            DFSEuler(i, a);
        }
          // out index of the node.
        out_time[a] = start;
    }
     
      // comparator for queries
    static int comparator(Pair<Integer, Integer> a,
                          Pair<Integer, Integer> b)
    {
          // comparator for queries to be
        // sorted according to in[x] / sq
        if (in_time[a.getKey()] / sq != in_time[b.getKey()] / sq) {
            return in_time[a.getKey()] - in_time[b.getKey()];
        }
        return out_time[a.getKey()] - out_time[b.getKey()];
    }
   
    // Function to answer the queries
    static void solve(List<Pair<Integer, Integer> > arr,
                      int q)
    {
          // for storing answers
        sq = (int)Math.sqrt(n) + 1;
        int[] answer = new int[q];
       
          // for storing indexes of queries
        // in the order of input.
        Map<Pair<Integer, Integer>, Integer> idx = new HashMap<>();
        for (int i = 0; i < q; i++) {
              // storing indexes of queries
            idx.put(arr.get(i), i);
        }
           
          // doing depth first search to
        // find indexes to flat the
        // tree using euler tour.   
        DFSEuler(1, 0);
           
          // After doing Euler tour,
        // subtree of x can be
        // represented as a subarray
        // from in[x] to out[x];
  
        // we'll sort the queries
        // according to the in[i];
        arr.sort(new Comparator<Pair<Integer, Integer> >() {
            @Override
            public int compare(Pair<Integer, Integer> a,
                               Pair<Integer, Integer> b)
            {
                if (in[a.getFirst()] / sq
                    != in[b.getFirst()] / sq)
                    return in[a.getFirst()]
                        - in[b.getFirst()];
                return out[a.getFirst()]
                    - out[b.getFirst()];
            }
        });
       
          // two pointers for
        // sliding the window
        int l = 1, r = 0;
       
        for (int i = 0; i < q; i++) {
              // finding answer to the query
            int node = arr.get(i).getFirst(),
                x = arr.get(i).getSecond(),
                id = idx.get(arr.get(i));
            while (l > in[node]) {
               
                  // decrementing the pointer as
                // it is greater than start
                // and adding answer
                // to our freq array.
                l--;
                cnt[flatTree[l]]++;
                freq[cnt[flatTree[l]]]++;
            }
               
            while (r < out[node]) {
               
                  // incrementing pointer as it is
                // less than the end value and
                // adding answer to our freq array.
                r++;
                cnt[flatTree[r]]++;
                freq[cnt[flatTree[r]]]++;
            }
            while (l < in[node]) {
               
                  // removing the lth node from
                // freq array and incrementing
                // the pointer
                freq[cnt[flatTree[l]]]--;
                cnt[flatTree[l]]--;
                l++;
            }
            while (r > out[node]) {
               
                  // removing the rth node from
                // freq array and decrementing
                // the pointer
                freq[cnt[flatTree[r]]]--;
                cnt[flatTree[r]]--;
                r--;
            }
           
              // answer to this query
            // is stored at freq[x]
            // freq[x] stores the frequency
            // of nodes greater equal to x
            answer[id] = freq[x];
        }
           
          // printing the queries
        for (int i = 0; i < q; i++)
            System.out.print(answer[i] + " ");
    }
    public static void main(String[] args)
    {
        // Driver Code
        /*
                   1(1)
                 /       \
               /          \
             2(2)         5(3)
            /   \       /  |   \
           /     \     /   |    \
        3(2)    4(3) 6(2) 7(3)  8(3)
        */
        Scanner sc = new Scanner(System.in);
        n = 8;
        for (int i = 1; i <= n; i++) {
            v[i] = new ArrayList<Integer>();
        }
        v[1].add(2);
        v[2].add(1);
        v[2].add(3);
        v[3].add(2);
        v[2].add(4);
        v[4].add(2);
        v[1].add(5);
        v[5].add(1);
        v[5].add(6);
        v[6].add(5);
        v[5].add(7);
        v[7].add(5);
        v[5].add(8);
        v[8].add(5);
        colour[1] = 1;
        colour[2] = 2;
        colour[3] = 2;
        colour[4] = 3;
        colour[5] = 3;
        colour[6] = 2;
        colour[7] = 3;
        colour[8] = 3;
        List<Pair<Integer, Integer> > queries
            = Arrays.asList(
                new Pair<Integer, Integer>(1, 2),
                new Pair<Integer, Integer>(1, 3),
                new Pair<Integer, Integer>(1, 4),
                new Pair<Integer, Integer>(2, 3),
                new Pair<Integer, Integer>(5, 3));
        int q = queries.size();
        solve(queries, q);
    }
}
 
// run this code in offline compiler because pair is not working in online compiler.

                    

Python3

import math
 
N = 100005
 
v = [[] for i in range(N)]
colour = [0]*N
in_time = [0]*N
out_time = [0]*N
cnt = [0]*N
freq = [0]*N
flatTree = [0]*N
 
n = 0
sq = 0
start = 0
 
def DFSEuler(a, par):
    global start
    start += 1
    in_time[a] = start
    flatTree[start] = colour[a]
     
    for i in v[a]:
        if i == par:
            continue
        DFSEuler(i, a)
         
    out_time[a] = start
 
def comparator(a, b):
    if in_time[a[0]] // sq != in_time[b[0]] // sq:
        return in_time[a[0]] < in_time[b[0]]
    return out_time[a[0]] < out_time[b[0]]
 
def solve(arr, q):
    global n, sq
    sq = int(math.sqrt(n)) + 1
    answer = [0]*q
    idx = {}
    for i in range(q):
        idx[arr[i]] = i
    DFSEuler(1, 0)
    arr.sort(key=lambda x: comparator(x, x))
    l, r = 1, 0
    for i in range(q):
        node, x = arr[i]
        id = idx[arr[i]]
        while l > in_time[node]:
            l -= 1
            cnt[flatTree[l]] += 1
            freq[cnt[flatTree[l]]] += 1
        while r < out_time[node]:
            r += 1
            cnt[flatTree[r]] += 1
            freq[cnt[flatTree[r]]] += 1
        while l < in_time[node]:
            freq[cnt[flatTree[l]]] -= 1
            cnt[flatTree[l]] -= 1
            l += 1
        while r > out_time[node]:
            freq[cnt[flatTree[r]]] -= 1
            cnt[flatTree[r]] -= 1
            r -= 1
        answer[id] = freq[x]
    for i in range(q):
        print(answer[i], end=" ")
 
if __name__ == '__main__':
    n = 8
    v[1].append(2)
    v[2].append(1)
    v[2].append(3)
    v[3].append(2)
    v[2].append(4)
    v[4].append(2)
    v[1].append(5)
    v[5].append(1)
    v[5].append(6)
    v[6].append(5)
    v[5].append(7)
    v[7].append(5)
    v[5].append(8)
    v[8].append(5)
    colour[1] = 1
    colour[2] = 2
    colour[3] = 2
    colour[4] = 3
    colour[5] = 3
    colour[6] = 2
    colour[7] = 3
    colour[8] = 3
    queries = [(1, 2), (1, 3), (1, 4), (2, 3), (5, 3)]
    q = len(queries)
    solve(queries, q)

                    

C#

// C# program to count distinct colors
// in a subtree of a Colored Tree
// with given min frequency for Q queries
using System;
using System.Collections.Generic;
using System.Linq;
 
// Array Equality Comparer
class ArrayEqualityComparer<T> : IEqualityComparer<T[]> {
    // Implementing the Equals method to compare two arrays
    public bool Equals(T[] x, T[] y)
    {
        if (x == null && y == null)
            return true;
        if (x == null || y == null)
            return false;
        if (x.Length != y.Length)
            return false;
        for (int i = 0; i < x.Length; i++) {
            if (!EqualityComparer<T>.Default.Equals(x[i],
                                                    y[i]))
                return false;
        }
        return true;
    }
 
    // Implementing the GetHashCode method to obtain a hash
    // code value for the array
    public int GetHashCode(T[] obj)
    {
        int hashCode = 0;
        foreach(T item in obj)
        {
            hashCode
                ^= EqualityComparer<T>.Default.GetHashCode(
                    item);
        }
        return hashCode;
    }
}
 
class GFG {
    private static readonly int N = 100005;
 
    private static readonly List<int>[] v
        = Enumerable.Range(0, N)
              .Select(_ => new List<int>())
              .ToArray();
    private static readonly int[] colour = new int[N];
    private static readonly int[] in_time = new int[N];
    private static readonly int[] out_time = new int[N];
    private static readonly int[] cnt = new int[N];
    private static readonly int[] freq = new int[N];
    private static readonly int[] flatTree = new int[N];
 
    private static int n = 0;
    private static int sq = 0;
    private static int start = 0;
 
    // DFS function to find
    // order of euler tour
    private static void DFSEuler(int a, int par)
    {
        start += 1;
        in_time[a] = start;
        flatTree[start] = colour[a];
 
        foreach(int i in v[a])
        {
            if (i == par) {
                continue;
            }
            DFSEuler(i, a);
        }
 
        out_time[a] = start;
    }
 
    // comparator for queries
    private static int Comparator(int[] a, int[] b)
    {
        // comparator for queries to be
        // sorted according to in[x] / sq
        if (in_time[a[0]] / sq != in_time[b[0]] / sq) {
            return in_time[a[0]] < in_time[b[0]] ? -1 : 1;
        }
        return out_time[a[0]] < out_time[b[0]] ? -1 : 1;
    }
 
    // Function to answer the queries
    private static void Solve(int[][] arr, int q)
    {
        sq = (int)Math.Floor(Math.Sqrt(n)) + 1;
        int[] answer = new int[q];
        Dictionary<int[], int> idx
            = new Dictionary<int[], int>(
                new ArrayEqualityComparer<int>());
        for (int i = 0; i < q; i++) {
            idx.Add(arr[i], i);
        }
 
        // doing depth first search to
        // find indexes to flat the
        // tree using euler tour.
        DFSEuler(1, 0);
 
        // After doing Euler tour,
        // subtree of x can be
        // represented as a subarray
        // from in[x] to out[x];
 
        // we'll sort the queries
        // according to the in[i];
        Array.Sort(arr, Comparator);
        int l = 1;
        int r = 0;
        for (int i = 0; i < q; i++) {
            int[] query = arr[i];
            int node = query[0];
            int x = query[1];
            int id = idx[query];
            while (l > in_time[node]) {
                l -= 1;
                cnt[flatTree[l]] += 1;
                freq[cnt[flatTree[l]]] += 1;
            }
            while (r < out_time[node]) {
                r += 1;
                cnt[flatTree[r]] += 1;
                freq[cnt[flatTree[r]]] += 1;
            }
            while (l < in_time[node]) {
                freq[cnt[flatTree[l]]] -= 1;
                cnt[flatTree[l]] -= 1;
                l += 1;
            }
            while (r > out_time[node]) {
                freq[cnt[flatTree[r]]] -= 1;
                cnt[flatTree[r]] -= 1;
                r -= 1;
            }
            answer[id] = freq[x];
        }
 
        foreach(var val in answer) Console.Write(val + " ");
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        int n = 8;
        v[1].Add(2);
        v[2].Add(1);
        v[2].Add(3);
        v[3].Add(2);
        v[2].Add(4);
        v[4].Add(2);
        v[1].Add(5);
        v[5].Add(1);
        v[5].Add(6);
        v[6].Add(5);
        v[5].Add(7);
        v[7].Add(5);
        v[5].Add(8);
        v[8].Add(5);
        colour[1] = 1;
        colour[2] = 2;
        colour[3] = 2;
        colour[4] = 3;
        colour[5] = 3;
        colour[6] = 2;
        colour[7] = 3;
        colour[8] = 3;
 
        int[][] queries
            = { new int[] { 1, 2 }, new int[] { 1, 3 },
                new int[] { 1, 4 }, new int[] { 2, 3 },
                new int[] { 5, 3 } };
        int q = queries.Length;
 
        // Function call
        Solve(queries, q);
    }
}

                    

Javascript

// JavaScript program to count distinct colors
// in a subtree of a Colored Tree
// with given min frequency for Q queries
const N = 100005;
 
const v = Array.from({ length: N }, () => []);
const colour = new Array(N).fill(0);
const in_time = new Array(N).fill(0);
const out_time = new Array(N).fill(0);
const cnt = new Array(N).fill(0);
const freq = new Array(N).fill(0);
const flatTree = new Array(N).fill(0);
 
let n = 0;
let sq = 0;
let start = 0;
 
// DFS function to find
// order of euler tour
function DFSEuler(a, par) {
  start += 1;
  in_time[a] = start;
  flatTree[start] = colour[a];
 
  for (const i of v[a]) {
    if (i === par) {
      continue;
    }
    DFSEuler(i, a);
  }
 
  out_time[a] = start;
}
 
// comparator for queries
function comparator(a, b) {
// comparator for queries to be
// sorted according to in[x] / sq
  if (Math.floor(in_time[a[0]] / sq) !== Math.floor(in_time[b[0]] / sq)) {
    return in_time[a[0]] < in_time[b[0]];
  }
  return out_time[a[0]] < out_time[b[0]];
}
 
// Function to answer the queries
function solve(arr, q) {
  sq = Math.floor(Math.sqrt(n)) + 1;
  const answer = new Array(q).fill(0);
  const idx = {};
  for (let i = 0; i < q; i++) {
    idx[arr[i]] = i;
  }
   
  // doing depth first search to
  // find indexes to flat the
  // tree using euler tour.
  DFSEuler(1, 0);
   
  // After doing Euler tour,
  // subtree of x can be
  // represented as a subarray
  // from in[x] to out[x];
 
  // we'll sort the queries
  // according to the in[i];
  arr.sort((a, b) => comparator(a, b));
  let l = 1;
  let r = 0;
  for (let i = 0; i < q; i++) {
    const [node, x] = arr[i];
    const id = idx[arr[i]];
    while (l > in_time[node]) {
      l -= 1;
      cnt[flatTree[l]] += 1;
      freq[cnt[flatTree[l]]] += 1;
    }
    while (r < out_time[node]) {
      r += 1;
      cnt[flatTree[r]] += 1;
      freq[cnt[flatTree[r]]] += 1;
    }
    while (l < in_time[node]) {
      freq[cnt[flatTree[l]]] -= 1;
      cnt[flatTree[l]] -= 1;
      l += 1;
    }
    while (r > out_time[node]) {
      freq[cnt[flatTree[r]]] -= 1;
      cnt[flatTree[r]] -= 1;
      r -= 1;
    }
    answer[id] = freq[x];
  }
   
  console.log(answer.join(' '));
}
 
n = 8;
v[1].push(2);
v[2].push(1);
v[2].push(3);
v[3].push(2);
v[2].push(4);
v[4].push(2);
v[1].push(5);
v[5].push(1);
v[5].push(6);
v[6].push(5);
v[5].push(7);
v[7].push(5);
v[5].push(8);
v[8].push(5);
colour[1] = 1;
colour[2] = 2;
colour[3] = 2;
colour[4] = 3;
colour[5] = 3;
colour[6] = 2;
colour[7] = 3;
colour[8] = 3;
 
const queries = [[1, 2], [1, 3], [1, 4], [2, 3], [5, 3]];
const q = queries.length;
solve(queries, q);
 
//contributed by adityasharmadev01

                    

Output: 
2 2 1 0 1

 

Time Complexity: O(Q * \sqrt{N})
Auxiliary Space: O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads