Open In App

Count the number of palindromic triplets

Given a tree of N nodes numbered 0, 1,… N – 1. The structure of the tree is given by array p[] of size N where p[i] is the direct ancestor of the ith node and p[0] = -1 as 0 is the root of the tree. Also, you are given a string C of size N, where C[i] represents the character written on the ith node. The task is to find the number of triplets of nodes (a, b, c){a≠ b≠ c} such that node b comes in the simple path from node a to c and C[a], C[b], and C forms a palindrome.

Examples:



Input: N = 5, p[] = {-1, 0, 0, 1, 1}, C = “accbaa”
Output: 8
Explanation: The tree looks like:

0’a’
/ \
1’c’ 2’c’
/ \
3’a’ 4’a’



Possible triplets are

  • (1, 0, 2), (2, 0, 1),
  • (3, 1, 0), (4, 1, 0),
  • (0, 1, 4), (0, 1, 3),
  • (3, 1, 4), (4, 1, 3)

Input: N = 3, p[] = {-1, 0, 0}, C = “abc”
Output: 0
Explanation: The tree looks like:

0’a’
/ \
1 ‘b’ 2’c’

No triplets are possible.

Approach: This can be solved with the following idea:

We need to count the number of valid triplets of nodes in the tree, where a triplet is valid if the characters on those nodes form a Thrindrome. To achieve this, we explore the tree using DFS, calculate character counts at each node, and count triplets involving each node. By maintaining character counts and considering node relationships, we efficiently identify and count valid triplets throughout the tree.

Below are the steps involved:

Below is the implementation of the code:




// C++ Implementation
#include <bits/stdc++.h>
#include <iostream>
using namespace std;
 
#define ll long long
vector<vector<ll> > dp;
long long ans = 0;
 
// Function to iterate in adj using dp
void rec(int cur, int par, string& C,
         vector<long long>& total,
         vector<vector<int> >& adj)
{
 
    // Iterate in adj
    for (auto next : adj[cur]) {
 
        // If next node is parent itself
        if (par == next)
            continue;
 
        rec(next, cur, C, total, adj);
 
        // Iterate in 26 Alphabets
 
        // Case 1
        for (int i = 0; i < 26; i++) {
 
            ans += dp[cur][i] * dp[next][i] * 2;
            dp[cur][i] += dp[next][i];
        }
    }
 
    // Again iterate there
 
    // Case 2
    for (int i = 0; i < 26; i++) {
        if (i == (int)(C[cur] - 'a'))
            ans += dp[cur][i]
                   * (1LL * total[i] - dp[cur][i] - 1) * 2;
        else
            ans += dp[cur][i]
                   * (1LL * total[i] - dp[cur][i]) * 2;
    }
 
    dp[cur][C[cur] - 'a']++;
}
 
// Function to count number of possible palindromic
// triplets
long long solve(int N, vector<int> p, string C)
{
 
    // Intialise a dp
    dp = vector<vector<ll> >(N + 1, vector<ll>(26, 0));
 
    vector<long long> arr(26, 0);
 
    // Iterate in given string
    for (auto c : C)
        arr++;
 
    // Create a adj vector
    vector<vector<int> > adj(N, vector<int>());
    for (int i = 1; i < N; i++) {
        adj[p[i]].push_back(i);
    }
 
    // Function to iterate in created vectors
    rec(0, -1, C, arr, adj);
 
    // Return the count of triplets
    return ans;
}
 
// Driver code
int main()
{
 
    int N = 5;
    vector<int> p = { -1, 0, 0, 1, 1 };
    string C = "accaa";
 
    // Function call
    cout << solve(N, p, C);
    return 0;
}




import java.util.Arrays;
import java.util.Vector;
import java.util.Collections;
 
public class Main {
 
    static long ans = 0;
    static Vector<Vector<Long>> dp;
 
    // Function to iterate in adj using dp
    static void rec(int cur, int par, String C,
                    Vector<Long> total,
                    Vector<Vector<Integer>> adj) {
 
        // Iterate in adj
        for (int next : adj.get(cur)) {
 
            // If the next node is the parent itself
            if (par == next)
                continue;
 
            rec(next, cur, C, total, adj);
 
            // Iterate in 26 Alphabets
 
            // Case 1
            for (int i = 0; i < 26; i++) {
                ans += dp.get(cur).get(i) * dp.get(next).get(i) * 2;
                dp.get(cur).set(i, dp.get(cur).get(i) + dp.get(next).get(i));
            }
        }
 
        // Again iterate there
 
        // Case 2
        for (int i = 0; i < 26; i++) {
            if (i == (int) (C.charAt(cur) - 'a'))
                ans += dp.get(cur).get(i)
                        * (1L * total.get(i) - dp.get(cur).get(i) - 1) * 2;
            else
                ans += dp.get(cur).get(i)
                        * (1L * total.get(i) - dp.get(cur).get(i)) * 2;
        }
 
        dp.get(cur).set(C.charAt(cur) - 'a', dp.get(cur).get(C.charAt(cur) - 'a') + 1);
    }
 
    // Function to count the number of possible palindromic triplets
    static long solve(int N, int[] p, String C) {
 
        // Initialize dp
        dp = new Vector<>(N + 1);
        for (int i = 0; i <= N; i++) {
            dp.add(new Vector<>(Arrays.asList(new Long[26])));
            Collections.fill(dp.get(i), 0L);
        }
 
        Vector<Long> arr = new Vector<>(Arrays.asList(new Long[26]));
        Collections.fill(arr, 0L);
 
        // Iterate in the given string
        for (char c : C.toCharArray())
            arr.set(c - 'a', arr.get(c - 'a') + 1);
 
        // Create an adj vector
        Vector<Vector<Integer>> adj = new Vector<>(N);
        for (int i = 0; i < N; i++) {
            adj.add(new Vector<>());
        }
 
        for (int i = 1; i < N; i++) {
            adj.get(p[i]).add(i);
        }
 
        // Function to iterate in created vectors
        rec(0, -1, C, arr, adj);
 
        // Return the count of triplets
        return ans;
    }
 
    // Driver code
    public static void main(String[] args) {
 
        int N = 5;
        int[] p = {-1, 0, 0, 1, 1};
        String C = "accaa";
 
        // Function call
        System.out.println(solve(N, p, C));
    }
}




# Python Implementation
from typing import List
 
# Function to iterate in adj using dp
def rec(cur, par, C, total, adj, dp, ans):
    global res
 
    # Iterate in adj
    for next_node in adj[cur]:
 
        # If next node is parent itself
        if par == next_node:
            continue
 
        rec(next_node, cur, C, total, adj, dp, ans)
 
        # Iterate in 26 Alphabets
 
        # Case 1
        for i in range(26):
            ans[0] += dp[cur][i] * dp[next_node][i] * 2
            dp[cur][i] += dp[next_node][i]
 
    # Again iterate there
 
    # Case 2
    for i in range(26):
        if i == ord(C[cur]) - ord('a'):
            ans[0] += dp[cur][i] * (total[i] - dp[cur][i] - 1) * 2
        else:
            ans[0] += dp[cur][i] * (total[i] - dp[cur][i]) * 2
 
    dp[cur][ord(C[cur]) - ord('a')] += 1
 
 
# Function to count the number of possible palindromic triplets
def solve(N, p, C):
    global res
 
    # Initialize dp
    dp = [[0] * 26 for _ in range(N + 1)]
 
    total = [0] * 26
 
    # Iterate in the given string
    for c in C:
        total[ord(c) - ord('a')] += 1
 
    # Create an adj vector
    adj = [[] for _ in range(N)]
    for i in range(1, N):
        adj[p[i]].append(i)
 
    ans = [0]
     
    # Function to iterate in created vectors
    rec(0, -1, C, total, adj, dp, ans)
 
    # Return the count of triplets
    return ans[0]
 
 
# Driver code
if __name__ == "__main__":
    N = 5
    p = [-1, 0, 0, 1, 1]
    C = "accaa"
 
    # Function call
    print(solve(N, p, C))




using System;
using System.Collections.Generic;
using System.Linq;
 
public class GFG {
    static long ans = 0;
    static List<List<long> > dp;
 
    // Function to iterate in adj using dp
    static void rec(int cur, int par, string C,
                    List<long> total, List<List<int> > adj)
    {
        // Iterate in adj
        foreach(int next in adj[cur])
        {
            // If the next node is the parent itself
            if (par == next)
                continue;
 
            rec(next, cur, C, total, adj);
 
            // Iterate in 26 Alphabets
 
            // Case 1
            for (int i = 0; i < 26; i++) {
                ans += dp[cur][i] * dp[next][i] * 2;
                dp[cur][i] += dp[next][i];
            }
        }
 
        // Again iterate there
 
        // Case 2
        for (int i = 0; i < 26; i++) {
            if (i == (int)(C[cur] - 'a'))
                ans += dp[cur][i]
                       * (1L * total[i] - dp[cur][i] - 1)
                       * 2;
            else
                ans += dp[cur][i]
                       * (1L * total[i] - dp[cur][i]) * 2;
        }
 
        dp[cur][C[cur] - 'a']++;
    }
 
    // Function to count the number of possible palindromic
    // triplets
    static long solve(int N, int[] p, string C)
    {
        // Initialize dp
        dp = new List<List<long> >(N + 1);
        for (int i = 0; i <= N; i++) {
            dp.Add(Enumerable.Repeat(0L, 26).ToList());
        }
 
        List<long> arr = Enumerable.Repeat(0L, 26).ToList();
 
        // Iterate in the given string
        foreach(char c in C) arr++;
 
        // Create an adj vector
        List<List<int> > adj = new List<List<int> >(N);
        for (int i = 0; i < N; i++) {
            adj.Add(new List<int>());
        }
 
        for (int i = 1; i < N; i++) {
            adj[p[i]].Add(i);
        }
 
        // Function to iterate in created vectors
        rec(0, -1, C, arr, adj);
 
        // Return the count of triplets
        return ans;
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        int N = 5;
        int[] p = { -1, 0, 0, 1, 1 };
        string C = "accaa";
 
        // Function call
        Console.WriteLine(solve(N, p, C));
    }
}




let dp;
let ans = 0n; // Use BigInt for ans
 
function rec(cur, par, C, total, adj) {
    for (let next of adj[cur]) {
        if (par === next) continue;
 
        rec(next, cur, C, total, adj);
 
        for (let i = 0; i < 26; i++) {
            ans += BigInt(dp[cur][i]) * BigInt(dp[next][i]) * 2n;
            dp[cur][i] += dp[next][i];
        }
    }
 
    for (let i = 0; i < 26; i++) {
        if (i === C[cur].charCodeAt(0) - 'a'.charCodeAt(0)) {
            ans += BigInt(dp[cur][i]) * (BigInt(total[i]) - BigInt(dp[cur][i]) - 1n) * 2n;
        } else {
            ans += BigInt(dp[cur][i]) * (BigInt(total[i]) - BigInt(dp[cur][i])) * 2n;
        }
    }
 
    dp[cur][C[cur].charCodeAt(0) - 'a'.charCodeAt(0)]++;
}
 
function solve(N, p, C) {
    dp = Array.from({ length: N + 1 }, () => Array(26).fill(0));
    const arr = Array(26).fill(0);
 
    for (let c of C) {
        arr++;
    }
 
    const adj = Array.from({ length: N }, () => []);
 
    for (let i = 1; i < N; i++) {
        adj[p[i]].push(i);
    }
 
    rec(0, -1, C, arr, adj);
    return ans.toString(); // Convert BigInt to string for output
}
 
// Driver code
const N = 5;
const p = [-1, 0, 0, 1, 1];
const C = "accaa";
 
// Function call
console.log(solve(N, p, C));
 
 
// This code is contributed by akshitaguprzj3

Output
8











Time Complexity: O(N)
Auxiliary Space: O(N)


Article Tags :