Open In App

Minimizing vertex distances in Tree with special operation

Last Updated : 19 Jul, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a tree of N vertices numbered from 0 to N-1 and N-1 edges. The task is to minimize the sum of the distance of all vertices from vertex 0 after applying at most 1 special operation. The edge weight of all edges is 1. In the special operation, select any edge of the tree and remove it, which results in two trees, now you have to attach the node of the removed edge such that the distance of the nodes is minimized from vertex 0.

Examples:

Input: n = 5, edges[] = [[0, 1], [1, 2], [2, 3], [3, 4]]
Output: 6
Explanation:  

 

Initially tree look like 0->1->2->3->4 and the sum of all vertex distances are,

  • 0->1 => dis=1, sum =1
  • 0->2 => dis=2, sum =1+2=3
  • 0->3 => dis=3, sum = 3+3=6
  • 0->4 => dis=4, sum = 6+4=10

To minimize the sum if we cut the edge between 2->3 vertices and attach 3 with 0,  

  • 0->1 => dis=1, sum=1
  • 0->2 => dis=2, sum=1+2=3
  • 0->3 => dis=3, sum=3+1=4
  • 0->4 => dis=2, sum=4+2=6

Here, 6 is the minimized sum of all vertices from vertex 0.

Input: n = 5, edges[] = [[0, 1], [0, 2], [0, 3], [0, 4]]
Output: 4

Approach: This can be solved with the following idea:

This problem asks to find the minimum sum of distances of all nodes from a particular vertex in a tree. The algorithm uses a dynamic programming approach to solve the problem.

Below are the steps involved in the implementation of the code:

  • The function dfs is used to calculate the dp and size arrays for each node. The dp array stores the sum of distances of all nodes from the current node, while the size array stores the size of the subtree rooted at the current node. The dfs function uses recursion to calculate these values for each node. Consider Example 1,

 

  • The function dfs2 is used to minimize the distance by removing edges. It takes in the current node, its parent node, the current height, the adjacency list, the ‘size’ and ‘dp’ arrays, and the minimum distance found so far (‘ans’). The function calculates the distance of the current node from the root node (‘temp’) and updates ‘ans‘ if ‘temp’ is smaller than the current value of ‘ans‘. It then recursively calls ‘dfs2’ for all the children of the current node.
  • The ‘minDistance‘ function is the main function that takes in the number of nodes and a list of edges. It creates an adjacency list from the edges, initializes the ‘dp’ and ‘size’ arrays, and calls ‘dfs’ to calculate their values. It then initializes ‘ans’ with the sum of distances of all nodes from the root node (‘dp[0]’) and calls ‘dfs2’ for each child of the root node to minimize the distance. Finally, it returns the minimum sum of distances (‘ans’).

Below is the implementation of the above approach:

C++




// C++ code for the above approach:
#include <bits/stdc++.h>
using namespace std;
 
// Calculates the dp[] and size[] arrays
// for each node.
void dfs(int node, int par, vector<int> adj[],
         vector<long long int>& dp,
         vector<long long int>& size)
{
    size[node] = 1;
    for (auto it : adj[node]) {
        if (it != par) {
            dfs(it, node, adj, dp, size);
            dp[node] += (dp[it] + size[it]);
            size[node] += size[it];
        }
    }
}
 
// Minimize the distance by
// removing edges
void dfs2(int node, int par, int h, vector<int> adj[],
          vector<long long int>& size,
          vector<long long int>& dp, long long int& ans)
{
    long long int temp = dp[0] - (h * size[node]);
    ans = min(ans, temp);
 
    for (auto it : adj[node]) {
        if (it != par) {
            dfs2(it, node, h + 1, adj, size, dp, ans);
        }
    }
}
 
// Return the sum of minimized distance
// of all nodes from vertex 0
long long minDistance(int n, vector<vector<int> >& edges)
{
 
    // code here
    vector<int> adj[n];
    for (int i = 0; i < n - 1; i++) {
        adj[edges[i][0]].push_back(edges[i][1]);
        adj[edges[i][1]].push_back(edges[i][0]);
    }
    vector<long long int> dp(n, 0), size(n, 0);
 
    dfs(0, -1, adj, dp, size);
    long long int ans = dp[0];
 
    for (auto it : adj[0]) {
        dfs2(it, 0, 0, adj, size, dp, ans);
    }
    return ans;
}
 
// Driver code
int main()
{
    int n = 5;
    vector<vector<int> > edges
        = { { 0, 1 }, { 1, 2 }, { 2, 3 }, { 3, 4 } };
 
    // Function call
    long long int ans = minDistance(n, edges);
 
    cout << ans << endl;
    return 0;
}


Java




import java.util.*;
 
public class Main {
    // Calculates the dp[] and size[] arrays for each node.
    public static void dfs(int node, int par,
                           List<Integer>[] adj,
                           List<Long> dp, List<Long> size)
    {
        size.set(node, 1L);
        for (int it : adj[node]) {
            if (it != par) {
                dfs(it, node, adj, dp, size);
                dp.set(node,
                       dp.get(node)
                           + (dp.get(it) + size.get(it)));
                size.set(node,
                         size.get(node) + size.get(it));
            }
        }
    }
 
    // Minimize the distance by removing edges
    public static void dfs2(int node, int par, int h,
                            List<Integer>[] adj,
                            List<Long> size, List<Long> dp,
                            long[] ans)
    {
        long temp = dp.get(0) - (h * size.get(node));
        ans[0] = Math.min(ans[0], temp);
        for (int it : adj[node]) {
            if (it != par) {
                dfs2(it, node, h + 1, adj, size, dp, ans);
            }
        }
    }
 
    // return the sum of minimized distance of all nodes
    // from vertex 0
    public static long
    minDistance(int n, List<List<Integer> > edges)
    {
        List<Integer>[] adj = new List[n];
        for (int i = 0; i < n; i++) {
            adj[i] = new ArrayList<>();
        }
        for (int i = 0; i < n - 1; i++) {
            adj[edges.get(i).get(0)].add(
                edges.get(i).get(1));
            adj[edges.get(i).get(1)].add(
                edges.get(i).get(0));
        }
        List<Long> dp
            = new ArrayList<>(Collections.nCopies(n, 0L));
        List<Long> size
            = new ArrayList<>(Collections.nCopies(n, 0L));
        dfs(0, -1, adj, dp, size);
        long[] ans = { dp.get(0) };
        for (int it : adj[0]) {
            dfs2(it, 0, 0, adj, size, dp, ans);
        }
        return ans[0];
    }
 
    public static void main(String[] args)
    {
        int n = 5;
        List<List<Integer> > edges = new ArrayList<>();
        edges.add(Arrays.asList(0, 1));
        edges.add(Arrays.asList(1, 2));
        edges.add(Arrays.asList(2, 3));
        edges.add(Arrays.asList(3, 4));
        long ans = minDistance(n, edges);
        System.out.println(ans);
    }
}


Python3




from typing import List
# Calculates the dp[] and size[] arrays for each node.
def dfs(node: int, par: int, adj: List[List[int]], dp: List[int], size: List[int]):
  size[node] = 1
  for it in adj[node]:
    if it != par:
      dfs(it, node, adj, dp, size)
      dp[node] += dp[it] + size[it]
      size[node] += size[it]
       
# Minimize the distance by removing edges
def dfs2(node: int, par: int, h: int, adj: List[List[int]], size: List[int], dp: List[int], ans: List[int]):
  temp = dp[0] - h * size[node]
  ans[0] = min(ans[0], temp)
  for it in adj[node]:
    if it != par:
      dfs2(it, node, h + 1, adj, size, dp, ans)
       
# return the sum of minimized distance of all nodes from vertex 0
def min_distance(n: int, edges: List[List[int]]) -> int:
  adj = [[] for _ in range(n)]
  for i in range(n - 1):
    adj[edges[i][0]].append(edges[i][1])
    adj[edges[i][1]].append(edges[i][0])
  dp = [0] * n
  size = [0] * n
  dfs(0, -1, adj, dp, size)
  ans = [dp[0]]
  for it in adj[0]:
    dfs2(it, 0, 0, adj, size, dp, ans)
  return ans[0]
 
n = 5
edges = [[0, 1], [1, 2], [2, 3], [3, 4]]
ans = min_distance(n, edges)
print(ans)


C#




// C# code for the above approach:
using System;
using System.Collections.Generic;
 
class GFG {
    // Calculates the dp[] and size[] arrays
    // for each node.
    static void Dfs(int node, int par, List<int>[] adj,
                    List<long> dp, List<long> size)
    {
        size[node] = 1;
        foreach(int it in adj[node])
        {
            if (it != par) {
                Dfs(it, node, adj, dp, size);
                dp[node] += (dp[it] + size[it]);
                size[node] += size[it];
            }
        }
    }
 
    // Minimize the distance by
    // removing edges
    static void Dfs2(int node, int par, int h,
                     List<int>[] adj, List<long> size,
                     List<long> dp, ref long ans)
    {
        long temp = dp[0] - (h * size[node]);
        ans = Math.Min(ans, temp);
 
        foreach(int it in adj[node])
        {
            if (it != par) {
                Dfs2(it, node, h + 1, adj, size, dp,
                     ref ans);
            }
        }
    }
 
    // Return the sum of minimized distance
    // of all nodes from vertex 0
    static long MinDistance(int n, List<List<int> > edges)
    {
        List<int>[] adj = new List<int>[ n ];
        for (int i = 0; i < n; i++) {
            adj[i] = new List<int>();
        }
 
        foreach(var edge in edges)
        {
            int u = edge[0];
            int v = edge[1];
            adj[u].Add(v);
            adj[v].Add(u);
        }
 
        List<long> dp = new List<long>(n);
        List<long> size = new List<long>(n);
 
        for (int i = 0; i < n; i++) {
            dp.Add(0);
            size.Add(0);
        }
 
        Dfs(0, -1, adj, dp, size);
        long ans = dp[0];
 
        foreach(int it in adj[0])
        {
            Dfs2(it, 0, 0, adj, size, dp, ref ans);
        }
 
        return ans;
    }
 
    // Driver code
    static void Main(string[] args)
    {
        int n = 5;
        List<List<int> > edges = new List<List<int> >() {
            new List<int>() { 0, 1 },
                new List<int>() { 1, 2 },
                new List<int>() { 2, 3 },
                new List<int>() { 3, 4 }
        };
 
        // Function call
        long ans = MinDistance(n, edges);
 
        Console.WriteLine(ans);
    }
}
 
// This code is contributed by shivamgupta0987654321


Javascript




// JAVASCRIPT code for the above approach:
// Calculates the dp[] and size[] arrays
// for each node.
function dfs(node, par, adj, dp, size) {
    size[node] = 1;
    for (let i = 0; i < adj[node].length; i++) {
        let it = adj[node][i];
        if (it !== par) {
            dfs(it, node, adj, dp, size);
            dp[node] += (dp[it] + size[it]);
            size[node] += size[it];
        }
    }
}
 
// Minimize the distance by
// removing edges
function dfs2(node, par, h, adj, size, dp, ans) {
    let temp = dp[0] - (h * size[node]);
    ans[0] = Math.min(ans[0], temp);
 
    for (let i = 0; i < adj[node].length; i++) {
        let it = adj[node][i];
        if (it !== par) {
            dfs2(it, node, h + 1, adj, size, dp, ans);
        }
    }
}
 
// Return the sum of minimized distance
// of all nodes from vertex 0
function minDistance(n, edges) {
    let adj = Array.from({ length: n }, () => []);
    for (let i = 0; i < n - 1; i++) {
        adj[edges[i][0]].push(edges[i][1]);
        adj[edges[i][1]].push(edges[i][0]);
    }
    let dp = Array(n).fill(0);
    let size = Array(n).fill(0);
 
    dfs(0, -1, adj, dp, size);
    let ans = [dp[0]];
 
    for (let i = 0; i < adj[0].length; i++) {
        let it = adj[0][i];
        dfs2(it, 0, 0, adj, size, dp, ans);
    }
    return ans[0];
}
 
// Driver code
function main() {
    let n = 5;
    let edges = [[0, 1], [1, 2], [2, 3], [3, 4]];
 
    // Function call
    let ans = minDistance(n, edges);
 
    console.log(ans);
}
 
main();
 
 
// This code is contributed by rambabuguphka


Output

6


Time Complexity: O(Vertices + Edges) = O(N+(N-1)) ~ O(2N) = O(N) 
Auxiliary Space: O(Vertices) = O(N)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads