Open In App

Find maximum product of K integers in a Binary Search Tree

Last Updated : 14 Oct, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

Given a Binary search tree(BST) and a positive integer K, find the maximum product of k integers in the tree. Return the product as an integer.

Examples:

Input: K = 3,
            10
         /       \
       3        12
      / \        /  \
    2   4   11  13
Output: 1716
Explanation: The maximum product of 3 integers can be obtained by selecting 12, 13, and 11. The product is 12*13*11 = 1716. Hence the output is 1716, which is the maximum product of 3 integers.

Input: K = 2,

          5
        /   \
      2   6
     /  \    \
   1   3   8
Output: 48
Explanation: The maximum product of 2 integers can be obtained by selecting 6 and 8. The product is 6*8 = 48. Hence the output is 48, which is the maximum product of 2 integers.

Approach: This can be solved with the following idea:

We can find the maximum product of K integers in the BST using the following approach:

  • Traverse the BST and store all the nodes in a vector.
  • Sort the vector in decreasing order of node values.
  • Select the first k nodes from the sorted vector and compute their product.
  • Return the product.

Below are the steps to implement the above idea :

  • Define a class for the BST node, which contains the node value and pointers to the left and right child nodes.
  • Define a function to traverse the BST in any order (preorder, inorder, postorder) and store the nodes in a vector.
  • Define a function to sort the vector in decreasing order of node values.
  • Define a function to compute the product of K nodes.
  • Define a function to find the maximum product of k nodes in the BST, which uses the above functions.

Below is the code for the above approach:

C++




// C++ Implementation
#include <bits/stdc++.h>
using namespace std;
 
// Define a class for the BST node
class Node {
public:
    int value;
    Node* left;
    Node* right;
 
    Node(int val)
    {
        value = val;
        left = nullptr;
        right = nullptr;
    }
};
 
// Function to traverse the BST in any
// order and store the nodes in a vector
void traverse(Node* root, vector<Node*>& nodes)
{
    if (root == nullptr) {
        return;
    }
    traverse(root->left, nodes);
    nodes.push_back(root);
    traverse(root->right, nodes);
}
 
// Function to sort the vector in
// decreasing order of node values
bool cmp(Node* a, Node* b) { return a->value > b->value; }
 
// Function to compute the product
// of k nodes
int product(vector<Node*>& nodes, int k)
{
    int prod = 1;
    for (int i = 0; i < k; i++) {
        prod *= nodes[i]->value;
    }
    return prod;
}
 
// Function to find the maximum product
// of k nodes in the BST
int max_product(Node* root, int k)
{
 
    // Create a vector to store the nodes
    vector<Node*> nodes;
 
    // Traverse the BST and store the
    // nodes in the vector
    traverse(root, nodes);
 
    // Sort the vector in decreasing
    // order of node values
    sort(nodes.begin(), nodes.end(), cmp);
 
    // If the number of nodes in the
    // BST is less than k, return 0
    if (nodes.size() < k) {
        return 0;
    }
 
    // Compute the product of the first
    // k nodes in the sorted vector
    int prod = product(nodes, k);
 
    // Return the product
    return prod;
}
 
// Driver code
int main()
{
 
    // Create the BST
    /*
         10
        /  \
       3    12
      / \   / \
     2   4 11  13
 
    */
    Node* root1 = new Node(10);
    root1->left = new Node(3);
    root1->right = new Node(12);
    root1->left->left = new Node(2);
    root1->left->right = new Node(4);
    root1->right->left = new Node(11);
    root1->right->right = new Node(13);
 
    // Function call
    cout << max_product(root1, 3) << endl;
 
    return 0;
}


Java




import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
 
// Define a class for the BST node
class Node {
    public int value;
    public Node left;
    public Node right;
 
    public Node(int val) {
        value = val;
        left = null;
        right = null;
    }
}
 
public class MaximumProductBST {
    // Function to traverse the BST in any
    // order and store the nodes in a list
    static void traverse(Node root, List<Node> nodes) {
        if (root == null) {
            return;
        }
        traverse(root.left, nodes);
        nodes.add(root);
        traverse(root.right, nodes);
    }
 
    // Function to sort the list in
    // decreasing order of node values
    static Comparator<Node> comparator = (a, b) -> b.value - a.value;
 
    // Function to compute the product
    // of k nodes
    static int product(List<Node> nodes, int k) {
        int prod = 1;
        for (int i = 0; i < k; i++) {
            prod *= nodes.get(i).value;
        }
        return prod;
    }
 
    // Function to find the maximum product
    // of k nodes in the BST
    static int maxProduct(Node root, int k) {
        // Create a list to store the nodes
        List<Node> nodes = new ArrayList<>();
 
        // Traverse the BST and store the
        // nodes in the list
        traverse(root, nodes);
 
        // Sort the list in decreasing
        // order of node values
        Collections.sort(nodes, comparator);
 
        // If the number of nodes in the
        // BST is less than k, return 0
        if (nodes.size() < k) {
            return 0;
        }
 
        // Compute the product of the first
        // k nodes in the sorted list
        int prod = product(nodes, k);
 
        // Return the product
        return prod;
    }
 
    // Driver code
    public static void main(String[] args) {
        // Create the BST
        /*
                 10
                /  \
               3    12
              / \   / \
             2   4 11  13
        */
        Node root = new Node(10);
        root.left = new Node(3);
        root.right = new Node(12);
        root.left.left = new Node(2);
        root.left.right = new Node(4);
        root.right.left = new Node(11);
        root.right.right = new Node(13);
 
        // Function call
        System.out.println(maxProduct(root, 3));
    }
}
// This code was contributed by codearcade


Python




# Python Implementation
# Define a class for the BST node
class Node:
    def __init__(self, val):
        self.value = val
        self.left = None
        self.right = None
 
# Function to traverse the BST in any
# order and store the nodes in a list
def traverse(root, nodes):
    if root is None:
        return
    traverse(root.left, nodes)
    nodes.append(root)
    traverse(root.right, nodes)
 
# Function to sort the list in
# decreasing order of node values
def cmp(a, b):
    return a.value > b.value
 
# Function to compute the product
# of k nodes
def product(nodes, k):
    prod = 1
    for i in range(k):
        prod *= nodes[i].value
    return prod
 
# Function to find the maximum product
# of k nodes in the BST
def max_product(root, k):
 
    # Create a list to store the nodes
    nodes = []
 
    # Traverse the BST and store the
    # nodes in the list
    traverse(root, nodes)
 
    # Sort the list in decreasing
    # order of node values
    nodes.sort(key=lambda x: x.value, reverse=True)
 
    # If the number of nodes in the
    # BST is less than k, return 0
    if len(nodes) < k:
        return 0
 
    # Compute the product of the first
    # k nodes in the sorted list
    prod = product(nodes, k)
 
    # Return the product
    return prod
 
# Driver code
if __name__ == "__main__":
 
    # Create the BST
    """
         10
        /  \
       3    12
      / \   / \
     2   4 11  13
    """
    root1 = Node(10)
    root1.left = Node(3)
    root1.right = Node(12)
    root1.left.left = Node(2)
    root1.left.right = Node(4)
    root1.right.left = Node(11)
    root1.right.right = Node(13)
 
    # Function call
    print(max_product(root1, 3))
 
# This code is contributed by Susobhan Akhuli


C#




// C# Implementation
using System;
using System.Collections.Generic;
using System.Linq;
 
// Define a class for the BST node
class Node {
    public int value;
    public Node left;
    public Node right;
 
    public Node(int val)
    {
        value = val;
        left = null;
        right = null;
    }
}
 
class Program {
    // Function to traverse the BST in any
    // order and store the nodes in a list
    static void Traverse(Node root, List<Node> nodes)
    {
        if (root == null) {
            return;
        }
        Traverse(root.left, nodes);
        nodes.Add(root);
        Traverse(root.right, nodes);
    }
 
    // Function to sort the list in
    // decreasing order of node values
    static int CompareNodes(Node a, Node b)
    {
        return b.value.CompareTo(a.value);
    }
 
    // Function to compute the product
    // of k nodes
    static int Product(List<Node> nodes, int k)
    {
        int prod = 1;
        for (int i = 0; i < k; i++) {
            prod *= nodes[i].value;
        }
        return prod;
    }
 
    // Function to find the maximum product
    // of k nodes in the BST
    static int MaxProduct(Node root, int k)
    {
        // Create a list to store the nodes
        List<Node> nodes = new List<Node>();
 
        // Traverse the BST and store the
        // nodes in the list
        Traverse(root, nodes);
 
        // Sort the list in decreasing
        // order of node values
        nodes.Sort(CompareNodes);
 
        // If the number of nodes in the
        // BST is less than k, return 0
        if (nodes.Count < k) {
            return 0;
        }
 
        // Compute the product of the first
        // k nodes in the sorted list
        int prod = Product(nodes, k);
 
        // Return the product
        return prod;
    }
 
    // Driver code
    static void Main()
    {
        // Create the BST
        /*
             10
            /  \
           3    12
          / \   / \
         2   4 11  13
        */
        Node root1 = new Node(10);
        root1.left = new Node(3);
        root1.right = new Node(12);
        root1.left.left = new Node(2);
        root1.left.right = new Node(4);
        root1.right.left = new Node(11);
        root1.right.right = new Node(13);
 
        // Function call
        Console.WriteLine(MaxProduct(root1, 3));
    }
}
 
// This code is contributed by Susobhan Akhuli


Javascript




<script>
// JavaScript code for the above approach
// Define a class for the BST node
class Node {
    constructor(val) {
        this.value = val;
        this.left = null;
        this.right = null;
    }
}
 
// Function to traverse the BST in any order and store the nodes in a vector
function traverse(root, nodes) {
    if (root === null) {
        return;
    }
    traverse(root.left, nodes);
    nodes.push(root);
    traverse(root.right, nodes);
}
 
// Function to sort the vector in decreasing order of node values
function cmp(a, b) {
    return a.value > b.value ? -1 : 1;
}
 
// Function to compute the product of k nodes
function product(nodes, k) {
    let prod = 1;
    for (let i = 0; i < k; i++) {
        prod *= nodes[i].value;
    }
    return prod;
}
 
// Function to find the maximum product of k nodes in the BST
function maxProduct(root, k) {
    // Create a vector to store the nodes
    const nodes = [];
 
    // Traverse the BST and store the nodes in the vector
    traverse(root, nodes);
 
    // Sort the vector in decreasing order of node values
    nodes.sort(cmp);
 
    // If the number of nodes in the BST is less than k, return 0
    if (nodes.length < k) {
        return 0;
    }
 
    // Compute the product of the first k nodes in the sorted vector
    const prod = product(nodes, k);
 
    // Return the product
    return prod;
}
 
// Create the BST
/*
       10
      /  \
     3    12
    / \   / \
   2   4 11  13
  */
const root = new Node(10);
root.left = new Node(3);
root.right = new Node(12);
root.left.left = new Node(2);
root.left.right = new Node(4);
root.right.left = new Node(11);
root.right.right = new Node(13);
 
// Function call
document.write(maxProduct(root, 3));
 
// This code is contributed by Susobhan Akhuli
</script>


Output

1716







Time Complexity: O(n*logn)
Auxiliary Space: O(n)



Like Article
Suggest improvement
Share your thoughts in the comments

Similar Reads