Open In App

Merge Sort using Multi-threading

Last Updated : 16 Feb, 2024
Improve
Improve
Like Article
Like
Save
Share
Report

Merge Sort is a popular sorting technique which divides an array or list into two halves and then start merging them when sufficient depth is reached. Time complexity of merge sort is O(nlogn).
Threads are lightweight processes and threads shares with other threads their code section, data section and OS resources like open files and signals. But, like process, a thread has its own program counter (PC), a register set, and a stack space.
Multi-threading is way to improve parallelism by running the threads simultaneously in different cores of your processor. In this program, we’ll use 4 threads but you may change it according to the number of cores your processor has.
Examples: 
 

Input :  83, 86, 77, 15, 93, 35, 86, 92, 49, 21, 
62, 27, 90, 59, 63, 26, 40, 26, 72, 36
Output : 15, 21, 26, 26, 27, 35, 36, 40, 49, 59,
62, 63, 72, 77, 83, 86, 86, 90, 92, 93
Input : 6, 5, 4, 3, 2, 1
Output : 1, 2, 3, 4, 5, 6

 

Note* It is better to execute the program in linux based system. 
To compile in linux System : 
 

g++ -pthread program_name.cpp

 

C++




// CPP Program to implement merge sort using
// multi-threading
#include <iostream>
#include <pthread.h>
#include <time.h>
 
// number of elements in array
#define MAX 20
 
// number of threads
#define THREAD_MAX 4
 
using namespace std;
 
// array of size MAX
int a[MAX];
int part = 0;
 
// merge function for merging two parts
void merge(int low, int mid, int high)
{
    int* left = new int[mid - low + 1];
    int* right = new int[high - mid];
 
    // n1 is size of left part and n2 is size
    // of right part
    int n1 = mid - low + 1, n2 = high - mid, i, j;
 
    // storing values in left part
    for (i = 0; i < n1; i++)
        left[i] = a[i + low];
 
    // storing values in right part
    for (i = 0; i < n2; i++)
        right[i] = a[i + mid + 1];
 
    int k = low;
    i = j = 0;
 
    // merge left and right in ascending order
    while (i < n1 && j < n2) {
        if (left[i] <= right[j])
            a[k++] = left[i++];
        else
            a[k++] = right[j++];
    }
 
    // insert remaining values from left
    while (i < n1) {
        a[k++] = left[i++];
    }
 
    // insert remaining values from right
    while (j < n2) {
        a[k++] = right[j++];
    }
}
 
// merge sort function
void merge_sort(int low, int high)
{
    // calculating mid point of array
    int mid = low + (high - low) / 2;
    if (low < high) {
 
        // calling first half
        merge_sort(low, mid);
 
        // calling second half
        merge_sort(mid + 1, high);
 
        // merging the two halves
        merge(low, mid, high);
    }
}
 
// thread function for multi-threading
void* merge_sort(void* arg)
{
    // which part out of 4 parts
    int thread_part = part++;
 
    // calculating low and high
    int low = thread_part * (MAX / 4);
    int high = (thread_part + 1) * (MAX / 4) - 1;
 
    // evaluating mid point
    int mid = low + (high - low) / 2;
    if (low < high) {
        merge_sort(low, mid);
        merge_sort(mid + 1, high);
        merge(low, mid, high);
    }
}
 
// Driver Code
int main()
{
    // generating random values in array
    for (int i = 0; i < MAX; i++)
        a[i] = rand() % 100;
 
    // t1 and t2 for calculating time for
    // merge sort
    clock_t t1, t2;
 
    t1 = clock();
    pthread_t threads[THREAD_MAX];
 
    // creating 4 threads
    for (int i = 0; i < THREAD_MAX; i++)
        pthread_create(&threads[i], NULL, merge_sort,
                                        (void*)NULL);
 
    // joining all 4 threads
    for (int i = 0; i < 4; i++)
        pthread_join(threads[i], NULL);
 
    // merging the final 4 parts
    merge(0, (MAX / 2 - 1) / 2, MAX / 2 - 1);
    merge(MAX / 2, MAX/2 + (MAX-1-MAX/2)/2, MAX - 1);
    merge(0, (MAX - 1)/2, MAX - 1);
 
    t2 = clock();
 
    // displaying sorted array
    cout << "Sorted array: ";
    for (int i = 0; i < MAX; i++)
        cout << a[i] << " ";
 
    // time taken by merge sort in seconds
    cout << "Time taken: " << (t2 - t1) /
              (double)CLOCKS_PER_SEC << endl;
 
    return 0;
}


Java




// Java Program to implement merge sort using
// multi-threading
import java.lang.System;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
 
class MergeSort{
     
      // Assuming system has 4 logical processors
    private static final int MAX_THREADS = 4;
       
    // Custom Thread class with constructors
    private static class SortThreads extends Thread{
        SortThreads(Integer[] array, int begin, int end){
            super(()->{
                MergeSort.mergeSort(array, begin, end);
            });
            this.start();
        }
    }
     
      // Perform Threaded merge sort
    public static void threadedSort(Integer[] array){
          // For performance - get current time in millis before starting
        long time = System.currentTimeMillis();
        final int length = array.length;
        // Workload per thread (chunk_of_data) = total_elements/core_count
        // if the no of elements exactly go into no of available threads,
        // then divide work equally,
        // else if some remainder is present, then assume we have (actual_threads-1) available workers
        // and assign the remaining elements to be worked upon by the remaining 1 actual thread.
        boolean exact = length%MAX_THREADS == 0;
        int maxlim = exact? length/MAX_THREADS: length/(MAX_THREADS-1);
        // if workload is less and no more than 1 thread is required for work, then assign all to 1 thread
        maxlim = maxlim < MAX_THREADS? MAX_THREADS : maxlim;
        // To keep track of threads
        final ArrayList<SortThreads> threads = new ArrayList<>();
        // Since each thread is independent to work on its assigned chunk,
        // spawn threads and assign their working index ranges
        // ex: for 16 element list, t1 = 0-3, t2 = 4-7, t3 = 8-11, t4 = 12-15
        for(int i=0; i < length; i+=maxlim){
            int beg = i;
            int remain = (length)-i;
            int end = remain < maxlim? i+(remain-1): i+(maxlim-1); 
            final SortThreads t = new SortThreads(array, beg, end);
            // Add the thread references to join them later
            threads.add(t);
        }
        for(Thread t: threads){
            try{
                  // This implementation of merge requires, all chunks worked by threads to be sorted first.
                // so we wait until all threads complete
                t.join();
            } catch(InterruptedException ignored){}
        }
        // System.out.println("Merging k-parts array, where m number of parts are distinctly sorted by each Threads of available MAX_THREADS="+MAX_THREADS);
        /*
          The merge takes 2 parts at a time and merges them into 1,
          then again merges the resultant into next part and so on...until end
          For MAXLIMIT = 2 (2 elements per thread where total threads = 4, in a total of 4*2 = 8 elements)
          list1 = (beg, mid); list2 = (mid+1, end);
          1st merge = 0,0,1 (beg, mid, end)
          2nd merge = 0,1,3 (beg, mid, end)
          3rd merge = 0,3,5 (beg, mid, end)
          4th merge = 0,5,7 (beg, mid, end)
        */
        for(int i=0; i < length; i+=maxlim){
            int mid = i == 0? 0 : i-1;
            int remain = (length)-i;
            int end = remain < maxlim? i+(remain-1): i+(maxlim-1);
            // System.out.println("Begin: "+0 + " Mid: "+ mid+ " End: "+ end + " MAXLIM = " + maxlim);
            merge(array, 0, mid, end);
        }
        time = System.currentTimeMillis() - time;
        System.out.println("Time spent for custom multi-threaded recursive merge_sort(): "+ time+ "ms");
    }
 
    // Typical recursive merge sort
    public static void mergeSort(Integer[] array, int begin, int end){
        if (begin<end){
            int mid = (begin+end)/2;
            mergeSort(array, begin, mid);
            mergeSort(array, mid+1, end);
            merge(array, begin, mid, end);
        }
    }
     
    //Typical 2-way merge
    public static void merge(Integer[] array, int begin, int mid, int end){
        Integer[] temp = new Integer[(end-begin)+1];
         
        int i = begin, j = mid+1;
        int k = 0;
 
        // Add elements from first half or second half based on whichever is lower,
        // do until one of the list is exhausted and no more direct one-to-one comparison could be made
        while(i<=mid && j<=end){
            if (array[i] <= array[j]){
                temp[k] = array[i];
                i+=1;
            }else{
                temp[k] = array[j];
                j+=1;
            }
            k+=1;
        }
 
        // Add remaining elements to temp array from first half that are left over
        while(i<=mid){
            temp[k] = array[i];
            i+=1; k+=1;
        }
         
        // Add remaining elements to temp array from second half that are left over
        while(j<=end){
            temp[k] = array[j];
            j+=1; k+=1;
        }
 
        for(i=begin, k=0; i<=end; i++,k++){
            array[i] = temp[k];
        }
    }
}
 
class Driver{
    // Array Size 
    private static Random random = new Random();
    private static final int size = random.nextInt(100);
    private static final Integer list[] = new Integer[size];
    // Fill the initial array with random elements within range
    static {
      for(int i=0; i<size; i++){
        // add a +ve offset to the generated random number and subtract same offset
        // from total so that the number shifts towards negative side by the offset.
        // ex: if random_num = 10, then (10+100)-100 => -10
        list[i] = random.nextInt(size+(size-1))-(size-1);
      }
    }
    // Test the sorting methods performance
    public static void main(String[] args){
      System.out.print("Input = [");
      for (Integer each: list)
        System.out.print(each+", ");
      System.out.print("] \n" +"Input.length = " + list.length + '\n');
 
      // Test standard Arrays.sort() method
      Integer[] arr1 = Arrays.copyOf(list, list.length);
      long t = System.currentTimeMillis();
      Arrays.sort(arr1, (a,b)->a>b? 1: a==b? 0: -1);
      t = System.currentTimeMillis() - t;
      System.out.println("Time spent for system based Arrays.sort(): " + t + "ms");
 
      // Test custom single-threaded merge sort (recursive merge) implementation
      Integer[] arr2 = Arrays.copyOf(list, list.length);
      t = System.currentTimeMillis();
      MergeSort.mergeSort(arr2, 0, arr2.length-1);
      t = System.currentTimeMillis() - t;
      System.out.println("Time spent for custom single threaded recursive merge_sort(): " + t + "ms");
 
      // Test custom (multi-threaded) merge sort (recursive merge) implementation
      Integer[] arr = Arrays.copyOf(list, list.length);
      MergeSort.threadedSort(arr);
      System.out.print("Output = [");
      for (Integer each: arr)
        System.out.print(each+", ");
      System.out.print("]\n");
    }
}


Python3




# Python Program to implement merge sort using
# multi-threading
import threading
import time
import random
 
# number of elements in array
MAX = 20
 
# number of threads
THREAD_MAX = 4
 
a = [0] * MAX
part = 0
 
# merge function for merging two parts
def merge(low, mid, high):
    left = a[low:mid+1]
    right = a[mid+1:high+1]
 
    # n1 is size of left part and n2 is size
    # of right part
    n1 = len(left)
    n2 = len(right)
    i = j = 0
    k = low
 
    # merge left and right in ascending order
    while i < n1 and j < n2:
        if left[i] <= right[j]:
            a[k] = left[i]
            i += 1
        else:
            a[k] = right[j]
            j += 1
        k += 1
 
    while i < n1:
        a[k] = left[i]
        i += 1
        k += 1
 
    while j < n2:
        a[k] = right[j]
        j += 1
        k += 1
 
# merge sort function
def merge_sort(low, high):
    if low < high:
        # calculating mid point of array
        mid = low + (high - low) // 2
 
        merge_sort(low, mid)
        merge_sort(mid + 1, high)
 
        # merging the two halves
        merge(low, mid, high)
 
# thread function for multi-threading
def merge_sort_threaded():
    global part
     
    # creating 4 threads
    for i in range(THREAD_MAX):
        t = threading.Thread(target=merge_sort, args=(part*(MAX//4), (part+1)*(MAX//4)-1))
        part += 1
        t.start()
         
    # joining all 4 threads
    for i in range(THREAD_MAX):
        t.join()
 
    # merging the final 4 parts
    merge(0, (MAX // 2 - 1) // 2, MAX // 2 - 1)
    merge(MAX // 2, MAX // 2 + (MAX - 1 - MAX // 2) // 2, MAX - 1)
    merge(0, (MAX - 1) // 2, MAX - 1)
 
# Driver Code
if __name__ == '__main__':
      # generating random values in array
    for i in range(MAX):
        a[i] = random.randint(0, 100)
 
     # t1 and t2 for calculating time for
    # merge sort
    t1 = time.perf_counter()
 
    merge_sort_threaded()
 
    t2 = time.perf_counter()
 
    print("Sorted array:", a)
    print(f"Time taken: {t2 - t1:.6f} seconds")


C#




using System;
using System.Threading;
 
public class MergeSortMultiThreaded
{
    // Number of elements in array
    const int MAX = 20;
 
    // Number of threads
    const int THREAD_MAX = 4;
 
    // Array of size MAX
    static int[] a = new int[MAX];
    static int part = 0;
 
    // Merge function for merging two parts
    static void Merge(int low, int mid, int high)
    {
        int[] left = new int[mid - low + 1];
        int[] right = new int[high - mid];
 
        // Size of left and right parts
        int n1 = mid - low + 1, n2 = high - mid, i, j;
 
        // Storing values in left part
        for (i = 0; i < n1; ++i)
            left[i] = a[i + low];
 
        // Storing values in right part
        for (i = 0; i < n2; ++i)
            right[i] = a[i + mid + 1];
 
        int k = low;
        i = j = 0;
 
        // Merge left and right in ascending order
        while (i < n1 && j < n2)
        {
            if (left[i] <= right[j])
                a[k++] = left[i++];
            else
                a[k++] = right[j++];
        }
 
        // Insert remaining values from left
        while (i < n1)
            a[k++] = left[i++];
 
        // Insert remaining values from right
        while (j < n2)
            a[k++] = right[j++];
    }
 
    // Merge sort function
    static void MergeSort(int low, int high)
    {
        int mid = low + (high - low) / 2;
        if (low < high)
        {
            MergeSort(low, mid);
            MergeSort(mid + 1, high);
            Merge(low, mid, high);
        }
    }
 
    // Thread function for multi-threading
    static void MergeSortThreaded()
    {
        int threadPart = Interlocked.Increment(ref part) - 1;
        int low = threadPart * (MAX / THREAD_MAX);
        int high = ((threadPart + 1) * (MAX / THREAD_MAX)) - 1;
        int mid = low + (high - low) / 2;
        if (low < high)
        {
            MergeSort(low, mid);
            MergeSort(mid + 1, high);
            Merge(low, mid, high);
        }
    }
 
    // Driver Code
    public static void Main(string[] args)
    {
        Random rand = new Random();
        // Generate random values in array
        for (int i = 0; i < MAX; ++i)
            a[i] = rand.Next(100);
 
        // Start the timer
        DateTime startTime = DateTime.Now;
 
        // Create threads and start sorting
        Thread[] threads = new Thread[THREAD_MAX];
        for (int i = 0; i < THREAD_MAX; ++i)
        {
            threads[i] = new Thread(new ThreadStart(MergeSortThreaded));
            threads[i].Start();
        }
 
        // Wait for all threads to finish
        foreach (Thread t in threads)
            t.Join();
 
        // Merge final parts
        Merge(0, (MAX / 2 - 1) / 2, MAX / 2 - 1);
        Merge(MAX / 2, MAX / 2 + (MAX - 1 - MAX / 2) / 2, MAX - 1);
        Merge(0, (MAX - 1) / 2, MAX - 1);
 
        // Stop the timer
        DateTime endTime = DateTime.Now;
 
        // Display sorted array
        Console.Write("Sorted array: ");
        for (int i = 0; i < MAX; ++i)
            Console.Write(a[i] + " ");
 
        // Display time taken
        Console.WriteLine("\nTime taken: " + (endTime - startTime).TotalSeconds + " seconds");
    }
}
//This code is contributed by Aman


Javascript




const MAX = 20;
const THREAD_MAX = 4;
 
const a = new Array(MAX);
let part = 0;
 
function merge(low, mid, high) {
    const left = a.slice(low, mid + 1);
    const right = a.slice(mid + 1, high + 1);
 
    let i = 0, j = 0, k = low;
 
    while (i < left.length && j < right.length) {
        if (left[i] <= right[j]) {
            a[k] = left[i];
            i++;
        } else {
            a[k] = right[j];
            j++;
        }
        k++;
    }
 
    while (i < left.length) {
        a[k] = left[i];
        i++;
        k++;
    }
 
    while (j < right.length) {
        a[k] = right[j];
        j++;
        k++;
    }
}
 
function mergeSort(low, high) {
    if (low < high) {
        const mid = low + Math.floor((high - low) / 2);
 
        mergeSort(low, mid);
        mergeSort(mid + 1, high);
 
        merge(low, mid, high);
    }
}
 
function mergeSortThreaded() {
    for (let i = 0; i < THREAD_MAX; i++) {
        const start = part * (MAX / 4);
        const end = (part + 1) * (MAX / 4) - 1;
 
        setTimeout(() => {
            mergeSort(start, end);
        });
 
        part++;
    }
 
    // Adding a delay to ensure threads complete before merging
    setTimeout(() => {
        merge(0, Math.floor((MAX / 2 - 1) / 2), MAX / 2 - 1);
        merge(MAX / 2, Math.floor(MAX / 2 + (MAX - 1 - MAX / 2) / 2), MAX - 1);
        merge(0, Math.floor((MAX - 1) / 2), MAX - 1);
 
        console.log("Sorted array:", a);
    }, 500);
}
 
// Driver Code
function main() {
    // Generating random values in array
    for (let i = 0; i < MAX; i++) {
        a[i] = Math.floor(Math.random() * 101);
    }
 
    const t1 = performance.now();
    mergeSortThreaded();
    const t2 = performance.now();
 
    console.log(`Time taken: ${(t2 - t1) / 1000} seconds`);
}
 
// Invoke the main function
main();


Output: 

Sorted array: 15 21 26 26 27 35 36 40 49 59 62 63 72 77 83 86 86 90 92 93
Time taken: 0.001023

Time Complexity: O(nLogn)
Auxiliary Space:  O(n)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads