import
java.util.ArrayList;
import
java.util.List;
class
Tree {
static
void
addEdge(List<
int
[]> edges,
List<Integer>[] tree,
int
x,
int
y)
{
edges.add(
new
int
[] { x, y });
tree[x].add(y);
tree[y].add(x);
}
static
void
dfs(List<
int
[]> edges, List<Integer>[] tree,
int
node,
int
parent,
int
[] dp)
{
dp[node] =
1
;
for
(
int
it : tree[node]) {
if
(it != parent) {
dfs(edges, tree, it, node, dp);
dp[node] += dp[it];
}
}
}
static
int
maximizeSum(
int
[] a, List<
int
[]> edges,
List<Integer>[] tree,
int
n)
{
int
[] dp =
new
int
[n +
1
];
dfs(edges, tree,
1
,
0
, dp);
java.util.Arrays.sort(a,
0
, n -
1
);
List<Integer> ans =
new
ArrayList<Integer>();
for
(
int
[] it : edges) {
int
x = it[
0
];
int
y = it[
1
];
if
(dp[x] < dp[y]) {
int
fi = n - dp[x];
int
sec = dp[x];
ans.add(fi * sec);
}
else
{
int
fi = n - dp[y];
int
sec = dp[y];
ans.add(fi * sec);
}
}
ans.sort(
null
);
int
res =
0
;
for
(
int
i =
0
; i < n -
1
; i++) {
res += ans.get(i) * a[i];
}
return
res;
}
public
static
void
main(String[] args)
{
int
n =
5
;
List<
int
[]> edges =
new
ArrayList<
int
[]>();
List<Integer>[] tree =
new
ArrayList[n +
1
];
for
(
int
i =
0
; i < n +
1
; i++) {
tree[i] =
new
ArrayList<Integer>();
}
addEdge(edges, tree,
1
,
2
);
addEdge(edges, tree,
1
,
3
);
addEdge(edges, tree,
3
,
4
);
addEdge(edges, tree,
3
,
5
);
int
[] a = {
6
,
3
,
1
,
9
,
3
};
System.out.println(maximizeSum(a, edges, tree, n));
}
}