#include <bits/stdc++.h>
using
namespace
std;
struct
Node {
int
val;
Node* left;
Node* right;
Node(
int
x)
: val(x), left(NULL), right(NULL)
{
}
};
void
traverse(Node* node,
int
level, vector<
int
>& rightmost)
{
if
(!node) {
return
;
}
if
(rightmost.size() == level) {
rightmost.push_back(node->val);
}
traverse(node->right, level + 1, rightmost);
traverse(node->left, level + 1, rightmost);
}
int
count_to_remove(Node* node,
int
level,
vector<
int
>& rightmost)
{
if
(!node) {
return
0;
}
int
count = 0;
if
(node->left) {
count += count_to_remove(node->left, level + 1,
rightmost);
if
(rightmost[level] != node->val) {
count++;
}
}
if
(node->right) {
count += count_to_remove(node->right, level + 1,
rightmost);
if
(rightmost[level] != node->val) {
count++;
}
}
return
count;
}
int
min_nodes_to_remove(Node* root)
{
vector<
int
> rightmost;
traverse(root, 0, rightmost);
return
count_to_remove(root, 0, rightmost);
}
int
main()
{
Node* root =
new
Node(1);
root->left =
new
Node(2);
root->left->right =
new
Node(8);
root->left->right->right =
new
Node(6);
root->right =
new
Node(3);
root->right->right =
new
Node(5);
cout << min_nodes_to_remove(root) << endl;
return
0;
}