Open In App

Maximum Path Sum in a Binary Tree

Last Updated : 04 Oct, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a binary tree, find the maximum path sum. The path may start and end at any node in the tree.

Example: 

Input: Root of below tree

       1
      / \
    2   3

Output: 6

Input: 

Binary tree

Output: 42
Explanation: Max path sum is represented using green color nodes in the above binary tree

Approach: To solve the problem follow the below idea:

For each node there can be four ways that the max path goes through the node: 

  • Node only 
  • Max path through Left Child + Node 
  • Max path through Right Child + Node 
  • Max path through Left Child + Node + Max path through Right Child

The idea is to keep track of four paths and pick up the max one in the end. An important thing to note is, that the root of every subtree needs to return the maximum path sum such that at most one child of the root is involved. This is needed for the parent function call. In the below code, this sum is stored in ‘max_single’ and returned by the recursive function.

Follow the given steps to solve the problem:

  • If the root is NULL, return 0(Base Case)
  • Call the recursive function to find the max sum for the left and the right subtree
  • In a variable store the maximum of (root->data, maximum of (leftSum, rightSum) + root->data)
  • In another variable store the maximum of previous step and root->data + leftSum + rightSum
  • Return the maximum of the previous step

Below is the implementation of the above approach:

C++




// C/C++ program to find maximum path sum in Binary Tree
#include <bits/stdc++.h>
using namespace std;
 
// A binary tree node
struct Node {
    int data;
    struct Node *left, *right;
};
 
// A utility function to allocate a new node
struct Node* newNode(int data)
{
    struct Node* newNode = new Node;
    newNode->data = data;
    newNode->left = newNode->right = NULL;
    return (newNode);
}
 
// This function returns overall maximum path sum in 'res'
// And returns max path sum going through root.
int findMaxUtil(Node* root, int& res)
{
    // Base Case
    if (root == NULL)
        return 0;
 
    // l and r store maximum path sum going through left and
    // right child of root respectively
    int l = findMaxUtil(root->left, res);
    int r = findMaxUtil(root->right, res);
 
    // Max path for parent call of root. This path must
    // include at-most one child of root
    int max_single
        = max(max(l, r) + root->data, root->data);
 
    // Max Top represents the sum when the Node under
    // consideration is the root of the maxsum path and no
    // ancestors of root are there in max sum path
    int max_top = max(max_single, l + r + root->data);
 
    res = max(res, max_top); // Store the Maximum Result.
 
    return max_single;
}
 
// Returns maximum path sum in tree with given root
int findMaxSum(Node* root)
{
    // Initialize result
    int res = INT_MIN;
 
    // Compute and return result
    findMaxUtil(root, res);
    return res;
}
 
// Driver code
int main(void)
{
    struct Node* root = newNode(10);
    root->left = newNode(2);
    root->right = newNode(10);
    root->left->left = newNode(20);
    root->left->right = newNode(1);
    root->right->right = newNode(-25);
    root->right->right->left = newNode(3);
    root->right->right->right = newNode(4);
 
    // Function call
    cout << "Max path sum is " << findMaxSum(root);
    return 0;
}


Java




// Java program to find maximum path sum in Binary Tree
 
/* Class containing left and right child of current
 node and key value*/
class Node {
 
    int data;
    Node left, right;
 
    public Node(int item)
    {
        data = item;
        left = right = null;
    }
}
 
// An object of Res is passed around so that the
// same value can be used by multiple recursive calls.
class Res {
    public int val;
}
 
class BinaryTree {
 
    // Root of the Binary Tree
    Node root;
 
    // This function returns overall maximum path sum in
    // 'res' And returns max path sum going through root.
    int findMaxUtil(Node node, Res res)
    {
 
        // Base Case
        if (node == null)
            return 0;
 
        // l and r store maximum path sum going through left
        // and right child of root respectively
        int l = findMaxUtil(node.left, res);
        int r = findMaxUtil(node.right, res);
 
        // Max path for parent call of root. This path must
        // include at-most one child of root
        int max_single = Math.max(
            Math.max(l, r) + node.data, node.data);
 
        // Max Top represents the sum when the Node under
        // consideration is the root of the maxsum path and
        // no ancestors of root are there in max sum path
        int max_top
            = Math.max(max_single, l + r + node.data);
 
        // Store the Maximum Result.
        res.val = Math.max(res.val, max_top);
 
        return max_single;
    }
 
    int findMaxSum() { return findMaxSum(root); }
 
    // Returns maximum path sum in tree with given root
    int findMaxSum(Node node)
    {
 
        // Initialize result
        // int res2 = Integer.MIN_VALUE;
        Res res = new Res();
        res.val = Integer.MIN_VALUE;
 
        // Compute and return result
        findMaxUtil(node, res);
        return res.val;
    }
 
    /* Driver code */
    public static void main(String args[])
    {
        BinaryTree tree = new BinaryTree();
        tree.root = new Node(10);
        tree.root.left = new Node(2);
        tree.root.right = new Node(10);
        tree.root.left.left = new Node(20);
        tree.root.left.right = new Node(1);
        tree.root.right.right = new Node(-25);
        tree.root.right.right.left = new Node(3);
        tree.root.right.right.right = new Node(4);
 
        // Function call
        System.out.println("maximum path sum is : "
                           + tree.findMaxSum());
    }
}


Python




# Python3 program to find maximum path sum in Binary Tree
 
# A Binary Tree Node
 
 
class Node:
 
    # Constructor to create a new node
    def __init__(self, data):
        self.data = data
        self.left = None
        self.right = None
 
# This function returns overall maximum path sum in 'res'
# And returns max path sum going through root
 
 
def findMaxUtil(root):
 
    # Base Case
    if root is None:
        return 0
 
    # l and r store maximum path sum going through left
    # and right child of root respectively
    l = findMaxUtil(root.left)
    r = findMaxUtil(root.right)
 
    # Max path for parent call of root. This path
    # must include at most one child of root
    max_single = max(max(l, r) + root.data, root.data)
 
    # Max top represents the sum when the node under
    # consideration is the root of the maxSum path and
    # no ancestor of root are there in max sum path
    max_top = max(max_single, l+r + root.data)
 
    # Static variable to store the changes
    # Store the maximum result
    findMaxUtil.res = max(findMaxUtil.res, max_top)
 
    return max_single
 
# Return maximum path sum in tree with given root
 
 
def findMaxSum(root):
 
    # Initialize result
    findMaxUtil.res = float("-inf")
 
    # Compute and return result
    findMaxUtil(root)
    return findMaxUtil.res
 
 
# Driver code
if __name__ == '__main__':
    root = Node(10)
    root.left = Node(2)
    root.right = Node(10)
    root.left.left = Node(20)
    root.left.right = Node(1)
    root.right.right = Node(-25)
    root.right.right.left = Node(3)
    root.right.right.right = Node(4)
 
    # Function call
    print "Max path sum is ", findMaxSum(root)
 
# This code is contributed by Nikhil Kumar Singh(nickzuck_007)


C#




// C# program to find maximum
// path sum in Binary Tree
using System;
 
/* Class containing left and
right child of current
node and key value*/
public class Node {
 
    public int data;
    public Node left, right;
 
    public Node(int item)
    {
        data = item;
        left = right = null;
    }
}
 
// An object of Res is passed
// around so that the same value
// can be used by multiple recursive calls.
class Res {
    public int val;
}
 
public class BinaryTree {
 
    // Root of the Binary Tree
    Node root;
 
    // This function returns overall
    // maximum path sum in 'res' And
    // returns max path sum going through root.
    int findMaxUtil(Node node, Res res)
    {
 
        // Base Case
        if (node == null)
            return 0;
 
        // l and r store maximum path
        // sum going through left and
        // right child of root respectively
        int l = findMaxUtil(node.left, res);
        int r = findMaxUtil(node.right, res);
 
        // Max path for parent call of root.
        // This path must include
        // at-most one child of root
        int max_single = Math.Max(
            Math.Max(l, r) + node.data, node.data);
 
        // Max Top represents the sum
        // when the Node under
        // consideration is the root
        // of the maxsum path and no
        // ancestors of root are there
        // in max sum path
        int max_top
            = Math.Max(max_single, l + r + node.data);
 
        // Store the Maximum Result.
        res.val = Math.Max(res.val, max_top);
 
        return max_single;
    }
 
    int findMaxSum() { return findMaxSum(root); }
 
    // Returns maximum path
    // sum in tree with given root
    int findMaxSum(Node node)
    {
 
        // Initialize result
        // int res2 = int.MinValue;
        Res res = new Res();
        res.val = int.MinValue;
 
        // Compute and return result
        findMaxUtil(node, res);
        return res.val;
    }
 
    /* Driver code */
    public static void Main(String[] args)
    {
        BinaryTree tree = new BinaryTree();
        tree.root = new Node(10);
        tree.root.left = new Node(2);
        tree.root.right = new Node(10);
        tree.root.left.left = new Node(20);
        tree.root.left.right = new Node(1);
        tree.root.right.right = new Node(-25);
        tree.root.right.right.left = new Node(3);
        tree.root.right.right.right = new Node(4);
 
        // Function call
        Console.WriteLine("maximum path sum is : "
                          + tree.findMaxSum());
    }
}
 
// This code is contributed Rajput-Ji.


Javascript




<script>
 
    // JavaScript program to find maximum
    // path sum in Binary Tree
     
    class Node
    {
        constructor(item) {
           this.left = null;
           this.right = null;
           this.data = item;
        }
    }
     
    let val;
     
    // Root of the Binary Tree
    let root;
  
    // This function returns overall maximum path sum in 'res'
    // And returns max path sum going through root.
    function findMaxUtil(node)
    {
  
        // Base Case
        if (node == null)
            return 0;
  
        // l and r store maximum path sum going through left and
        // right child of root respectively
        let l = findMaxUtil(node.left);
        let r = findMaxUtil(node.right);
  
        // Max path for parent call of root. This path must
        // include at-most one child of root
        let max_single = Math.max(Math.max(l, r) + node.data,
                                  node.data);
  
  
        // Max Top represents the sum when the Node under
        // consideration is the root of the maxsum path and no
        // ancestors of root are there in max sum path
        let max_top = Math.max(max_single, l + r + node.data);
  
        // Store the Maximum Result.
        val = Math.max(val, max_top);
  
        return max_single;
    }
  
    function findMaxsum() {
        return findMaxSum(root);
    }
  
    // Returns maximum path sum in tree with given root
    function findMaxSum(node) {
  
        // Initialize result
        // int res2 = Integer.MIN_VALUE;
        val = Number.MIN_VALUE;
  
        // Compute and return result
        findMaxUtil(node);
        return val;
    }
     
    root = new Node(10);
    root.left = new Node(2);
    root.right = new Node(10);
    root.left.left = new Node(20);
    root.left.right = new Node(1);
    root.right.right = new Node(-25);
    root.right.right.left = new Node(3);
    root.right.right.right = new Node(4);
    document.write("Max path sum is : " + findMaxsum());
     
</script>


Output

Max path sum is 42







Time Complexity: O(N) where N is the number of nodes in the Binary Tree
Auxiliary Space: O(N)

Using Iterative DFS: 

The basic idea behind the iterative DFS approach to finding the maximum path sum in a binary tree is to traverse the tree using a stack, maintaining the state of each node as we visit it.

Follow the steps to implement the approach:

  1. Initialize the max_sum variable to INT_MIN and create a stack to perform iterative DFS. Push the root node with a state of 0 onto the stack.
  2. Pop the top element from the stack and check if it is null. If it is null, skip the rest of the loop. Otherwise, check the state of the node.
  3. If the state is 0, push the node onto the stack with a state of 1 and push its left child onto the stack with a state of 0.
  4. If the state is 1, push the node onto the stack with a state of 2 and push its right child onto the stack with a state of 0.
  5. If the state is 2, calculate the maximum path sum that goes through the current node. To do this, calculate the maximum path sum that goes through the left child and the maximum path sum that goes through the right child. If either of these sums is negative, then we don’t include that child in the path. We add the maximum path sum that goes from the current node to one of its children, and the current node’s value, to get the maximum path sum that goes through the current node. We update the max_sum variable if the maximum path sum that goes through the current node is greater than the current value of max_sum.
  6. Update the node’s value to be the maximum path sum that goes from the node to one of its children. This step is important because the maximum path sum that goes through the node will either include both of its children or just one of its children.
  7. Repeat steps 2 to 6 until the stack is empty.
  8. Return the max_sum variable.

Below is the implementation of the above approach:

C++




// C++ code to implement iterative DFS approach
#include <bits/stdc++.h>
using namespace std;
 
// Definition of a binary tree node
struct Node {
    int val;
    Node *left, *right;
    Node(int val)
        : val(val)
        , left(nullptr)
        , right(nullptr)
    {
    }
};
 
int findMaxSum(Node* root)
{
    int max_sum = INT_MIN;
    stack<pair<Node*, int> > s;
    s.push(make_pair(root, 0));
 
    while (!s.empty()) {
        auto node = s.top().first;
        int state = s.top().second;
        s.pop();
 
        if (node == nullptr) {
            continue;
        }
 
        if (state == 0) {
            // first visit to the node
            s.push(make_pair(node, 1));
            s.push(make_pair(node->left, 0));
        }
        else if (state == 1) {
            // second visit to the node
            s.push(make_pair(node, 2));
            s.push(make_pair(node->right, 0));
        }
        else {
            // third visit to the node
            int left_sum = (node->left != nullptr)
                               ? node->left->val
                               : 0;
            int right_sum = (node->right != nullptr)
                                ? node->right->val
                                : 0;
            max_sum
                = max(max_sum, node->val + max(0, left_sum)
                                   + max(0, right_sum));
            int max_child_sum = max(left_sum, right_sum);
            node->val += max(0, max_child_sum);
        }
    }
 
    return max_sum;
}
// Driver Code
int main()
{
    // Constructing the tree given in the above example
    Node* root = new Node(10);
    root->left = new Node(2);
    root->right = new Node(-25);
    root->left->left = new Node(20);
    root->left->right = new Node(1);
    root->right->left = new Node(3);
    root->right->right = new Node(4);
 
    int max_sum = findMaxSum(root);
    cout << "Maximum Path Sum: " << max_sum << endl;
 
    return 0;
}
// This code is contributed by Veerendra_Singh_Rajpoot


Java




import java.util.AbstractMap.SimpleEntry;
import java.util.Stack;
 
// Definition of a binary tree node
class Node {
    int val;
    Node left, right;
 
    Node(int val)
    {
        this.val = val;
        this.left = this.right = null;
    }
}
 
public class Main {
 
    static int findMaxSum(Node root)
    {
        int max_sum = Integer.MIN_VALUE;
        Stack<SimpleEntry<Node, Integer> > stack
            = new Stack<>();
        stack.push(new SimpleEntry<>(root, 0));
 
        while (!stack.empty()) {
            SimpleEntry<Node, Integer> pair = stack.pop();
            Node node = pair.getKey();
            int state = pair.getValue();
 
            if (node == null) {
                continue;
            }
 
            if (state == 0) {
                // first visit to the node
                stack.push(new SimpleEntry<>(node, 1));
                stack.push(new SimpleEntry<>(node.left, 0));
            }
            else if (state == 1) {
                // second visit to the node
                stack.push(new SimpleEntry<>(node, 2));
                stack.push(
                    new SimpleEntry<>(node.right, 0));
            }
            else {
                // third visit to the node
                int left_sum = (node.left != null)
                                   ? node.left.val
                                   : 0;
                int right_sum = (node.right != null)
                                    ? node.right.val
                                    : 0;
                max_sum = Math.max(
                    max_sum, node.val
                                 + Math.max(0, left_sum)
                                 + Math.max(0, right_sum));
                int max_child_sum
                    = Math.max(left_sum, right_sum);
                node.val += Math.max(0, max_child_sum);
            }
        }
 
        return max_sum;
    }
    // Driver Code
    public static void main(String[] args)
    {
        // Constructing the tree given in the above example
        Node root = new Node(10);
        root.left = new Node(2);
        root.right = new Node(-25);
        root.left.left = new Node(20);
        root.left.right = new Node(1);
        root.right.left = new Node(3);
        root.right.right = new Node(4);
 
        int max_sum = findMaxSum(root);
        System.out.println("Maximum Path Sum: " + max_sum);
    }
}
// This code is contributed by Veerendra_Singh_Rajpoot


Python




class TreeNode:
    def __init__(self, val):
        self.val = val
        self.left = None
        self.right = None
 
def find_max_sum(root):
    max_sum = float('-inf')
    stack = []
    stack.append((root, 0))
 
    while stack:
        node, state = stack.pop()
 
        if node is None:
            continue
 
        if state == 0:
            # First visit to the node
            stack.append((node, 1))
            stack.append((node.left, 0))
        elif state == 1:
            # Second visit to the node
            stack.append((node, 2))
            stack.append((node.right, 0))
        else:
            # Third visit to the node
            left_sum = node.left.val if node.left is not None else 0
            right_sum = node.right.val if node.right is not None else 0
            max_sum = max(max_sum, node.val + max(0, left_sum) + max(0, right_sum))
            max_child_sum = max(left_sum, right_sum)
            node.val += max(0, max_child_sum)
 
    return max_sum
 
# Constructing the tree given in the above example
root = TreeNode(10)
root.left = TreeNode(2)
root.right = TreeNode(-25)
root.left.left = TreeNode(20)
root.left.right = TreeNode(1)
root.right.left = TreeNode(3)
root.right.right = TreeNode(4)
 
max_sum = find_max_sum(root)
print("Maximum Path Sum:", max_sum)


C#




using System;
using System.Collections.Generic;
 
// Definition of a binary tree node
class Node
{
    public int val;
    public Node left, right;
 
    public Node(int val)
    {
        this.val = val;
        this.left = this.right = null;
    }
}
 
public class GFG
{
    static int FindMaxSum(Node root)
    {
        int max_sum = int.MinValue;
        Stack<Tuple<Node, int>> stack = new Stack<Tuple<Node, int>>();
        stack.Push(new Tuple<Node, int>(root, 0));
 
        while (stack.Count > 0)
        {
            Tuple<Node, int> pair = stack.Pop();
            Node node = pair.Item1;
            int state = pair.Item2;
 
            if (node == null)
            {
                continue;
            }
 
            if (state == 0)
            {
                // first visit to the node
                stack.Push(new Tuple<Node, int>(node, 1));
                stack.Push(new Tuple<Node, int>(node.left, 0));
            }
            else if (state == 1)
            {
                // second visit to the node
                stack.Push(new Tuple<Node, int>(node, 2));
                stack.Push(new Tuple<Node, int>(node.right, 0));
            }
            else
            {
                // third visit to the node
                int left_sum = (node.left != null) ? node.left.val : 0;
                int right_sum = (node.right != null) ? node.right.val : 0;
                max_sum = Math.Max(max_sum, node.val + Math.Max(0, left_sum) + Math.Max(0, right_sum));
                int max_child_sum = Math.Max(left_sum, right_sum);
                node.val += Math.Max(0, max_child_sum);
            }
        }
 
        return max_sum;
    }
 
    // Driver Code
    public static void Main(string[] args)
    {
        // Constructing the tree given in the above example
        Node root = new Node(10);
        root.left = new Node(2);
        root.right = new Node(-25);
        root.left.left = new Node(20);
        root.left.right = new Node(1);
        root.right.left = new Node(3);
        root.right.right = new Node(4);
 
        int max_sum = FindMaxSum(root);
        Console.WriteLine("Maximum Path Sum: " + max_sum);
    }
}


Javascript




// JavaScript code for above approach
 
class Node {
    constructor(val) {
        this.val = val;
        this.left = this.right = null;
    }
}
 
function findMaxSum(root) {
    let max_sum = Number.MIN_SAFE_INTEGER;
    const stack = [];
    stack.push({ node: root, state: 0 });
 
    while (stack.length > 0) {
        const { node, state } = stack.pop();
 
        if (node === null) {
            continue;
        }
 
        if (state === 0) {
            // first visit to the node
            stack.push({ node: node, state: 1 });
            stack.push({ node: node.left, state: 0 });
        } else if (state === 1) {
            // second visit to the node
            stack.push({ node: node, state: 2 });
            stack.push({ node: node.right, state: 0 });
        } else {
            // third visit to the node
            const left_sum = (node.left !== null) ? node.left.val : 0;
            const right_sum = (node.right !== null) ? node.right.val : 0;
            max_sum = Math.max(
                max_sum,
                node.val + Math.max(0, left_sum) + Math.max(0, right_sum)
            );
            const max_child_sum = Math.max(left_sum, right_sum);
            node.val += Math.max(0, max_child_sum);
        }
    }
 
    return max_sum;
}
 
 
// Constructing the tree given in the above example
const root = new Node(10);
root.left = new Node(2);
root.right = new Node(-25);
root.left.left = new Node(20);
root.left.right = new Node(1);
root.right.left = new Node(3);
root.right.right = new Node(4);
 
const max_sum = findMaxSum(root);
console.log("Maximum Path Sum: " + max_sum);


Output

Maximum Path Sum: 32







Time Complexity: O(n) , The time complexity of this approach is O(n), where n is the number of nodes in the tree, since we need to visit each node once.

Space Complexity: O(h) , The space complexity is O(h), where h is the height of the tree, since we need to store at most h nodes in the stack

This article is contributed by Anmol Varshney (FB Profile: https://www.facebook.com/anmolvarshney695).



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads