import
java.util.*;
public
class
SubtreeQueries {
static
void
eulerTour(List<List<Integer>> tree, List<Integer> vst,
int
root, List<Integer> tour) {
tour.add(root);
vst.set(root,
1
);
for
(
int
x : tree.get(root)) {
if
(vst.get(x) ==
0
) {
eulerTour(tree, vst, x, tour);
}
}
tour.add(root);
}
static
void
createStartEnd(List<Integer> tour, List<Integer> start, List<Integer> end) {
for
(
int
i =
1
; i < tour.size(); ++i) {
int
node = tour.get(i);
if
(start.get(node) == -
1
) {
start.set(node, i);
}
else
{
end.set(node, i);
}
}
}
static
List<Pair<Integer, Integer>> createSortedTour(List<Integer> tour) {
List<Pair<Integer, Integer>> arr =
new
ArrayList<>();
for
(
int
i =
1
; i < tour.size(); ++i) {
arr.add(
new
Pair<>(tour.get(i), i));
}
Collections.sort(arr, Comparator.comparingInt(Pair::getKey));
return
arr;
}
static
void
increment(List<Integer> bit,
int
pos) {
for
(; pos < bit.size(); pos += pos & -pos) {
bit.set(pos, bit.get(pos) +
1
);
}
}
static
int
query(List<Integer> bit,
int
start,
int
end) {
--start;
int
s1 =
0
, s2 =
0
;
for
(; start >
0
; start -= start & -start) {
s1 += bit.get(start);
}
for
(; end >
0
; end -= end & -end) {
s2 += bit.get(end);
}
return
s2 - s1;
}
static
Map<Pair<Integer, Integer>, Integer> cal(
int
N,
int
Q, List<List<Integer>> tree, List<Pair<Integer, Integer>> queries) {
List<Integer> tour =
new
ArrayList<>();
List<Integer> vst =
new
ArrayList<>(Collections.nCopies(N +
1
,
0
));
List<Integer> start =
new
ArrayList<>(Collections.nCopies(N +
1
, -
1
));
List<Integer> end =
new
ArrayList<>(Collections.nCopies(N +
1
, -
1
));
List<Integer> bit =
new
ArrayList<>(Collections.nCopies(
2
* N +
4
,
0
));
tour.add(-
1
);
eulerTour(tree, vst,
1
, tour);
createStartEnd(tour, start, end);
List<Pair<Integer, Integer>> sortedTour = createSortedTour(tour);
List<Pair<Integer, Integer>> sortedQuery =
new
ArrayList<>(queries);
sortedQuery.sort(Comparator.comparingInt(Pair::getKey));
Map<Pair<Integer, Integer>, Integer> queryAns =
new
LinkedHashMap<>();
int
tourptr =
0
, queryptr =
0
;
while
(queryptr < sortedQuery.size()) {
while
(queryptr < sortedQuery.size() && sortedQuery.get(queryptr).getKey() <= sortedTour.get(tourptr).getKey()) {
int
node = sortedQuery.get(queryptr).getValue();
queryAns.put(sortedQuery.get(queryptr), query(bit, start.get(node), end.get(node)) /
2
);
++queryptr;
}
if
(tourptr < sortedTour.size()) {
increment(bit, sortedTour.get(tourptr).getValue());
++tourptr;
}
}
return
queryAns;
}
public
static
void
main(String[] args) {
int
N =
7
, Q =
3
;
List<List<Integer>> tree = Arrays.asList(
Collections.emptyList(),
Arrays.asList(
4
,
6
),
Collections.singletonList(
4
),
Collections.singletonList(
4
),
Arrays.asList(
1
,
2
,
3
,
5
),
Collections.singletonList(
4
),
Arrays.asList(
1
,
7
),
Collections.singletonList(
6
)
);
List<Pair<Integer, Integer>> queries = Arrays.asList(
new
Pair<>(
4
,
1
),
new
Pair<>(
7
,
6
),
new
Pair<>(
5
,
1
)
);
Map<Pair<Integer, Integer>, Integer> queryAns = cal(N, Q, tree, queries);
for
(Pair<Integer, Integer> x : queries) {
System.out.println(queryAns.get(x));
}
}
}
class
Pair<K, V> {
private
K key;
private
V value;
public
Pair(K key, V value) {
this
.key = key;
this
.value = value;
}
public
K getKey() {
return
key;
}
public
V getValue() {
return
value;
}
}