Implementation of Chinese Remainder theorem (Inverse Modulo based implementation)
Last Updated :
28 Nov, 2022
We are given two arrays num[0..k-1] and rem[0..k-1]. In num[0..k-1], every pair is coprime (gcd for every pair is 1). We need to find minimum positive number x such that:
x % num[0] = rem[0],
x % num[1] = rem[1],
.......................
x % num[k-1] = rem[k-1]
Example:
Input: num[] = {3, 4, 5}, rem[] = {2, 3, 1}
Output: 11
Explanation:
11 is the smallest number such that:
(1) When we divide it by 3, we get remainder 2.
(2) When we divide it by 4, we get remainder 3.
(3) When we divide it by 5, we get remainder 1.
We strongly recommend to refer below post as a prerequisite for this.
Chinese Remainder Theorem | Set 1 (Introduction)
We have discussed a Naive solution to find minimum x. In this article, an efficient solution to find x is discussed.
The solution is based on below formula.
x = ( ? (rem[i]*pp[i]*inv[i]) ) % prod
Where 0 <= i <= n-1
rem[i] is given array of remainders
prod is product of all given numbers
prod = num[0] * num[1] * ... * num[k-1]
pp[i] is product of all divided by num[i]
pp[i] = prod / num[i]
inv[i] = Modular Multiplicative Inverse of
pp[i] with respect to num[i]
Example:
Let us take below example to understand the solution
num[] = {3, 4, 5}, rem[] = {2, 3, 1}
prod = 60
pp[] = {20, 15, 12}
inv[] = {2, 3, 3} // (20*2)%3 = 1, (15*3)%4 = 1
// (12*3)%5 = 1
x = (rem[0]*pp[0]*inv[0] + rem[1]*pp[1]*inv[1] +
rem[2]*pp[2]*inv[2]) % prod
= (2*20*2 + 3*15*3 + 1*12*3) % 60
= (80 + 135 + 36) % 60
= 11
Refer this for nice visual explanation of above formula.
Below is the implementation of above formula. We can use Extended Euclid based method discussed here to find inverse modulo.
C++
#include <bits/stdc++.h>
using namespace std;
int inv( int a, int m)
{
int m0 = m, t, q;
int x0 = 0, x1 = 1;
if (m == 1)
return 0;
while (a > 1) {
q = a / m;
t = m;
m = a % m, a = t;
t = x0;
x0 = x1 - q * x0;
x1 = t;
}
if (x1 < 0)
x1 += m0;
return x1;
}
int findMinX( int num[], int rem[], int k)
{
int prod = 1;
for ( int i = 0; i < k; i++)
prod *= num[i];
int result = 0;
for ( int i = 0; i < k; i++) {
int pp = prod / num[i];
result += rem[i] * inv(pp, num[i]) * pp;
}
return result % prod;
}
int main( void )
{
int num[] = { 3, 4, 5 };
int rem[] = { 2, 3, 1 };
int k = sizeof (num) / sizeof (num[0]);
cout << "x is " << findMinX(num, rem, k);
return 0;
}
|
Java
import java.io.*;
class GFG {
static int inv( int a, int m)
{
int m0 = m, t, q;
int x0 = 0 , x1 = 1 ;
if (m == 1 )
return 0 ;
while (a > 1 ) {
q = a / m;
t = m;
m = a % m;
a = t;
t = x0;
x0 = x1 - q * x0;
x1 = t;
}
if (x1 < 0 )
x1 += m0;
return x1;
}
static int findMinX( int num[], int rem[], int k)
{
int prod = 1 ;
for ( int i = 0 ; i < k; i++)
prod *= num[i];
int result = 0 ;
for ( int i = 0 ; i < k; i++) {
int pp = prod / num[i];
result += rem[i] * inv(pp, num[i]) * pp;
}
return result % prod;
}
public static void main(String args[])
{
int num[] = { 3 , 4 , 5 };
int rem[] = { 2 , 3 , 1 };
int k = num.length;
System.out.println( "x is " + findMinX(num, rem, k));
}
}
|
Python3
def inv(a, m) :
m0 = m
x0 = 0
x1 = 1
if (m = = 1 ) :
return 0
while (a > 1 ) :
q = a / / m
t = m
m = a % m
a = t
t = x0
x0 = x1 - q * x0
x1 = t
if (x1 < 0 ) :
x1 = x1 + m0
return x1
def findMinX(num, rem, k) :
prod = 1
for i in range ( 0 , k) :
prod = prod * num[i]
result = 0
for i in range ( 0 ,k):
pp = prod / / num[i]
result = result + rem[i] * inv(pp, num[i]) * pp
return result % prod
num = [ 3 , 4 , 5 ]
rem = [ 2 , 3 , 1 ]
k = len (num)
print ( "x is " , findMinX(num, rem, k))
|
C#
using System;
class GFG
{
static int inv( int a, int m)
{
int m0 = m, t, q;
int x0 = 0, x1 = 1;
if (m == 1)
return 0;
while (a > 1)
{
q = a / m;
t = m;
m = a % m; a = t;
t = x0;
x0 = x1 - q * x0;
x1 = t;
}
if (x1 < 0)
x1 += m0;
return x1;
}
static int findMinX( int []num,
int []rem,
int k)
{
int prod = 1;
for ( int i = 0; i < k; i++)
prod *= num[i];
int result = 0;
for ( int i = 0; i < k; i++)
{
int pp = prod / num[i];
result += rem[i] *
inv(pp, num[i]) * pp;
}
return result % prod;
}
static public void Main ()
{
int []num = {3, 4, 5};
int []rem = {2, 3, 1};
int k = num.Length;
Console.WriteLine( "x is " +
findMinX(num, rem, k));
}
}
|
PHP
<?php
function inv( $a , $m )
{
$m0 = $m ;
$x0 = 0;
$x1 = 1;
if ( $m == 1)
return 0;
while ( $a > 1)
{
$q = (int)( $a / $m );
$t = $m ;
$m = $a % $m ;
$a = $t ;
$t = $x0 ;
$x0 = $x1 - $q * $x0 ;
$x1 = $t ;
}
if ( $x1 < 0)
$x1 += $m0 ;
return $x1 ;
}
function findMinX( $num , $rem , $k )
{
$prod = 1;
for ( $i = 0; $i < $k ; $i ++)
$prod *= $num [ $i ];
$result = 0;
for ( $i = 0; $i < $k ; $i ++)
{
$pp = (int) $prod / $num [ $i ];
$result += $rem [ $i ] * inv( $pp ,
$num [ $i ]) * $pp ;
}
return $result % $prod ;
}
$num = array (3, 4, 5);
$rem = array (2, 3, 1);
$k = sizeof( $num );
echo "x is " . findMinX( $num , $rem , $k );
?>
|
Javascript
<script>
function inv(a, m)
{
let m0 = m;
let x0 = 0;
let x1 = 1;
if (m == 1)
return 0;
while (a > 1)
{
let q = parseInt(a / m);
let t = m;
m = a % m;
a = t;
t = x0;
x0 = x1 - q * x0;
x1 = t;
}
if (x1 < 0)
x1 += m0;
return x1;
}
function findMinX(num, rem, k)
{
let prod = 1;
for (let i = 0; i < k; i++)
prod *= num[i];
let result = 0;
for (let i = 0; i < k; i++)
{
pp = parseInt(prod / num[i]);
result += rem[i] * inv(pp,
num[i]) * pp;
}
return result % prod;
}
let num = new Array(3, 4, 5);
let rem = new Array(2, 3, 1);
let k = num.length;
document.write( "x is " + findMinX(num, rem, k));
</script>
|
Output:
x is 11
Time Complexity : O(N*LogN)
Auxiliary Space : O(1)
Share your thoughts in the comments
Please Login to comment...