#include <bits/stdc++.h>
using
namespace
std;
unordered_map<
int
,
bool
> mp;
struct
Node {
int
key;
struct
Node *left, *right;
};
Node* newNode(
int
key)
{
Node* temp =
new
Node;
temp->key = key;
temp->left = temp->right = NULL;
return
(temp);
}
bool
deleteNode(
int
nodeVal)
{
return
mp.find(nodeVal) != mp.end();
}
Node* treePruning(Node* root, vector<Node*>& result)
{
if
(root == NULL)
return
NULL;
root->left = treePruning(root->left, result);
root->right = treePruning(root->right, result);
if
(deleteNode(root->key)) {
if
(root->left) {
result.push_back(root->left);
}
if
(root->right) {
result.push_back(root->right);
}
return
NULL;
}
return
root;
}
void
printInorderTree(Node* root)
{
if
(root == NULL)
return
;
printInorderTree(root->left);
cout << root->key <<
" "
;
printInorderTree(root->right);
}
void
printForests(Node* root,
int
arr[],
int
n)
{
for
(
int
i = 0; i < n; i++) {
mp[arr[i]] =
true
;
}
vector<Node*> result;
if
(treePruning(root, result))
result.push_back(root);
for
(
int
i = 0; i < result.size(); i++) {
printInorderTree(result[i]);
cout << endl;
}
}
int
main()
{
Node* root = newNode(1);
root->left = newNode(12);
root->right = newNode(13);
root->right->left = newNode(14);
root->right->right = newNode(15);
root->right->left->left = newNode(21);
root->right->left->right = newNode(22);
root->right->right->left = newNode(23);
root->right->right->right = newNode(24);
int
arr[] = { 14, 23, 1 };
int
n =
sizeof
(arr) /
sizeof
(arr[0]);
printForests(root, arr, n);
}