#include <bits/stdc++.h>
using
namespace
std;
struct
TreeNode {
int
val;
TreeNode* left;
TreeNode* right;
TreeNode(
int
val)
{
this
->val = val;
this
->left = NULL;
this
->right = NULL;
}
};
vector<TreeNode*> path;
bool
findPath(TreeNode* node, TreeNode* target)
{
if
(node == NULL) {
return
false
;
}
if
(node == target || findPath(node->left, target)
|| findPath(node->right, target)) {
path.push_back(node);
return
true
;
}
return
false
;
}
void
findKDistanceFromNode(TreeNode* node,
int
dist,
vector<
int
>& result,
TreeNode* blocker)
{
if
(dist < 0 || node == NULL
|| (blocker != NULL && node == blocker)) {
return
;
}
if
(dist == 0) {
result.push_back(node->val);
}
findKDistanceFromNode(node->left, dist - 1, result,
blocker);
findKDistanceFromNode(node->right, dist - 1, result,
blocker);
}
vector<
int
> distanceK(TreeNode* root, TreeNode* target,
int
K)
{
findPath(root, target);
vector<
int
> result;
for
(
int
i = 0; i < path.size(); i++) {
findKDistanceFromNode(path[i], K - i, result,
i == 0 ? NULL : path[i - 1]);
}
return
result;
}
int
main()
{
TreeNode* root =
new
TreeNode(20);
root->left =
new
TreeNode(8);
root->right =
new
TreeNode(22);
root->left->left =
new
TreeNode(4);
root->left->right =
new
TreeNode(12);
root->left->right->left =
new
TreeNode(10);
root->left->right->right =
new
TreeNode(4);
TreeNode* target = root->left->right;
vector<
int
> result = distanceK(root, target, 2);
cout <<
"["
;
for
(
int
i = 0; i < result.size() - 1; i++) {
cout << result[i] <<
", "
;
}
cout << result[result.size() - 1] <<
"]"
<< endl;
return
0;
}