Open In App

Find Square Root under Modulo p | (When p is product of two primes in the form 4*i + 3)

Improve
Improve
Like Article
Like
Save
Share
Report

Given an integer N and an integer P denoting the product of two primes, the task is to find all possible Square Root of N under modulo P if it exists. It is given that P is the product of p1 and p2, where p1 and p2 are prime numbers of the form 4i + 3 where i is any integer. Some examples of such prime numbers are (7, 11, 19, 23, 31). 
Examples: 
 

Input: N = 67, P = 77 
Output: 23 65 12 54 
Explanation: 
All possible answers are (23, 65, 12, 54)
Input: N = 188, P = 437 
Output: 25 44 393 412 
Explanation:

 
All possible answers are (25, 44, 393, 412) 
 

Naive Approach: 
The simplest approach to solve this problem is to consider all numbers from 2 to P – 1 and for every number, check if it is the square root of N under modulo P. If found to be true, then print that number and break. 
Time Complexity: O(P) 
Auxiliary Space: O(1)
Efficient Approach: 
To optimize the above approach, the idea is to use Prime Factorization to obtain two factors p1 and p2 and then find square root of both using Square root under mod P. Then use Chinese Remainder Theorem to find the square root modulo p1 * p2. Each set will contain two equations (as there are two square roots each for modulo p1 and p2). Combining them all, four sets of equations are obtained which gives the four possible answers. 
Follow the steps below to solve this problem: 
 

  • Factorize the number P by using Prime Factorization. Factors will be two primes p1 and p2.
  • Find square root modulo primes p1 and p2.
  • Two sets for answers is found, one for each prime. Each set containing two numbers.
  • So there will be 4 set of equations if one number from each set is chosen.
  • Perform Chinese Remainder Theorem to find the square root modulo p1 * p2 and print them.

Below is the implementation of the above approach where BigInteger class in Java is used to handle very large numbers:
 

C++




#include <cmath>
#include <iostream>
 
// Function to calculate the integer square root of n
int isqrt(int n) { return static_cast<int>(std::sqrt(n)); }
 
// Function to calculate the modular inverse of a modulo m
int mod_inverse(int a, int m)
{
    int m0 = m, y = 0, x = 1;
 
    if (m == 1) {
        return 0;
    }
 
    while (a > 1) {
        int q = a / m;
 
        int t = m;
        m = a % m;
        a = t;
 
        t = y;
        y = x - q * y;
        x = t;
    }
 
    while (x < 0) {
        x += m0;
    }
 
    return x;
}
 
// Chinese Remainder Theorem (CRT) function
int CRT(int num[], int rem[], int k)
{
    int prod = 1;
 
    for (int i = 0; i < k; i++) {
        prod *= num[i];
    }
 
    int result = 0;
 
    for (int i = 0; i < k; i++) {
        int pp = prod / num[i];
        result += rem[i] * mod_inverse(pp, num[i]) * pp;
    }
 
    return result % prod;
}
 
// Function to calculate base^exponent % modulus efficiently
int pow_mod(int base, int exponent, int modulus)
{
    int result = 1;
    base = base % modulus;
 
    while (exponent > 0) {
        if (exponent % 2 == 1) {
            result = (result * base) % modulus;
        }
        exponent /= 2;
        base = (base * base) % modulus;
    }
 
    return result;
}
 
// Function to calculate the square root of n modulo p
int square_root(int n, int p)
{
    if (p % 4 != 3) {
        std::cout << "Invalid Input" << std::endl;
        return -1;
    }
 
    n = n % p;
    int x = pow_mod(n, (p + 1) / 4, p);
 
    if (pow_mod(x, 2, p) == n) {
        return x;
    }
 
    x = p - x;
    if (pow_mod(x, 2, p) == n) {
        return x;
    }
 
    return -1;
}
 
// Function to calculate the floor square root of x
int sqrt_c(int x)
{
    if (x < 0) {
        return -1;
    }
    if (x == 0 || x == 1) {
        return x;
    }
 
    int y = x / 2;
    while (y > x / y) {
        y = (y + x / y) / 2;
    }
 
    if (x % y == 0) {
        return y;
    }
    else {
        return y + 1;
    }
}
 
// Function to factorize a number and return the factors
int* factorize(int p)
{
    static int ans[2] = { 0, 0 };
    int b = 2;
 
    for (b = 2; b <= isqrt(p); b++) {
        if (p % b == 0) {
            ans[0] = b;
            ans[1] = p / b;
        }
    }
 
    return ans;
}
 
// Function to solve the given problem
void solve(int a, int m)
{
    int* s = factorize(m);
    const int ZERO = 0;
 
    if (s[0] * s[1] != m) {
        std::cout << "Not a product of two primes"
                  << std::endl;
    }
    else {
        int s1[4] = { 0 };
        int a1[4] = { 0 };
        int a2[2] = { 0 };
 
        a1[0] = square_root(a % s[0], s[0]);
        a1[1] = square_root(a % s[1], s[1]);
 
        a1[2] = -a1[0];
        a1[3] = -a1[1];
 
        while (a1[2] < ZERO) {
            a1[2] += s[0];
        }
 
        while (a1[3] < ZERO) {
            a1[3] += s[1];
        }
 
        s1[0] = s[0];
        s1[1] = s[1];
        s1[2] = s[0];
        s1[3] = s[1];
 
        if (a1[0] == -1 || a1[1] == -1) {
            std::cout << "No Solution" << std::endl;
        }
        else {
            std::cout << CRT(s, a1, 2) << " ";
 
            a2[0] = a1[0];
            a2[1] = a1[3];
            std::cout << CRT(s, a2, 2) << " ";
 
            a2[0] = a1[2];
            a2[1] = a1[1];
            std::cout << CRT(s, a2, 2) << " ";
 
            a2[0] = a1[2];
            a2[1] = a1[3];
            std::cout << CRT(s, a2, 2) << " ";
        }
    }
}
 
int main()
{
    // Given Number N
    int N = 188;
 
    // Given product of Prime Numbers
    int P = 437;
 
    // Function Call
    solve(N, P);
 
    return 0;
}


Java




// Java program for the above approach
import java.math.*;
import java.util.*;
 
class SqrtMod {
 
    // Function to find the modulo inverse
    // of a with respect to m
    static BigInteger modInverse(BigInteger a,
                                 BigInteger m)
    {
 
        BigInteger m0 = m;
        BigInteger y = new BigInteger("0"),
                   x = new BigInteger("1");
 
        if (m.compareTo(new BigInteger("1")) == 0)
            return new BigInteger("0");
 
        // Perform Extended Euclid Algorithm
        while (a.compareTo(new BigInteger("1")) > 0) {
 
            if (m.toString().equals("0"))
                break;
            BigInteger q = a.divide(m);
 
            BigInteger t = m;
 
            // M is remainder now, process
            // Extended Euclid Algorithm
            m = a.mod(m);
            a = t;
 
            t = y;
            y = x.subtract(q.multiply(y));
            x = t;
        }
 
        // Make x positive
        while (x.compareTo(new BigInteger("0")) < 0)
            x = x.add(m0);
        return x;
    }
 
    // Function to apply the Chinese
    // Remainder Theorem
    static String CRT(BigInteger num[],
                      BigInteger rem[],
                      int k)
    {
 
        BigInteger prod = new BigInteger("1");
 
        // Find product of all numbers
        for (int i = 0; i < k; i++)
            prod = prod.multiply(num[i]);
 
        BigInteger result = new BigInteger("0");
 
        // Apply the above formula
        for (int i = 0; i < k; i++) {
            BigInteger pp = prod.divide(num[i]);
            result = result.add(rem[i]
                                    .multiply(modInverse(pp,
                                                         num[i]))
                                    .multiply(pp));
        }
 
        // Return result
        return result.mod(prod).toString();
    }
 
    // Function to perform modular
    // exponentiation
    static BigInteger powMod(BigInteger base1,
                             BigInteger exponent,
                             BigInteger modulus)
    {
 
        BigInteger result = new BigInteger("1");
 
        base1 = base1.mod(modulus);
 
        while (exponent
                   .compareTo(new BigInteger("0"))
               > 0) {
 
            if (exponent
                    .mod(new BigInteger("2"))
                    .equals(new BigInteger("1")))
 
                result = (result.multiply(base1))
                             .mod(modulus);
 
            exponent = exponent
                           .divide(new BigInteger("2"));
 
            base1 = base1
                        .multiply(base1)
                        .mod(modulus);
        }
 
        // Return the result
        return result;
    }
 
    // Function that returns square root
    // of n under modulo p if exists
    static BigInteger squareRoot(
        BigInteger n,
        BigInteger p)
    {
 
        // For Invalid Input
        if (!p.mod(new BigInteger("4"))
                 .equals(new BigInteger("3"))) {
 
            System.out.print("Invalid Input");
            return new BigInteger("-1");
        }
 
        n = n.mod(p);
 
        // Try "+(n^((p + 1)/4))"
        BigInteger x = powMod(n,
                              (p.add(new BigInteger("1")))
                                  .divide(new BigInteger("4")),
                              p);
 
        if ((x.multiply(x)).mod(p).equals(n)) {
            return x;
        }
 
        // Try "-(n ^ ((p + 1)/4))"
        x = p.subtract(x);
        if ((x.multiply(x)).mod(p).equals(n)) {
            return x;
        }
 
        // If none of above two work,
        // then sqrt doesn't exist
        return new BigInteger("-1");
    }
 
    // Function to find the ceiling
    // of square root of a number
    public static BigInteger
    sqrtC(BigInteger x)
    {
 
        // If number < zero
        if (x.compareTo(BigInteger.ZERO) < 0) {
            return new BigInteger("-1");
        }
 
        // If number is zero or 1
        if (x == BigInteger.ZERO
            || x == BigInteger.ONE) {
            return x;
        }
 
        BigInteger two
            = BigInteger.valueOf(2L);
 
        BigInteger y;
 
        // Iterate to find the square root
        for (y = x.divide(two);
             y.compareTo(x.divide(y)) > 0;
             y = ((x.divide(y)).add(y)).divide(two))
            ;
        if (x.compareTo(y.multiply(y)) == 0) {
            return y;
        }
        else {
            return y.add(BigInteger.ONE);
        }
    }
 
    // Function to factorise number P
    static BigInteger[] factorise(BigInteger p)
    {
 
        BigInteger ans[] = new BigInteger[2];
        BigInteger b = new BigInteger("2");
 
        BigInteger ONE = new BigInteger("1");
        BigInteger ZERO = new BigInteger("0");
 
        // Iterate over [2, sqrt(N)]
        for (; b.compareTo(
                   sqrtC(p).add(ONE))
               < 0;
             b = b.add(ONE)) {
 
            // Check if mod is zero
            if (p.mod(b).equals(ZERO)) {
                ans[0] = b;
                ans[1] = p.divide(b);
            }
        }
 
        // Return the ans
        return ans;
    }
 
    // Function to find the all possible
    // squareRoot of a under modulo m
    public static void
    solve(BigInteger a, BigInteger m)
    {
 
        // Find factorization of m
        BigInteger s[] = factorise(m);
        BigInteger ZERO = new BigInteger("0");
 
        // If number P is not product
        // of two primes
        if (!s[0].multiply(s[1]).equals(m)) {
 
            System.out.println("Not a product "
                               + "of two primes");
        }
 
        // Find the numbers
        else {
 
            BigInteger s1[] = new BigInteger[4];
            BigInteger a1[] = new BigInteger[4];
            BigInteger a2[] = new BigInteger[2];
 
            // Find squareRoot
            a1[0] = squareRoot(a.mod(s[0]), s[0]);
            a1[1] = squareRoot(a.mod(s[1]), s[1]);
 
            a1[2] = a1[0].multiply(new BigInteger("-1"));
            a1[3] = a1[1].multiply(new BigInteger("-1"));
 
            // Compare to Zero
            while (a1[2].compareTo(ZERO) < 0)
                a1[2] = a1[2].add(s[0]);
            while (a1[3].compareTo(ZERO) < 0)
                a1[3] = a1[3].add(s[1]);
            s1[0] = s[0];
            s1[1] = s[1];
            s1[2] = s[0];
            s1[3] = s[1];
 
            // Condition for no solution
            if (a1[0].equals(new BigInteger("-1"))
                || a1[1].equals(new BigInteger("-1"))) {
 
                System.out.println("No Solution");
            }
 
            // Else print all possible answers
            else {
 
                // Perform Chinese Remainder
                // Theorem and print numbers
                System.out.print(
                    CRT(s, a1, 2) + " ");
 
                a2[0] = a1[0];
                a2[1] = a1[3];
 
                System.out.print(
                    CRT(s, a2, 2) + " ");
 
                a2[0] = a1[2];
                a2[1] = a1[1];
 
                System.out.print(
                    CRT(s, a2, 2) + " ");
 
                a2[0] = a1[2];
                a2[1] = a1[3];
 
                System.out.print(
                    CRT(s, a2, 2) + " ");
            }
        }
    }
 
    // Driver Code
    public static void
        main(String[] args) throws Exception
    {
 
        // Given Number N
        BigInteger N = new BigInteger("188");
 
        // Given product of Prime Number
        BigInteger P = new BigInteger("437");
 
        // Function Call
        solve(N, P);
    }
}


Python3




import math
 
def isqrt(n):
    return int(math.sqrt(n))
 
def mod_inverse(a, m):
    m0, y, x = m, 0, 1
 
    if m == 1:
        return 0
 
    while a > 1:
        q = a // m
 
        t = m
        m = a % m
        a = t
 
        t = y
        y = x - q * y
        x = t
 
    while x < 0:
        x += m0
 
    return x
 
def CRT(num, rem, k):
    prod = 1
 
    for i in range(k):
        prod *= num[i]
 
    result = 0
 
    for i in range(k):
        pp = prod // num[i]
        result += rem[i] * mod_inverse(pp, num[i]) * pp
 
    return result % prod
 
def pow_mod(base, exponent, modulus):
    result = 1
    base = base % modulus
 
    while exponent > 0:
        if exponent % 2 == 1:
            result = (result * base) % modulus
        exponent //= 2
        base = (base * base) % modulus
 
    return result
 
def square_root(n, p):
    if p % 4 != 3:
        print("Invalid Input")
        return -1
 
    n = n % p
    x = pow_mod(n, (p + 1) // 4, p)
 
    if pow_mod(x, 2, p) == n:
        return x
 
    x = p - x
    if pow_mod(x, 2, p) == n:
        return x
 
    return -1
 
def sqrt_c(x):
    if x < 0:
        return -1
    if x == 0 or x == 1:
        return x
 
    y = x // 2
    while y > x // y:
        y = (y + x // y) // 2
 
    if x % y == 0:
        return y
    else:
        return y + 1
 
def factorize(p):
    ans = [0, 0]
    b = 2
 
    for b in range(2, isqrt(p) + 1):
        if p % b == 0:
            ans[0] = b
            ans[1] = p // b
 
    return ans
 
def solve(a, m):
    s = factorize(m)
    ZERO = 0
 
    if s[0] * s[1] != m:
        print("Not a product of two primes")
    else:
        s1 = [0] * 4
        a1 = [0] * 4
        a2 = [0] * 2
 
        a1[0] = square_root(a % s[0], s[0])
        a1[1] = square_root(a % s[1], s[1])
 
        a1[2] = -a1[0]
        a1[3] = -a1[1]
 
        while a1[2] < ZERO:
            a1[2] += s[0]
 
        while a1[3] < ZERO:
            a1[3] += s[1]
 
        s1[0] = s[0]
        s1[1] = s[1]
        s1[2] = s[0]
        s1[3] = s[1]
 
        if a1[0] == -1 or a1[1] == -1:
            print("No Solution")
        else:
            print(CRT(s, a1, 2), end=" ")
 
            a2[0] = a1[0]
            a2[1] = a1[3]
            print(CRT(s, a2, 2), end=" ")
 
            a2[0] = a1[2]
            a2[1] = a1[1]
            print(CRT(s, a2, 2), end=" ")
 
            a2[0] = a1[2]
            a2[1] = a1[3]
            print(CRT(s, a2, 2), end=" ")
 
# Driver Code
if __name__ == "__main__":
    # Given Number N
    N = 188
 
    # Given product of Prime Number
    P = 437
 
    # Function Call
    solve(N, P)


C#




using System;
 
class Program
{
    static int isqrt(int n)
    {
        return (int)Math.Sqrt(n);
    }
 
    static int mod_inverse(int a, int m)
    {
        int m0 = m, y = 0, x = 1;
 
        if (m == 1)
            return 0;
 
        while (a > 1)
        {
            int q = a / m;
 
            int t = m;
            m = a % m;
            a = t;
 
            t = y;
            y = x - q * y;
            x = t;
        }
 
        while (x < 0)
        {
            x += m0;
        }
 
        return x;
    }
 
    static int CRT(int[] num, int[] rem, int k)
    {
        int prod = 1;
 
        for (int i = 0; i < k; i++)
        {
            prod *= num[i];
        }
 
        int result = 0;
 
        for (int i = 0; i < k; i++)
        {
            int pp = prod / num[i];
            result += rem[i] * mod_inverse(pp, num[i]) * pp;
        }
 
        return result % prod;
    }
 
    static int pow_mod(int baseNum, int exponent, int modulus)
    {
        int result = 1;
        baseNum = baseNum % modulus;
 
        while (exponent > 0)
        {
            if (exponent % 2 == 1)
            {
                result = (result * baseNum) % modulus;
            }
            exponent /= 2;
            baseNum = (baseNum * baseNum) % modulus;
        }
 
        return result;
    }
 
    static int square_root(int n, int p)
    {
        if (p % 4 != 3)
        {
            Console.WriteLine("Invalid Input");
            return -1;
        }
 
        n = n % p;
        int x = pow_mod(n, (p + 1) / 4, p);
 
        if (pow_mod(x, 2, p) == n)
        {
            return x;
        }
 
        x = p - x;
        if (pow_mod(x, 2, p) == n)
        {
            return x;
        }
 
        return -1;
    }
 
    static int sqrt_c(int x)
    {
        if (x < 0)
        {
            return -1;
        }
        if (x == 0 || x == 1)
        {
            return x;
        }
 
        int y = x / 2;
        while (y > x / y)
        {
            y = (y + x / y) / 2;
        }
 
        return x % y == 0 ? y : y + 1;
    }
 
    static int[] factorize(int p)
    {
        int[] ans = new int[2] { 0, 0 };
        int b = 2;
 
        for (b = 2; b <= isqrt(p); b++)
        {
            if (p % b == 0)
            {
                ans[0] = b;
                ans[1] = p / b;
            }
        }
 
        return ans;
    }
 
    static void solve(int a, int m)
    {
        int ZERO = 0;
        int[] s = factorize(m);
 
        if (s[0] * s[1] != m)
        {
            Console.WriteLine("Not a product of two primes");
        }
        else
        {
            int[] s1 = new int[4];
            int[] a1 = new int[4];
            int[] a2 = new int[2];
 
            a1[0] = square_root(a % s[0], s[0]);
            a1[1] = square_root(a % s[1], s[1]);
 
            a1[2] = -a1[0];
            a1[3] = -a1[1];
 
            while (a1[2] < ZERO)
            {
                a1[2] += s[0];
            }
 
            while (a1[3] < ZERO)
            {
                a1[3] += s[1];
            }
 
            s1[0] = s[0];
            s1[1] = s[1];
            s1[2] = s[0];
            s1[3] = s[1];
 
            if (a1[0] == -1 || a1[1] == -1)
            {
                Console.WriteLine("No Solution");
            }
            else
            {
                Console.Write(CRT(s, a1, 2) + " ");
 
                a2[0] = a1[0];
                a2[1] = a1[3];
                Console.Write(CRT(s, a2, 2) + " ");
 
                a2[0] = a1[2];
                a2[1] = a1[1];
                Console.Write(CRT(s, a2, 2) + " ");
 
                a2[0] = a1[2];
                a2[1] = a1[3];
                Console.Write(CRT(s, a2, 2) + " ");
            }
        }
    }
 
    static void Main()
    {
        int N = 188;
        int P = 437;
 
        solve(N, P);
    }
}


Javascript




// Function to calculate the integer square root of n
function isqrt(n) {
    return Math.floor(Math.sqrt(n));
}
 
// Function to calculate the modular inverse of a modulo m
function mod_inverse(a, m) {
    let m0 = m, y = 0, x = 1;
 
    if (m === 1) {
        return 0;
    }
 
    while (a > 1) {
        let q = Math.floor(a / m);
 
        let t = m;
        m = a % m;
        a = t;
 
        t = y;
        y = x - q * y;
        x = t;
    }
 
    while (x < 0) {
        x += m0;
    }
 
    return x;
}
 
// Chinese Remainder Theorem (CRT) function
function CRT(num, rem, k) {
    let prod = 1;
 
    for (let i = 0; i < k; i++) {
        prod *= num[i];
    }
 
    let result = 0;
 
    for (let i = 0; i < k; i++) {
        let pp = prod / num[i];
        result += rem[i] * mod_inverse(pp, num[i]) * pp;
    }
 
    return result % prod;
}
 
// Function to calculate base^exponent % modulus efficiently
function pow_mod(base, exponent, modulus) {
    let result = 1;
    base = base % modulus;
 
    while (exponent > 0) {
        if (exponent % 2 === 1) {
            result = (result * base) % modulus;
        }
        exponent = Math.floor(exponent / 2);
        base = (base * base) % modulus;
    }
 
    return result;
}
 
// Function to calculate the square root of n modulo p
function square_root(n, p) {
    if (p % 4 !== 3) {
        console.log("Invalid Input");
        return -1;
    }
 
    n = n % p;
    let x = pow_mod(n, Math.floor((p + 1) / 4), p);
 
    if (pow_mod(x, 2, p) === n) {
        return x;
    }
 
    x = p - x;
    if (pow_mod(x, 2, p) === n) {
        return x;
    }
 
    return -1;
}
 
// Function to calculate the floor square root of x
function sqrt_c(x) {
    if (x < 0) {
        return -1;
    }
    if (x === 0 || x === 1) {
        return x;
    }
 
    let y = Math.floor(x / 2);
    while (y > Math.floor(x / y)) {
        y = Math.floor((y + Math.floor(x / y)) / 2);
    }
 
    if (x % y === 0) {
        return y;
    } else {
        return y + 1;
    }
}
 
// Function to factorize a number and return the factors
function factorize(p) {
    let ans = [0, 0];
    let b = 2;
 
    for (b = 2; b <= isqrt(p); b++) {
        if (p % b === 0) {
            ans[0] = b;
            ans[1] = Math.floor(p / b);
        }
    }
 
    return ans;
}
 
// Function to solve the given problem
function solve(a, m) {
    let s = factorize(m);
    const ZERO = 0;
 
    if (s[0] * s[1] !== m) {
        console.log("Not a product of two primes");
    } else {
        let s1 = [0, 0, 0, 0];
        let a1 = [0, 0, 0, 0];
        let a2 = [0, 0];
 
        a1[0] = square_root(a % s[0], s[0]);
        a1[1] = square_root(a % s[1], s[1]);
 
        a1[2] = -a1[0];
        a1[3] = -a1[1];
 
        while (a1[2] < ZERO) {
            a1[2] += s[0];
        }
 
        while (a1[3] < ZERO) {
            a1[3] += s[1];
        }
 
        s1[0] = s[0];
        s1[1] = s[1];
        s1[2] = s[0];
        s1[3] = s[1];
 
        if (a1[0] === -1 || a1[1] === -1) {
            console.log("No Solution");
        } else {
            console.log(CRT(s, a1, 2) + " ");
 
            a2[0] = a1[0];
            a2[1] = a1[3];
            console.log(CRT(s, a2, 2) + " ");
 
            a2[0] = a1[2];
            a2[1] = a1[1];
            console.log(CRT(s, a2, 2) + " ");
 
            a2[0] = a1[2];
            a2[1] = a1[3];
            console.log(CRT(s, a2, 2) + " ");
        }
    }
}
 
// Given Number N
let N = 188;
 
// Given product of Prime Numbers
let P = 437;
 
// Function Call
solve(N, P);


Output

25 44 393 412





Time Complexity: O(?P + logN) 
Auxiliary Space: O(1)
 



Last Updated : 26 Dec, 2023
Like Article
Save Article
Previous
Next
Share your thoughts in the comments
Similar Reads