Open In App

Length of Longest Increasing Subsequences (LIS) using Segment Tree

Given an array arr[] of size N, the task is to count the number of longest increasing subsequences present in the given array.

Example:



Input: arr[] = {2, 2, 2, 2, 2}
Output: 5
Explanation: The length of the longest increasing subsequence is 1, i.e. {2}. Therefore, count of longest increasing subsequences of length 1 is 5. 
 

Input: arr[] = {1, 3, 5, 4, 7}
Output: 2
Explanation: The length of the longest increasing subsequence is 4, and there are 2 longest increasing subsequences of length 4, i.e. {1, 3, 4, 7} and {1, 3, 5, 7}.



Approach: An approach to the given problem has been already discussed using dynamic programming in this article. 
This article suggests a different approach using segment trees. Follow the below steps to solve the given problem:

Below is the implementation of the above approach:




// C++ implementation of the above approach
#include <bits/stdc++.h>
using namespace std;
 
#define M 100000
 
// Stores the Segment tree
vector<pair<int, int> > tree(4 * M + 1);
 
// Function to update Segment tree, the root
// of which contains the length of the LIS
void update_tree(int start, int end,
                 int update_idx, int length_t,
                 int count_c, int idx)
{
    // If the intervals
    // are overlapping completely
    if (start == end
        && start == update_idx) {
        tree[idx].first
            = max(tree[idx].first, length_t);
        tree[idx].second = count_c;
        return;
    }
 
    // If intervals are not overlapping
    if (update_idx < start
        || end < update_idx) {
        return;
    }
 
    // If intervals are partially overlapping
    int mid = (start + end) / 2;
 
    update_tree(start, mid, update_idx,
                length_t, count_c,
                2 * idx);
    update_tree(mid + 1, end, update_idx,
                length_t, count_c,
                2 * idx + 1);
 
    // If length_t of left and
    // right child are equal
    if (tree[2 * idx].first
        == tree[2 * idx + 1].first) {
        tree[idx].first
            = tree[2 * idx].first;
        tree[idx].second
            = tree[2 * idx].second
              + tree[2 * idx + 1].second;
    }
 
    // If length_t of left > length_t right child
    else if (tree[2 * idx].first
             > tree[2 * idx + 1].first) {
        tree[idx] = tree[2 * idx];
    }
 
    // If length_t of left < length_t right child
    else {
        tree[idx] = tree[2 * idx + 1];
    }
}
 
// Function to find the LIS length
// and count in the given range
pair<int, int> query(int start, int end,
                     int query_start,
                     int query_end, int idx)
{
    // If the intervals
    // are overlapping completely
    if (query_start <= start
        && end <= query_end) {
        return tree[idx];
    }
 
    // If intervals are not overlapping
    pair<int, int> temp({ INT32_MIN, 0 });
    if (end < query_start
        || query_end < start) {
        return temp;
    }
 
    // If intervals are partially overlapping
    int mid = (start + end) / 2;
    auto left_child
        = query(start, mid, query_start,
                query_end, 2 * idx);
    auto right_child
        = query(mid + 1, end, query_start,
                query_end, 2 * idx + 1);
 
    // If length_t of left child is greater
    // than length_t of right child
    if (left_child.first > right_child.first) {
        return left_child;
    }
 
    // If length_t of right child is
    // greater than length_t of left child
    if (right_child.first > left_child.first) {
        return right_child;
    }
 
    // If length_t of left
    // and right child are equal
    // return there sum
    return make_pair(left_child.first,
                     left_child.second
                         + right_child.second);
}
 
// Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
bool comp(pair<int, int> a, pair<int, int> b)
{
    if (a.first == b.first) {
        return a.second > b.second;
    }
    return a.first < b.first;
}
 
// Function to find count
// of LIS in the given array
int countLIS(int arr[], int n)
{
    // Generating value-index pair array
    vector<pair<int, int> > pair_array(n);
    for (int i = 0; i < n; i++) {
        pair_array[i].first = arr[i];
        pair_array[i].second = i;
    }
 
    // Sort array of pairs with increasing order
    // of value and decreasing order of index
    sort(pair_array.begin(),
         pair_array.end(), comp);
 
    // Traverse the array
    // and perform query updates
    for (int i = 0; i < n; i++) {
 
        int update_idx = pair_array[i].second;
 
        // If update index is the 1st index
        if (update_idx == 0) {
            update_tree(0, n - 1, 0, 1, 1, 1);
            continue;
        }
 
        // Query over the interval [0, update_idx -1]
        pair<int, int> temp
            = query(0, n - 1, 0,
                    update_idx - 1, 1);
 
        // Update the segment tree
        update_tree(0, n - 1, update_idx,
                    temp.first + 1,
                    max(1, temp.second), 1);
    }
 
    // Stores the final answer
    pair<int, int> ans
        = query(0, n - 1, 0, n - 1, 1);
 
    // Return answer
    return ans.second;
}
 
// Driver Code
int main()
{
    int arr[] = { 1, 3, 5, 4, 7 };
    int n = sizeof(arr) / sizeof(int);
 
    cout << countLIS(arr, n);
 
    return 0;
}




import java.util.*;
import java.io.*;
 
// Java program for the above approach
public class GFG{
 
    public static int M = 100000;
 
    // Stores the Segment tree
    public static ArrayList<ArrayList<Integer>> tree =
      new ArrayList<ArrayList<Integer>>();
 
    // Function to update Segment tree, the root
    // of which contains the length of the LIS
    public static void update_tree(int start, int end,
                                   int update_idx, int length_t,
                                   int count_c, int idx)
    {
        // If the intervals
        // are overlapping completely
        if (start == end && start == update_idx) {
            tree.get(idx).set(0, Math.max(tree.get(idx).get(0), length_t));
            tree.get(idx).set(1, count_c);
            return;
        }
 
        // If intervals are not overlapping
        if (update_idx < start || end < update_idx) {
            return;
        }
 
        // If intervals are partially overlapping
        int mid = (start + end) / 2;
 
        update_tree(start, mid, update_idx,
                    length_t, count_c, 2 * idx);
        update_tree(mid + 1, end, update_idx,
                    length_t, count_c, 2 * idx + 1);
 
        // If length_t of left and
        // right child are equal
        if (tree.get(2 * idx).get(0) == tree.get(2 * idx + 1).get(0)) {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx).get(0),
                        tree.get(2 * idx).get(1) +
                        tree.get(2 * idx + 1).get(1))
            ));
        }
 
        // If length_t of left > length_t right child
        else if (tree.get(2 * idx).get(0) > tree.get(2 * idx + 1).get(0)) {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx).get(0), tree.get(2 * idx).get(1))
            ));
        }
 
        // If length_t of left < length_t right child
        else {
            tree.set(idx, new ArrayList<Integer>(
                List.of(tree.get(2 * idx + 1).get(0), tree.get(2 * idx + 1).get(1))
            ));
        }
    }
 
    // Function to find the LIS length
    // and count in the given range
    public static ArrayList<Integer> query(int start, int end,
                                           int query_start,
                                           int query_end, int idx)
    {
        // If the intervals
        // are overlapping completely
        if (query_start <= start && end <= query_end) {
            return new ArrayList<Integer>(tree.get(idx));
        }
 
        // If intervals are not overlapping
        ArrayList<Integer> temp = new ArrayList<Integer>(
            List.of(Integer.MIN_VALUE, 0 )
        );
 
        if (end < query_start || query_end < start) {
            return new ArrayList<Integer>(temp);
        }
 
        // If intervals are partially overlapping
        int mid = (start + end) / 2;
        ArrayList<Integer> left_child = query(start, mid,
                                              query_start,
                                              query_end, 2 * idx);
        ArrayList<Integer> right_child = query(mid + 1, end,
                                               query_start,
                                               query_end, 2 * idx + 1);
 
        // If length_t of left child is greater
        // than length_t of right child
        if (left_child.get(0) > right_child.get(0)) {
            return new ArrayList<Integer>(left_child);
        }
 
        // If length_t of right child is
        // greater than length_t of left child
        if (right_child.get(0) > left_child.get(0)) {
            return new ArrayList<Integer>(right_child);
        }
 
        // If length_t of left
        // and right child are equal
        // return there sum
        return new ArrayList<Integer>(
            List.of(
                left_child.get(0),
                left_child.get(1) + right_child.get(1)
            )
        );
    }
 
    // Function to find count
    // of LIS in the given array
    public static int countLIS(int arr[], int n)
    {
        // Generating value-index pair array
        ArrayList<ArrayList<Integer>> pair_array = new ArrayList<ArrayList<Integer>>();
 
        for(int i = 0 ; i < n ; i++){
            pair_array.add(new ArrayList<Integer>(
                List.of(arr[i], i)
            ));
        }
 
        // Sort array of pairs with increasing order
        // of value and decreasing order of index
        Collections.sort(pair_array, new comp());
 
        // Traverse the array
        // and perform query updates
        for (int i = 0 ; i < n ; i++) {
 
            int update_idx = pair_array.get(i).get(1);
 
            // If update index is the 1st index
            if (update_idx == 0) {
                update_tree(0, n - 1, 0, 1, 1, 1);
                continue;
            }
 
            // Query over the interval [0, update_idx -1]
            ArrayList<Integer> temp = query(0, n - 1, 0,
                                            update_idx - 1, 1);
 
            // Update the segment tree
            update_tree(0, n - 1, update_idx, temp.get(0) + 1,
                        Math.max(1, temp.get(1)), 1);
        }
 
        // Stores the final answer
        ArrayList<Integer> ans = query(0, n - 1, 0, n - 1, 1);
 
        // Return answer
        return ans.get(1);
    }
 
 
    // Driver code
    public static void main(String args[])
    {
        int arr[] = { 1, 3, 5, 4, 7 };
        int n = arr.length;
 
        for(int i = 0 ; i < 4*M + 1 ; i++){
            tree.add(new ArrayList<Integer>(
                List.of(Integer.MIN_VALUE,0)
            ));
        }
 
        System.out.println(countLIS(arr, n));
    }
}
 
 // Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
public class comp implements Comparator<ArrayList<Integer>>{
    public int compare(ArrayList<Integer> a, ArrayList<Integer> b)
    {
        if (a.get(0).equals(b.get(0))) {
            return b.get(1).compareTo(a.get(1));
        }
        return a.get(0).compareTo(b.get(0));
    }
}
 
// This code is contributed by subhamgoyal2014.




// Finding the Longest Increasing Subsequence using
// Segment Tree
 
using System;
 
class SegmentTree {
    private int[] tree; // The segment tree array
 
    // Constructor that initializes the segment tree
    public SegmentTree(int size)
    {
        // Determine the height of the tree
        int height = (int)Math.Ceiling(Math.Log(size, 2));
        // Determine the maximum size of the tree array
        int maxSize = 2 * (int)Math.Pow(2, height) - 1;
        // Create the tree array
        tree = new int[maxSize];
    }
 
    // Method that builds the segment tree from an array
    public void BuildTree(int[] arr, int pos, int low,
                          int high)
    {
        // If the segment has only one element, set the
        // corresponding value in the tree
        if (low == high) {
            tree[pos] = arr[low];
            return;
        }
 
        // Determine the middle index of the segment
        int mid = (low + high) / 2;
        // Recursively build the left subtree
        BuildTree(arr, 2 * pos + 1, low, mid);
        // Recursively build the right subtree
        BuildTree(arr, 2 * pos + 2, mid + 1, high);
        // Set the value of the current node to the maximum
        // value of its children
        tree[pos] = Math.Max(tree[2 * pos + 1],
                             tree[2 * pos + 2]);
    }
 
    // Method that returns the maximum value in a given
    // range
    public int Query(int pos, int low, int high, int start,
                     int end)
    {
        // If the given range is fully contained in the
        // current segment, return the corresponding value
        // in the tree
        if (start <= low && end >= high) {
            return tree[pos];
        }
 
        // If the given range does not overlap with the
        // current segment, return the minimum possible
        // value
        if (start > high || end < low) {
            return int.MinValue;
        }
 
        // Determine the middle index of the segment
        int mid = (low + high) / 2;
        // Recursively query the left subtree
        int left = Query(2 * pos + 1, low, mid, start, end);
        // Recursively query the right subtree
        int right
            = Query(2 * pos + 2, mid + 1, high, start, end);
        // Return the maximum value of the two subtrees
        return Math.Max(left, right);
    }
}
 
class LIS {
    // Method that returns the length of the longest
    // increasing subsequence in an array
    public static int GetLISLength(int[] arr)
    {
        int n = arr.Length;
 
        // Create a sorted copy of the array
        int[] sortedArr = new int[n];
        Array.Copy(arr, sortedArr, n);
        Array.Sort(sortedArr);
 
        // Create a map that maps the elements of the
        // original array to their indices in the sorted
        // array
        int[] indexMap = new int[n];
        for (int i = 0; i < n; i++) {
            indexMap[Array.IndexOf(sortedArr, arr[i])] = i;
        }
 
        // Create a segment tree to store the dynamic
        // programming values
        SegmentTree tree = new SegmentTree(n);
        // Build the initial tree with all values set to
        // zero
        tree.BuildTree(new int[n], 0, 0, n - 1);
 
        // Create an array to store the dynamic programming
        // values
        int[] dp = new int[n];
        for (int i = 0; i < n; i++) {
            int prevMax = tree.Query(0, 0, n - 1, 0,
                                     indexMap[i] - 1);
            dp[i] = prevMax + 1;
            tree.BuildTree(dp, 0, 0, n - 1);
        }
 
        // Searching for the maximum value of longest
        // increasing subsequence
        int maxLIS = 0;
        for (int i = 0; i < n; i++) {
            maxLIS = Math.Max(maxLIS, dp[i]);
        }
 
        // Return the maximum value of longest increasing
        // subsequence
        return maxLIS;
    }
}
 
// Driver code
class Program {
    static void Main(string[] args)
    {
        int[] arr = { 1, 3, 5, 4, 7 };
        Console.WriteLine(LIS.GetLISLength(arr));
    }
}




<script>
// Javascript implementation of the above approach
 
 
let M = 100000
 
// Stores the Segment tree
let tree = new Array(4 * M + 1).fill(0).map(() => []);
 
// Function to update Segment tree, the root
// of which contains the length of the LIS
function update_tree(start, end, update_idx, length_t, count_c, idx) {
  // If the intervals
  // are overlapping completely
  if (start == end
    && start == update_idx) {
    tree[idx][0]
      = Math.max(tree[idx][0], length_t);
    tree[idx][1] = count_c;
    return;
  }
 
  // If intervals are not overlapping
  if (update_idx < start
    || end < update_idx) {
    return;
  }
 
  // If intervals are partially overlapping
  let mid = Math.floor((start + end) / 2);
 
  update_tree(start, mid, update_idx,
    length_t, count_c,
    2 * idx);
  update_tree(mid + 1, end, update_idx,
    length_t, count_c,
    2 * idx + 1);
 
  // If length_t of left and
  // right child are equal
  if (tree[2 * idx][0]
    == tree[2 * idx + 1][0]) {
    tree[idx][0]
      = tree[2 * idx][0];
    tree[idx][1]
      = tree[2 * idx][1]
      + tree[2 * idx + 1][1];
  }
 
  // If length_t of left > length_t right child
  else if (tree[2 * idx][0]
    > tree[2 * idx + 1][0]) {
    tree[idx] = tree[2 * idx];
  }
 
  // If length_t of left < length_t right child
  else {
    tree[idx] = tree[2 * idx + 1];
  }
}
 
// Function to find the LIS length
// and count in the given range
function query(start, end, query_start, query_end, idx) {
  // If the intervals
  // are overlapping completely
  if (query_start <= start
    && end <= query_end) {
    return tree[idx];
  }
 
  // If intervals are not overlapping
  let temp = [Number.MIN_SAFE_INTEGER, 0];
  if (end < query_start
    || query_end < start) {
    return temp;
  }
 
  // If intervals are partially overlapping
  let mid = Math.floor((start + end) / 2);
  let left_child
    = query(start, mid, query_start,
      query_end, 2 * idx);
  let right_child
    = query(mid + 1, end, query_start,
      query_end, 2 * idx + 1);
 
  // If length_t of left child is greater
  // than length_t of right child
  if (left_child[0] > right_child[0]) {
    return left_child;
  }
 
  // If length_t of right child is
  // greater than length_t of left child
  if (right_child[0] > left_child[0]) {
    return right_child;
  }
 
  // If length_t of left
  // and right child are equal
  // return there sum
  return [left_child[0],
  left_child[1]
  + right_child[1]];
}
 
// Comparator function to sort an array of pairs
// in increasing order of their 1st element and
// thereafter in decreasing order of the 2nd
function comp(a, b) {
  if (a[0] == b[0]) {
    return a[1] > b[1];
  }
  return a[0] < b[0];
}
 
// Function to find count
// of LIS in the given array
function countLIS(arr, n) {
  // Generating value-index pair array
  let pair_array = new Array(n).fill(0).map(() => []);
  for (let i = 0; i < n; i++) {
    pair_array[i][0] = arr[i];
    pair_array[i][1] = i;
  }
 
  // Sort array of pairs with increasing order
  // of value and decreasing order of index
  pair_array.sort(comp);
 
  // Traverse the array
  // and perform query updates
  for (let i = 0; i < n; i++) {
 
    let update_idx = pair_array[i][1];
 
    // If update index is the 1st index
    if (update_idx == 0) {
      update_tree(0, n - 1, 0, 1, 1, 1);
      continue;
    }
 
    // Query over the interval [0, update_idx -1]
    let temp = query(0, n - 1, 0, update_idx - 1, 1);
 
    // Update the segment tree
    update_tree(0, n - 1, update_idx,
      temp[0] + 1,
      Math.max(1, temp[1]), 1);
  }
 
  // Stores the final answer
  let ans = query(0, n - 1, 0, n - 1, 1);
 
  // Return answer
  return ans[1];
}
 
// Driver Code
 
let arr = [1, 3, 5, 4, 7];
let n = arr.length;
 
document.write(countLIS(arr, n));
 
// This code is contributed by saurabh_jaiswal.
</script>




# Python program for the above approach
 
import math
 
# Finding the longest increasing Subsequence using
# Segment Tree
class SegmentTree:
     
    # Constructor that initializes the segment tree
    def __init__(self, size):
        # Determine the height of the tree
        height = math.ceil(math.log2(size))
         
        # Determine the maximum size of the tree array
        maxSize = 2 * (2 ** height) - 1
         
        # Create the tree array
        self.tree = [0] * maxSize
 
    # Method that builds the segment tree from an array
    def BuildTree(self, arr, pos, low, high):
         
        # If the segment has only one element, set the
        # corresponding value in the tree
        if low == high:
            self.tree[pos] = arr[low]
            return
 
        # Determine the middle index of the segment
        mid = (low + high) // 2
         
        # Recursively build the left subtree
        self.BuildTree(arr, 2 * pos + 1, low, mid)
         
        # Recursively build the right subtree
        self.BuildTree(arr, 2 * pos + 2, mid + 1, high)
         
        # Set the value of the current node to the maximum
        # value of its children
        self.tree[pos] = max(self.tree[2 * pos + 1], self.tree[2 * pos + 2])
 
    # Method that returns the maximum value in a given
    # range
    def Query(self, pos, low, high, start, end):
         
        # If the given range is fully contained in the
        # current segment, return the corresponding value
        # in the tree
        if start <= low and end >= high:
            return self.tree[pos]
         
        # If the given range does not overlap with the
        # current segment, return the minimum possible
        # value
        if start > high or end < low:
            return float('-inf')
         
        # Determine the middle index of the segment
        mid = (low + high) // 2
         
        # Recursively query the left subtree
        left = self.Query(2 * pos + 1, low, mid, start, end)
         
        # Recursively query the right subtree
        right = self.Query(2 * pos + 2, mid + 1, high, start, end)
         
        # Return the maximum value of the two subtrees
        return max(left, right)
 
class LIS:
    @staticmethod
    # Method that returns the length of the longest
    # increasing subsequence in an array
    def GetLISLength(arr):
        n = len(arr)
         
        # Create a sorted copy of the array
        sortedArr = sorted(arr)
         
         
        # Create a map that maps the elements of the
        # original array to their indices in the sorted
        # array
        indexMap = [0] * n
         
        for i in range(n):
            indexMap[sortedArr.index(arr[i])] = i
     
        # Create a segment tree to store the dynamic
        # programming values
        tree = SegmentTree(n)
         
        # Build the initial tree with all values set to
        # zero
        tree.BuildTree([0] * n, 0, 0, n - 1)
         
        # Create an array to store the dynamic programming
        # values
        dp = [0] * n
 
        for i in range(n):
            prevMax = tree.Query(0, 0, n - 1, 0, indexMap[i] - 1)
            dp[i] = prevMax + 1
            tree.BuildTree(dp, 0, 0, n - 1)
         
        # Searching for the maximum value of longest
        # increasing subsequence
        maxLIS = 0
        for i in range(n):
            maxLIS = max(maxLIS, dp[i])
         
        # Return the maximum value of longest increasing
        # subsequence
        return maxLIS
 
# Driver code
arr = [1, 3, 5, 4, 7]
print(LIS.GetLISLength(arr))
 
 
# This code is contributed by princekumaras

Output
2

Time Complexity: O(N*log N)
Auxiliary Space: O(N)

Related Topic: Segment Tree


Article Tags :