Open In App

Maximizing Color Differences in Subtrees

Given a tree consisting of n vertices, each vertex is either colored black or white. The task is to compute for each vertex v, the maximum difference between the number of white and the number of black vertices if a subtree of the given tree is chosen that contains the vertex v.

Examples:



Input: n = 9, color = { 0, 1, 1, 1, 0, 0, 0, 0, 1 }, edges = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 }, { 2, 6 }, { 4, 7 }, { 6, 8 }, { 5, 9 } }
Output: 2 2 2 2 2 1 1 0 2

Input: n = 4, color = { 0, 0, 1, 0}, edges = {{1, 2}, {1, 3}, {1, 4}}
Output: 0 -1 1 -1



Approach: To solve the problem follow the below idea:

The idea is to use a two-pass Depth-First Search (DFS) on a tree to calculate, for each vertex, the maximum difference between the count of white and black vertices in any subtree rooted at that vertex. The first pass computes a cumulative count (dp) considering the current vertex’s color and its children, while the second pass adjusts these counts to compute the final result for each vertex by considering the counts of its children. This way, we can efficiently determine the maximum difference for each vertex in a single traversal of the tree.

Step-by-step approach:

Below is the implementation of the above approach:




#include <bits/stdc++.h>
using namespace std;
 
// Array to store color information (0 for black,
// 1 for white)
vector<int> color;
 
// dp array to store the difference between the number of
// white and black vertices in the subtree rooted at each
// vertex
vector<int> dp;
 
// ans array to store the final result for each vertex
vector<int> ans;
 
// adjacency list representing the tree
vector<vector<int> > adj;
 
// First DFS to calculate the difference in the number of
// white and black vertices in the subtree rooted at each
// vertex
void dfs(int v, int p = -1)
{
 
    // Initialize dp[v] with the color
    // of the current vertex
    dp[v] = color[v];
    for (auto to : adj[v]) {
 
        // Skip the parent in the DFS
        if (to == p)
            continue;
        dfs(to, v);
 
        // Update dp[v] by adding the maximum of dp
        // values from its children
        dp[v] += max(dp[to], 0);
    }
}
 
// Second DFS to compute the final
// result for each vertex
void dfs2(int v, int p = -1)
{
 
    // Initialize ans[v] with dp[v]
    ans[v] = dp[v];
    for (auto to : adj[v]) {
 
        // Skip the parent in the DFS
        if (to == p)
            continue;
 
        // Update dp[v] by subtracting the maximum of dp
        // values from its children
        dp[v] -= max(0, dp[to]);
 
        // Update dp[to] by adding the updated dp[v]
        dp[to] += max(0, dp[v]);
        dfs2(to, v);
 
        // Undo the changes to dp values after the DFS of
        // the subtree rooted at 'to'
        dp[to] -= max(0, dp[v]);
 
        // Revert back the changes to dp[v]
        dp[v] += max(0, dp[to]);
    }
}
 
// Drivers code
int main()
{
 
    int n = 9;
    dp = ans = vector<int>(n);
    adj = vector<vector<int> >(n);
 
    // Input the color of each vertex (0 for black, 1 for
    // white)
 
    color = { 0, 1, 1, 1, 0, 0, 0, 0, 1 };
 
    for (int i = 0; i < n; ++i) {
 
        // Convert 0 to -1 for easier calculation
        if (color[i] == 0)
            color[i] = -1;
    }
 
    // Input the edges of the tree
    vector<vector<int> > edges
        = { { 1, 2 }, { 1, 3 }, { 3, 4 }, { 3, 5 },
            { 2, 6 }, { 4, 7 }, { 6, 8 }, { 5, 9 } };
 
    for (int i = 0; i < n - 1; ++i) {
        int x = edges[i][0], y = edges[i][1];
        --x, --y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
 
    // Perform the two DFS to calculate the result
    dfs(0);
    dfs2(0);
 
    // Output the result for each vertex
    for (auto it : ans)
        cout << it << " ";
    cout << endl;
 
    return 0;
}




import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
 
class Main {
 
    // Array to store color information (0 for black, 1 for white)
    static List<Integer> color;
 
    // dp array to store the difference between the number of
    // white and black vertices in the subtree rooted at each
    // vertex
    static List<Integer> dp;
 
    // ans array to store the final result for each vertex
    static List<Integer> ans;
 
    // adjacency list representing the tree
    static List<List<Integer>> adj;
 
    // First DFS to calculate the difference in the number of
    // white and black vertices in the subtree rooted at each
    // vertex
    static void dfs(int v, int p) {
 
        // Initialize dp[v] with the color
        // of the current vertex
        dp.set(v, color.get(v));
        for (int to : adj.get(v)) {
 
            // Skip the parent in the DFS
            if (to == p)
                continue;
            dfs(to, v);
 
            // Update dp[v] by adding the maximum of dp
            // values from its children
            dp.set(v, dp.get(v) + Math.max(dp.get(to), 0));
        }
    }
 
    // Second DFS to compute the final
    // result for each vertex
    static void dfs2(int v, int p) {
 
        // Initialize ans[v] with dp[v]
        ans.set(v, dp.get(v));
        for (int to : adj.get(v)) {
 
            // Skip the parent in the DFS
            if (to == p)
                continue;
 
            // Update dp[v] by subtracting the maximum of dp
            // values from its children
            dp.set(v, dp.get(v) - Math.max(0, dp.get(to)));
 
            // Update dp[to] by adding the updated dp[v]
            dp.set(to, dp.get(to) + Math.max(0, dp.get(v)));
            dfs2(to, v);
 
            // Undo the changes to dp values after the DFS of
            // the subtree rooted at 'to'
            dp.set(to, dp.get(to) - Math.max(0, dp.get(v)));
 
            // Revert back the changes to dp[v]
            dp.set(v, dp.get(v) + Math.max(0, dp.get(to)));
        }
    }
 
    // Drivers code
    public static void main(String[] args) {
 
        int n = 9;
        dp = new ArrayList<>(Arrays.asList(new Integer[n]));
        ans = new ArrayList<>(Arrays.asList(new Integer[n]));
        adj = new ArrayList<>(n);
 
        // Input the color of each vertex (0 for black, 1 for
        // white)
        color = Arrays.asList(0, 1, 1, 1, 0, 0, 0, 0, 1);
 
        for (int i = 0; i < n; ++i) {
 
            // Convert 0 to -1 for easier calculation
            if (color.get(i) == 0)
                color.set(i, -1);
        }
 
        // Input the edges of the tree
        List<List<Integer>> edges = Arrays.asList(
                Arrays.asList(1, 2), Arrays.asList(1, 3), Arrays.asList(3, 4),
                Arrays.asList(3, 5), Arrays.asList(2, 6), Arrays.asList(4, 7),
                Arrays.asList(6, 8), Arrays.asList(5, 9)
        );
 
        for (int i = 0; i < n; ++i) {
            adj.add(new ArrayList<>());
        }
 
        for (int i = 0; i < n - 1; ++i) {
            int x = edges.get(i).get(0), y = edges.get(i).get(1);
            --x;
            --y;
            adj.get(x).add(y);
            adj.get(y).add(x);
        }
 
        // Perform the two DFS to calculate the result
        dfs(0, -1);
        dfs2(0, -1);
 
        // Output the result for each vertex
        for (int it : ans)
            System.out.print(it + " ");
        System.out.println();
    }
}




class Main:
 
    # Array to store color information (0 for black, 1 for white)
    color = []
 
    # dp array to store the difference between the number of
    # white and black vertices in the subtree rooted at each
    # vertex
    dp = []
 
    # ans array to store the final result for each vertex
    ans = []
 
    # adjacency list representing the tree
    adj = []
 
    # First DFS to calculate the difference in the number of
    # white and black vertices in the subtree rooted at each
    # vertex
    @staticmethod
    def dfs(v, p):
 
        # Initialize dp[v] with the color
        # of the current vertex
        Main.dp[v] = Main.color[v]
        for to in Main.adj[v]:
 
            # Skip the parent in the DFS
            if to == p:
                continue
            Main.dfs(to, v)
 
            # Update dp[v] by adding the maximum of dp
            # values from its children
            Main.dp[v] += max(Main.dp[to], 0)
 
    # Second DFS to compute the final
    # result for each vertex
    @staticmethod
    def dfs2(v, p):
 
        # Initialize ans[v] with dp[v]
        Main.ans[v] = Main.dp[v]
        for to in Main.adj[v]:
 
            # Skip the parent in the DFS
            if to == p:
                continue
 
            # Update dp[v] by subtracting the maximum of dp
            # values from its children
            Main.dp[v] -= max(0, Main.dp[to])
 
            # Update dp[to] by adding the updated dp[v]
            Main.dp[to] += max(0, Main.dp[v])
            Main.dfs2(to, v)
 
            # Undo the changes to dp values after the DFS of
            # the subtree rooted at 'to'
            Main.dp[to] -= max(0, Main.dp[v])
 
            # Revert back the changes to dp[v]
            Main.dp[v] += max(0, Main.dp[to])
 
    # Drivers code
    @staticmethod
    def main():
 
        n = 9
        Main.dp = [0] * n
        Main.ans = [0] * n
        Main.adj = [[] for _ in range(n)]
 
        # Input the color of each vertex (0 for black, 1 for
        # white)
        Main.color = [0, 1, 1, 1, 0, 0, 0, 0, 1]
 
        for i in range(n):
 
            # Convert 0 to -1 for easier calculation
            if Main.color[i] == 0:
                Main.color[i] = -1
 
        # Input the edges of the tree
        edges = [
            [1, 2], [1, 3], [3, 4],
            [3, 5], [2, 6], [4, 7],
            [6, 8], [5, 9]
        ]
 
        for i in range(n):
            Main.adj.append([])
 
        for i in range(n - 1):
            x, y = edges[i]
            x -= 1
            y -= 1
            Main.adj[x].append(y)
            Main.adj[y].append(x)
 
        # Perform the two DFS to calculate the result
        Main.dfs(0, -1)
        Main.dfs2(0, -1)
 
        # Output the result for each vertex
        for it in Main.ans:
            print(it, end=" ")
        print()
 
 
# Run the main function
Main.main()
 
 
# This code is contributed by rambabuguphka




using System;
using System.Collections.Generic;
using System.Linq;
 
class GFG
{
    // Array to store color information
    static List<int> color;
    static List<int> dp;
    // ans array to store the final result for the each vertex
    static List<int> ans;
    static List<List<int>> adj;
    static void Dfs(int v, int p)
    {
        // Initialize dp[v] with the color of the current vertex
        dp[v] = color[v];
        foreach (int to in adj[v])
        {
            // Skip the parent in the DFS
            if (to == p)
                continue;
            Dfs(to, v);
            dp[v] += Math.Max(dp[to], 0);
        }
    }
    static void Dfs2(int v, int p)
    {
        // Initialize ans[v] with dp[v]
        ans[v] = dp[v];
        foreach (int to in adj[v])
        {
            // Skip the parent in the DFS
            if (to == p)
                continue;
            dp[v] -= Math.Max(0, dp[to]);
            // Update dp[to] by adding the updated dp[v]
            dp[to] += Math.Max(0, dp[v]);
            Dfs2(to, v);
            dp[to] -= Math.Max(0, dp[v]);
            // Revert back the changes to dp[v]
            dp[v] += Math.Max(0, dp[to]);
        }
    }
    // Driver code
    public static void Main(string[] args)
    {
        int n = 9;
        dp = new List<int>(new int[n]);
        ans = new List<int>(new int[n]);
        adj = new List<List<int>>(n);
        // Input the color of the each vertex
        color = new List<int> { 0, 1, 1, 1, 0, 0, 0, 0, 1 };
        for (int i = 0; i < n; ++i)
        {
            // Convert 0 to -1 for the easier calculation
            if (color[i] == 0)
                color[i] = -1;
        }
        // Input the edges of the tree
        var edges = new List<List<int>> {
            new List<int> { 1, 2 }, new List<int> { 1, 3 }, new List<int> { 3, 4 },
            new List<int> { 3, 5 }, new List<int> { 2, 6 }, new List<int> { 4, 7 },
            new List<int> { 6, 8 }, new List<int> { 5, 9 }
        };
        for (int i = 0; i < n; ++i)
        {
            adj.Add(new List<int>());
        }
        for (int i = 0; i < n - 1; ++i)
        {
            int x = edges[i][0], y = edges[i][1];
            --x;
            --y;
            adj[x].Add(y);
            adj[y].Add(x);
        }
        // Perform the two DFS to calculate the result
        Dfs(0, -1);
        Dfs2(0, -1);
        foreach (int it in ans)
            Console.Write(it + " ");
        Console.WriteLine();
    }
}




// Array to store color information (0 for black, 1 for white)
let color = [];
 
// dp array to store the difference between the number of white and black vertices in the subtree rooted at each vertex
let dp = [];
 
// ans array to store the final result for each vertex
let ans = [];
 
// Adjacency list representing the tree
let adj = [];
 
// First DFS to calculate the difference in the number of white and black vertices in the subtree rooted at each vertex
function dfs(v, p = -1) {
    dp[v] = color[v];
    for (let to of adj[v]) {
        if (to === p) continue;
        dfs(to, v);
        dp[v] += Math.max(dp[to], 0);
    }
}
 
// Second DFS to compute the final result for each vertex
function dfs2(v, p = -1) {
    ans[v] = dp[v];
    for (let to of adj[v]) {
        if (to === p) continue;
        dp[v] -= Math.max(0, dp[to]);
        dp[to] += Math.max(0, dp[v]);
        dfs2(to, v);
        dp[to] -= Math.max(0, dp[v]);
        dp[v] += Math.max(0, dp[to]);
    }
}
 
// Driver code
function main() {
    let n = 9;
    color = [0, 1, 1, 1, 0, 0, 0, 0, 1].map(c => c === 0 ? -1 : c);
    dp = Array(n).fill(0);
    ans = Array(n).fill(0);
    adj = Array(n).fill().map(() => []);
 
    let edges = [[1, 2], [1, 3], [3, 4], [3, 5], [2, 6], [4, 7], [6, 8], [5, 9]];
 
    edges.forEach(([x, y]) => {
        x--, y--;
        adj[x].push(y);
        adj[y].push(x);
    });
 
    dfs(0);
    dfs2(0);
 
    console.log(ans.join(' '));
}
 
main();
 
// This code is contributed by akshitaguprzj3

Output
2 2 2 2 2 1 1 0 2 






Time Complexity: O(n)
Auxiliary Space: O(n)


Article Tags :