AucFind the distance between two keys in a binary tree, no parent pointers are given. The distance between two nodes is the minimum number of edges to be traversed to reach one node from another.

The distance between two nodes can be obtained in terms of lowest common ancestor. Following is the formula.
Dist(n1, n2) = Dist(root, n1) + Dist(root, n2) - 2*Dist(root, lca)
'n1' and 'n2' are the two given keys
'root' is root of given Binary Tree.
'lca' is lowest common ancestor of n1 and n2
Dist(n1, n2) is the distance between n1 and n2.
Following is the implementation of the above approach. The implementation is adopted from the last code provided in Lowest Common Ancestor Post.
C++
#include <iostream>
using namespace std;
struct Node
{
struct Node *left, *right;
int key;
};
Node* newNode(int key)
{
Node *temp = new Node;
temp->key = key;
temp->left = temp->right = NULL;
return temp;
}
int findLevel(Node *root, int k, int level)
{
if (root == NULL)
return -1;
if (root->key == k)
return level;
int l = findLevel(root->left, k, level+1);
return (l != -1)? l : findLevel(root->right, k, level+1);
}
Node *findDistUtil(Node* root, int n1, int n2, int &d1,
int &d2, int &dist, int lvl)
{
if (root == NULL) return NULL;
if (root->key == n1)
{
d1 = lvl;
return root;
}
if (root->key == n2)
{
d2 = lvl;
return root;
}
Node *left_lca = findDistUtil(root->left, n1, n2,
d1, d2, dist, lvl+1);
Node *right_lca = findDistUtil(root->right, n1, n2,
d1, d2, dist, lvl+1);
if (left_lca && right_lca)
{
dist = d1 + d2 - 2*lvl;
return root;
}
return (left_lca != NULL)? left_lca: right_lca;
}
int findDistance(Node *root, int n1, int n2)
{
int d1 = -1, d2 = -1, dist;
Node *lca = findDistUtil(root, n1, n2, d1, d2,
dist, 1);
if (d1 != -1 && d2 != -1)
return dist;
if (d1 != -1)
{
dist = findLevel(lca, n2, 0);
return dist;
}
if (d2 != -1)
{
dist = findLevel(lca, n1, 0);
return dist;
}
return -1;
}
int main()
{
Node * root = newNode(1);
root->left = newNode(2);
root->right = newNode(3);
root->left->left = newNode(4);
root->left->right = newNode(5);
root->right->left = newNode(6);
root->right->right = newNode(7);
root->right->left->right = newNode(8);
cout << "Dist(4, 5) = " << findDistance(root, 4, 5);
cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);
cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);
cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);
cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);
return 0;
}
|
Java
class GFG
{
static int d1 = -1;
static int d2 = -1;
static int dist = 0;
static class Node
{
Node left, right;
int key;
Node(int key)
{
this.key = key;
left = null;
right = null;
}
}
static int findLevel(Node root, int k,
int level)
{
if (root == null)
{
return -1;
}
if (root.key == k)
{
return level;
}
int l = findLevel(root.left, k, level + 1);
return (l != -1)? l : findLevel(root.right, k,
level + 1);
}
public static Node findDistUtil(Node root, int n1,
int n2, int lvl)
{
if (root == null)
{
return null;
}
if (root.key == n1)
{
d1 = lvl;
return root;
}
if (root.key == n2)
{
d2 = lvl;
return root;
}
Node left_lca = findDistUtil(root.left, n1,
n2, lvl + 1);
Node right_lca = findDistUtil(root.right, n1,
n2, lvl + 1);
if (left_lca != null && right_lca != null)
{
dist = (d1 + d2) - 2 * lvl;
return root;
}
return (left_lca != null)? left_lca : right_lca;
}
public static int findDistance(Node root, int n1, int n2)
{
d1 = -1;
d2 = -1;
dist = 0;
Node lca = findDistUtil(root, n1, n2, 1);
if (d1 != -1 && d2 != -1)
{
return dist;
}
if (d1 != -1)
{
dist = findLevel(lca, n2, 0);
return dist;
}
if (d2 != -1)
{
dist = findLevel(lca, n1, 0);
return dist;
}
return -1;
}
public static void main(String[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
System.out.println("Dist(4, 5) = " +
findDistance(root, 4, 5));
System.out.println("Dist(4, 6) = " +
findDistance(root, 4, 6));
System.out.println("Dist(3, 4) = " +
findDistance(root, 3, 4));
System.out.println("Dist(2, 4) = " +
findDistance(root, 2, 4));
System.out.println("Dist(8, 5) = " +
findDistance(root, 8, 5));
}
}
|
Python3
class Node:
def __init__(self, data):
self.data = data
self.right = None
self.left = None
def pathToNode(root, path, k):
if root is None:
return False
path.append(root.data)
if root.data == k :
return True
if ((root.left != None and pathToNode(root.left, path, k)) or
(root.right!= None and pathToNode(root.right, path, k))):
return True
path.pop()
return False
def distance(root, data1, data2):
if root:
path1 = []
pathToNode(root, path1, data1)
path2 = []
pathToNode(root, path2, data2)
i=0
while i<len(path1) and i<len(path2):
if path1[i] != path2[i]:
break
i = i+1
return (len(path1)+len(path2)-2*i)
else:
return 0
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.right.right= Node(7)
root.right.left = Node(6)
root.left.right = Node(5)
root.right.left.right = Node(8)
dist = distance(root, 4, 5)
print ("Distance between node {} & {}: {}".format(4, 5, dist))
dist = distance(root, 4, 6)
print ("Distance between node {} & {}: {}".format(4, 6, dist))
dist = distance(root, 3, 4)
print ("Distance between node {} & {}: {}".format(3, 4, dist))
dist = distance(root, 2, 4)
print ("Distance between node {} & {}: {}".format(2, 4, dist))
dist = distance(root, 8, 5)
print ("Distance between node {} & {}: {}".format(8, 5, dist))
|
C#
using System;
class GFG
{
public static int d1 = -1;
public static int d2 = -1;
public static int dist = 0;
public class Node
{
public Node left, right;
public int key;
public Node(int key)
{
this.key = key;
left = null;
right = null;
}
}
public static int findLevel(Node root, int k,
int level)
{
if (root == null)
{
return -1;
}
if (root.key == k)
{
return level;
}
int l = findLevel(root.left, k, level + 1);
return (l != -1)? l : findLevel(root.right, k,
level + 1);
}
public static Node findDistUtil(Node root, int n1,
int n2, int lvl)
{
if (root == null)
{
return null;
}
if (root.key == n1)
{
d1 = lvl;
return root;
}
if (root.key == n2)
{
d2 = lvl;
return root;
}
Node left_lca = findDistUtil(root.left, n1,
n2, lvl + 1);
Node right_lca = findDistUtil(root.right, n1,
n2, lvl + 1);
if (left_lca != null && right_lca != null)
{
dist = (d1 + d2) - 2 * lvl;
return root;
}
return (left_lca != null)? left_lca : right_lca;
}
public static int findDistance(Node root, int n1, int n2)
{
d1 = -1;
d2 = -1;
dist = 0;
Node lca = findDistUtil(root, n1, n2, 1);
if (d1 != -1 && d2 != -1)
{
return dist;
}
if (d1 != -1)
{
dist = findLevel(lca, n2, 0);
return dist;
}
if (d2 != -1)
{
dist = findLevel(lca, n1, 0);
return dist;
}
return -1;
}
public static void Main(string[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
Console.WriteLine("Dist(4, 5) = " +
findDistance(root, 4, 5));
Console.WriteLine("Dist(4, 6) = " +
findDistance(root, 4, 6));
Console.WriteLine("Dist(3, 4) = " +
findDistance(root, 3, 4));
Console.WriteLine("Dist(2, 4) = " +
findDistance(root, 2, 4));
Console.WriteLine("Dist(8, 5) = " +
findDistance(root, 8, 5));
}
}
|
Javascript
let d1 = -1;
let d2 = -1;
let dist = 0;
class Node
{
constructor(key) {
this.left = null;
this.right = null;
this.key = key;
}
}
function findLevel(root, k, level)
{
if (root == null)
{
return -1;
}
if (root.key == k)
{
return level;
}
let l = findLevel(root.left, k, level + 1);
return (l != -1)? l : findLevel(root.right, k, level + 1);
}
function findDistUtil(root, n1, n2, lvl)
{
if (root == null)
{
return null;
}
if (root.key == n1)
{
d1 = lvl;
return root;
}
if (root.key == n2)
{
d2 = lvl;
return root;
}
let left_lca = findDistUtil(root.left, n1,
n2, lvl + 1);
let right_lca = findDistUtil(root.right, n1,
n2, lvl + 1);
if (left_lca != null && right_lca != null)
{
dist = (d1 + d2) - 2 * lvl;
return root;
}
return (left_lca != null)? left_lca : right_lca;
}
function findDistance(root, n1, n2)
{
d1 = -1;
d2 = -1;
dist = 0;
let lca = findDistUtil(root, n1, n2, 1);
if (d1 != -1 && d2 != -1)
{
return dist;
}
if (d1 != -1)
{
dist = findLevel(lca, n2, 0);
return dist;
}
if (d2 != -1)
{
dist = findLevel(lca, n1, 0);
return dist;
}
return -1;
}
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
console.log("Dist(4, 5) = " +
findDistance(root, 4, 5) + "</br>");
console.log("Dist(4, 6) = " +
findDistance(root, 4, 6) + "</br>");
console.log("Dist(3, 4) = " +
findDistance(root, 3, 4) + "</br>");
console.log("Dist(2, 4) = " +
findDistance(root, 2, 4) + "</br>");
console.log("Dist(8, 5) = " +
findDistance(root, 8, 5) + "</br>");
|
Output
Dist(4, 5) = 2
Dist(4, 6) = 4
Dist(3, 4) = 3
Dist(2, 4) = 1
Dist(8, 5) = 5
Time Complexity: O(n), As the method does a single tree traversal. Here n is the number of elements in the tree.
Auxiliary Space: O(h), Here h is the height of the tree and the extra space is used in recursion call stack.
Thanks to Atul Singh for providing the initial solution for this post.
Better Solution :
We first find the LCA of two nodes. Then we find the distance from LCA to two nodes.
C++
#include <iostream>
using namespace std;
struct Node {
struct Node *left, *right;
int key;
};
Node* newNode(int key)
{
Node* temp = new Node;
temp->key = key;
temp->left = temp->right = NULL;
return temp;
}
Node* LCA(Node* root, int n1, int n2)
{
if (root == NULL)
return root;
if (root->key == n1 || root->key == n2)
return root;
Node* left = LCA(root->left, n1, n2);
Node* right = LCA(root->right, n1, n2);
if (left != NULL && right != NULL)
return root;
if (left == NULL && right == NULL)
return NULL;
if (left != NULL)
return LCA(root->left, n1, n2);
return LCA(root->right, n1, n2);
}
int findLevel(Node* root, int k, int level)
{
if (root == NULL)
return -1;
if (root->key == k)
return level;
int left = findLevel(root->left, k, level + 1);
if (left == -1)
return findLevel(root->right, k, level + 1);
return left;
}
int findDistance(Node* root, int a, int b)
{
Node* lca = LCA(root, a, b);
int d1 = findLevel(lca, a, 0);
int d2 = findLevel(lca, b, 0);
return d1 + d2;
}
int main()
{
Node* root = newNode(1);
root->left = newNode(2);
root->right = newNode(3);
root->left->left = newNode(4);
root->left->right = newNode(5);
root->right->left = newNode(6);
root->right->right = newNode(7);
root->right->left->right = newNode(8);
cout << "Dist(4, 5) = " << findDistance(root, 4, 5);
cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);
cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);
cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);
cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);
return 0;
}
|
Java
public class GFG {
public static class Node {
int value;
Node left;
Node right;
public Node(int value) { this.value = value; }
}
public static Node LCA(Node root, int n1, int n2)
{
if (root == null)
return root;
if (root.value == n1 || root.value == n2)
return root;
Node left = LCA(root.left, n1, n2);
Node right = LCA(root.right, n1, n2);
if (left != null && right != null)
return root;
if (left == null && right == null)
return null;
if (left != null)
return left;
else
return right;
}
public static int findLevel(Node root, int a, int level)
{
if (root == null)
return -1;
if (root.value == a)
return level;
int left = findLevel(root.left, a, level + 1);
if (left == -1)
return findLevel(root.right, a, level + 1);
return left;
}
public static int findDistance(Node root, int a, int b)
{
Node lca = LCA(root, a, b);
int d1 = findLevel(lca, a, 0);
int d2 = findLevel(lca, b, 0);
return d1 + d2;
}
public static void main(String[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
System.out.println("Dist(4, 5) = "
+ findDistance(root, 4, 5));
System.out.println("Dist(4, 6) = "
+ findDistance(root, 4, 6));
System.out.println("Dist(3, 4) = "
+ findDistance(root, 3, 4));
System.out.println("Dist(2, 4) = "
+ findDistance(root, 2, 4));
System.out.println("Dist(8, 5) = "
+ findDistance(root, 8, 5));
}
}
|
Python3
class Node:
def __init__(self, data):
self.data = data
self.left = self.right = None
def find_least_common_ancestor(root: Node, n1: int, n2: int) -> Node:
if root is None:
return root
if root.data == n1 or root.data == n2:
return root
left = find_least_common_ancestor(root.left, n1, n2)
right = find_least_common_ancestor(root.right, n1, n2)
if left and right:
return root
if left:
return left
else:
return right
def find_distance_from_ancestor_node(root: Node, data: int) -> int:
if root is None:
return -1
if root.data == data:
return 0
left = find_distance_from_ancestor_node(root.left, data)
right = find_distance_from_ancestor_node(root.right, data)
distance = max(left, right)
return distance+1 if distance >= 0 else -1
def find_distance_between_two_nodes(root: Node, n1: int, n2: int):
lca = find_least_common_ancestor(root, n1, n2)
return find_distance_from_ancestor_node(lca, n1) + find_distance_from_ancestor_node(lca, n2) if lca else -1
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)
root.right.left = Node(6)
root.right.right = Node(7)
root.right.left.right = Node(8)
print("Dist(4,5) = ", find_distance_between_two_nodes(root, 4, 5))
print("Dist(4,6) = ", find_distance_between_two_nodes(root, 4, 6))
print("Dist(3,4) = ", find_distance_between_two_nodes(root, 3, 4))
print("Dist(2,4) = ", find_distance_between_two_nodes(root, 2, 4))
print("Dist(8,5) = ", find_distance_between_two_nodes(root, 8, 5))
|
C#
using System;
public class GFG {
public class Node {
public int value;
public Node left;
public Node right;
public Node(int value) { this.value = value; }
}
public static Node LCA(Node root, int n1, int n2)
{
if (root == null) {
return root;
}
if (root.value == n1 || root.value == n2) {
return root;
}
Node left = LCA(root.left, n1, n2);
Node right = LCA(root.right, n1, n2);
if (left != null && right != null) {
return root;
}
if (left == null && right == null) {
return null;
}
if (left != null) {
return LCA(root.left, n1, n2);
}
else {
return LCA(root.right, n1, n2);
}
}
public static int findLevel(Node root, int a, int level)
{
if (root == null) {
return -1;
}
if (root.value == a) {
return level;
}
int left = findLevel(root.left, a, level + 1);
if (left == -1) {
return findLevel(root.right, a, level + 1);
}
return left;
}
public static int findDistance(Node root, int a, int b)
{
Node lca = LCA(root, a, b);
int d1 = findLevel(lca, a, 0);
int d2 = findLevel(lca, b, 0);
return d1 + d2;
}
public static void Main(string[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
Console.WriteLine("Dist(4, 5) = "
+ findDistance(root, 4, 5));
Console.WriteLine("Dist(4, 6) = "
+ findDistance(root, 4, 6));
Console.WriteLine("Dist(3, 4) = "
+ findDistance(root, 3, 4));
Console.WriteLine("Dist(2, 4) = "
+ findDistance(root, 2, 4));
Console.WriteLine("Dist(8, 5) = "
+ findDistance(root, 8, 5));
}
}
|
Javascript
<script>
class Node
{
constructor(value)
{
this.value = value;
this.left = null;
this.right = null;
}
}
function LCA(root, n1, n2)
{
if (root == null)
{
return root;
}
if (root.value == n1 || root.value == n2)
{
return root;
}
var left = LCA(root.left, n1, n2);
var right = LCA(root.right, n1, n2);
if (left != null && right != null)
{
return root;
}
if (left == null && right == null)
{
return null;
}
if (left != null)
{
return LCA(root.left, n1, n2);
}
else
{
return LCA(root.right, n1, n2);
}
}
function findLevel(root, a, level)
{
if (root == null)
{
return -1;
}
if (root.value == a)
{
return level;
}
var left = findLevel(root.left, a, level + 1);
if (left == -1)
{
return findLevel(root.right, a, level + 1);
}
return left;
}
function findDistance(root, a, b)
{
var lca = LCA(root, a, b);
var d1 = findLevel(lca, a, 0);
var d2 = findLevel(lca, b, 0);
return d1 + d2;
}
var root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
document.write("Dist(4, 5) = " +
findDistance(root, 4, 5) + "<br>");
document.write("Dist(4, 6) = " +
findDistance(root, 4, 6) + "<br>");
document.write("Dist(3, 4) = " +
findDistance(root, 3, 4) + "<br>");
document.write("Dist(2, 4) = " +
findDistance(root, 2, 4) + "<br>");
document.write("Dist(8, 5) = " +
findDistance(root, 8, 5) + "<br>");
</script>
|
Output
Dist(4, 5) = 2
Dist(4, 6) = 4
Dist(3, 4) = 3
Dist(2, 4) = 1
Dist(8, 5) = 5
Time Complexity: O(n), As the method does a single tree traversal. Here n is the number of elements in the tree.
Auxiliary Space: O(h), Here h is the height of the tree and the extra space is used in recursion call stack.
Thanks to NILMADHAB MONDAL for suggesting this solution.
Another Better Solution (one pass):
We know that distance between two node(let suppose n1 and n2) = distance between LCA and n1 + distance between LCA and n2.
A general solution using above formula that may come to your mind is :
int findDistance(Node* root, int n1, int n2) {
if (!root) return 0;
if (root->data == n1 || root->data == n2)
return 1;
int left = findDistance(root->left, n1, n2);
int right = findDistance(root->right, n1, n2);
if (left && right)
return left + right;
else if (left || right)
return max(left, right) + 1;
return 0;
}
But this solution has a flaw (a missing edge case) when n2 is Descendant of n1 or n1 is Descendant of n2.
Below is dry run of above code with edge case example :

In the above binary tree expected output is 2 but the function will give output as 3. This situation is overcome in the solution code given below :
Note : both n1 and n2 should be present in Binary Tree.
C++
#include <iostream>
using namespace std;
struct Node {
struct Node *left, *right;
int key;
};
Node* newNode(int key)
{
Node* temp = new Node;
temp->key = key;
temp->left = temp->right = NULL;
return temp;
}
int ans;
int _findDistance(Node* root, int n1, int n2)
{
if (!root) return 0;
int left = _findDistance(root->left, n1, n2);
int right = _findDistance(root->right, n1, n2);
if (root->key == n1 || root->key == n2)
{
if (left || right)
{
ans = max(left, right);
return 0;
}
else
return 1;
}
else if (left && right)
{
ans = left + right;
return 0;
}
else if (left || right)
return max(left, right) + 1;
return 0;
}
int findDistance(Node* root, int n1, int n2)
{
ans = 0;
_findDistance(root, n1, n2);
return ans;
}
int main()
{
Node* root = newNode(1);
root->left = newNode(2);
root->right = newNode(3);
root->left->left = newNode(4);
root->left->right = newNode(5);
root->right->left = newNode(6);
root->right->right = newNode(7);
root->right->left->right = newNode(8);
cout << "Dist(4, 5) = " << findDistance(root, 4, 5);
cout << "\nDist(4, 6) = " << findDistance(root, 4, 6);
cout << "\nDist(3, 4) = " << findDistance(root, 3, 4);
cout << "\nDist(2, 4) = " << findDistance(root, 2, 4);
cout << "\nDist(8, 5) = " << findDistance(root, 8, 5);
return 0;
}
|
Java
public class Main {
public static class Node {
int value;
Node left;
Node right;
public Node(int value) { this.value = value; }
}
public static int ans;
public static int _findDistance(Node root, int n1, int n2)
{
if (root == null) return 0;
int left = _findDistance(root.left, n1, n2);
int right = _findDistance(root.right, n1, n2);
if (root.value == n1 || root.value == n2)
{
if (left != 0 || right != 0)
{
ans = Math.max(left, right);
return 0;
}
else
return 1;
}
else if (left != 0 && right != 0)
{
ans = left + right;
return 0;
}
else if (left != 0 || right != 0)
return Math.max(left, right) + 1;
return 0;
}
public static int findDistance(Node root, int n1, int n2)
{
ans = 0;
_findDistance(root, n1, n2);
return ans;
}
public static void main(String[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
System.out.println("Dist(4, 5) = "
+ findDistance(root, 4, 5));
System.out.println("Dist(4, 6) = "
+ findDistance(root, 4, 6));
System.out.println("Dist(3, 4) = "
+ findDistance(root, 3, 4));
System.out.println("Dist(2, 4) = "
+ findDistance(root, 2, 4));
System.out.println("Dist(8, 5) = "
+ findDistance(root, 8, 5));
}
}
|
Python3
class Node:
def __init__(self, key):
self.key = key
self.left = None
self.right = None
ans = 0
def _findDistance(root, n1, n2):
global ans
if not root:
return 0
left = _findDistance(root.left, n1, n2)
right = _findDistance(root.right, n1, n2)
if root.key == n1 or root.key == n2:
if left or right:
ans = max(left, right)
return 0
else:
return 1
elif left and right:
ans = left + right
return 0
elif left or right:
return max(left, right) + 1
return 0
def findDistance(root, n1, n2):
_findDistance(root, n1, n2)
return ans
if __name__ == '__main__':
root = Node(1)
root.left = Node(2)
root.right = Node(3)
root.left.left = Node(4)
root.left.right = Node(5)
root.right.left = Node(6)
root.right.right = Node(7)
root.right.left.right = Node(8)
print("Dist(4, 5) =", findDistance(root, 4, 5))
print("Dist(4, 6) =", findDistance(root, 4, 6))
print("Dist(3, 4) =", findDistance(root, 3, 4))
print("Dist(2, 4) =", findDistance(root, 2, 4))
print("Dist(8, 5) =", findDistance(root, 8, 5))
|
C#
using System;
class Program {
public class Node {
public int value;
public Node left;
public Node right;
public Node(int value) { this.value = value; }
}
public static int ans = 0;
public static int _findDistance(Node root, int n1,
int n2)
{
if (root == null)
return 0;
int left = _findDistance(root.left, n1, n2);
int right = _findDistance(root.right, n1, n2);
if (root.value == n1 || root.value == n2) {
if (left != 0 || right != 0) {
ans = Math.Max(left, right);
return 0;
}
else
return 1;
}
else if (left != 0 && right != 0) {
ans = left + right;
return 0;
}
else if (left != 0 || right != 0)
return Math.Max(left, right) + 1;
return 0;
}
public static int findDistance(Node root, int n1,
int n2)
{
ans = 0;
_findDistance(root, n1, n2);
return ans;
}
public static void Main(string[] args)
{
Node root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
Console.WriteLine("Dist({0}, {1}) = {2}", 4, 5,
findDistance(root, 4, 5));
Console.WriteLine("Dist({0}, {1}) = {2}", 4, 6,
findDistance(root, 4, 6));
Console.WriteLine("Dist({0}, {1}) = {2}", 3, 4,
findDistance(root, 3, 4));
Console.WriteLine("Dist({0}, {1}) = {2}", 2, 4,
findDistance(root, 2, 4));
Console.WriteLine("Dist({0}, {1}) = {2}", 8, 5,
findDistance(root, 8, 5));
}
}
|
Javascript
<script>
class Node {
constructor(key) {
this.key = key;
this.left = null;
this.right = null;
}
}
let ans = 0;
function _findDistance(root, n1, n2) {
if (!root) {
return 0;
}
let left = _findDistance(root.left, n1, n2);
let right = _findDistance(root.right, n1, n2);
if (root.key === n1 || root.key === n2) {
if (left || right) {
ans = Math.max(left, right);
return 0;
} else {
return 1;
}
}
else if (left && right) {
ans = left + right;
return 0;
}
else if (left || right) {
return Math.max(left, right) + 1;
}
return 0;
}
function findDistance(root, n1, n2) {
_findDistance(root, n1, n2);
return ans;
}
let root = new Node(1);
root.left = new Node(2);
root.right = new Node(3);
root.left.left = new Node(4);
root.left.right = new Node(5);
root.right.left = new Node(6);
root.right.right = new Node(7);
root.right.left.right = new Node(8);
document.write("Dist(4, 5) = ", findDistance(root, 4, 5));
document.write("Dist(4, 6) = ", findDistance(root, 4, 6));
document.write("Dist(3, 4) = ", findDistance(root, 3, 4));
document.write("Dist(2, 4) = ", findDistance(root, 2, 4));
document.write("Dist(8, 5) = ", findDistance(root, 8, 5));
</script>
|
Output
Dist(4, 5) = 2
Dist(4, 6) = 4
Dist(3, 4) = 3
Dist(2, 4) = 1
Dist(8, 5) = 5
Time Complexity: O(n), where n is the number of nodes in the binary tree.
Auxiliary Space: O(h), where h is the height of the binary tree.
Thanks to Gurudev Singh for suggesting this solution.
Feeling lost in the world of random DSA topics, wasting time without progress? It's time for a change! Join our DSA course, where we'll guide you on an exciting journey to master DSA efficiently and on schedule.
Ready to dive in? Explore our Free Demo Content and join our DSA course, trusted by over 100,000 geeks!
Commit to GfG's Three-90 Challenge! Purchase a course, complete 90% in 90 days, and save 90% cost click here to explore.
Last Updated :
30 Jan, 2023
Like Article
Save Article
Share your thoughts in the comments
Please Login to comment...