Open In App

Barrett Reduction Algorithm (Optimized Variant)

Last Updated : 28 Oct, 2023
Improve
Improve
Like Article
Like
Save
Share
Report

The Barrett Reduction Algorithm computes the remainder of a big integer x divided by another large integer mod. The procedure precomputes mu, the inverse of mod with regard to the next power of 2. All mod calculations utilize this value.

Approach:

The Barrett Reduction Algorithm approximates x divided by mod using the precomputed value mu. The procedure initially separates x into 32-bit pieces q1 and r1. The Barrett approximations of x and r1 / mod are q2 = ((mu * q1) >> 32) * mod and r2 = mu * r1. Next, the program calculates q3 = r2 >> 32 and r3 = r1 – q3 * mod, the Barrett approximation of x/mod. Finally, res = r3 + ((r3 >> 63) & mod) yields the remainder. Mod makes a negative outcome positive.

Below are the steps involved in the implementation of the code:

  • Define mul_mod and pow_mod helpers. The mul_mod function multiplies three 64-bit unsigned integers (a, b, and mod) modulo mod. Pow_mod computes a raised to the power of b modulo mod from three 64-bit unsigned integers (a, b, and mod).
  • The barrett_reduce function computes the remainder of x divided by mod using the Barrett Reduction Algorithm for three 64-bit unsigned integers (x, mod, and mu).
  • Split x into 32-bit portions q1 and r1 using bit shifting and masking in barrett_reduce.
  • The Barrett approximations of x and r1 are q2 = ((mu * q1) >> 32) * mod and r2 = mu * r1.
  • The Barrett approximation of x/mod is q3 = r2 >> 32 and r3 = r1 – q3 * mod.
  • Remainder = r3 + ((r3 >> 63) & mod). Modify bad results to make them favorable.
  • Call barrett_reduce twice in the main function to calculate x^2 and x^1234.
  • Console output results.

Below is the implementation of the above approach:

C++




// C++ code for the above approach:
#include <cstdint>
#include <iostream>
 
// Calculates (a * b) % mod
std::uint64_t mul_mod(std::uint64_t a, std::uint64_t b,
                      std::uint64_t mod)
{
    std::uint64_t res = 0;
    while (b > 0) {
        if (b & 1) {
            res = (res + a) % mod;
        }
        a = (a * 2) % mod;
        b >>= 1;
    }
    return res;
}
 
// Calculates (a^b) % mod
std::uint64_t pow_mod(std::uint64_t a, std::uint64_t b,
                      std::uint64_t mod)
{
    std::uint64_t res = 1;
    a %= mod;
    while (b > 0) {
        if (b & 1) {
            res = mul_mod(res, a, mod);
        }
        a = mul_mod(a, a, mod);
        b >>= 1;
    }
    return res;
}
 
// Calculates Barrett reduction
// of x modulo mod
std::uint64_t barrett_reduce(std::uint64_t x,
                             std::uint64_t mod,
                             std::uint64_t mu)
{
    std::uint64_t q1 = x >> 32;
    std::uint64_t q2 = ((mu * q1) >> 32) * mod;
    std::uint64_t r1 = x & 0xffffffff;
    std::uint64_t r2 = mu * r1;
    std::uint64_t q3 = r2 >> 32;
    std::uint64_t r3 = r1 - q3 * mod;
    std::uint64_t res = r3 + ((r3 >> 63) & mod);
    if (res >= mod) {
        res -= mod;
    }
    return res;
}
 
// Driver code
int main()
{
    std::uint64_t x = 123456789;
    std::uint64_t mod = 1000000007;
    std::uint64_t mu = ((1ULL << 64) + mod - 1) / mod;
 
    // Function call
    // Calculate x^2 % mod
    std::uint64_t x_squared = mul_mod(x, x, mod);
    std::uint64_t x_squared_barrett
        = barrett_reduce(x_squared, mod, mu);
    std::cout << "x^2 % mod = " << x_squared_barrett
              << std::endl;
 
    // Calculate x^1234 % mod
    std::uint64_t x_pow_1234 = pow_mod(x, 1234, mod);
    std::uint64_t x_pow_1234_barrett
        = barrett_reduce(x_pow_1234, mod, mu);
    std::cout << "x^1234 % mod = " << x_pow_1234_barrett
              << std::endl;
 
    return 0;
}


Java




// java code for above approach
 
public class Main {
 
    // Calculates (a * b) % mod
    public static long mulMod(long a, long b, long mod)
    {
        long res = 0;
        while (b > 0) {
            if ((b & 1) == 1) {
                res = (res + a) % mod;
            }
            a = (a * 2) % mod;
            b >>= 1;
        }
        return res;
    }
 
    // Calculates (a^b) % mod
    public static long powMod(long a, long b, long mod)
    {
        long res = 1;
        a %= mod;
        while (b > 0) {
            if ((b & 1) == 1) {
                res = mulMod(res, a, mod);
            }
            a = mulMod(a, a, mod);
            b >>= 1;
        }
        return res;
    }
 
    // Calculates Barrett reduction of x modulo mod
    public static long barrettReduce(long x, long mod,
                                     long mu)
    {
        long q1 = x >> 32;
        long q2 = ((mu * q1) >> 32) * mod;
        long r1 = x & 0xffffffff;
        long r2 = mu * r1;
        long q3 = r2 >> 32;
        long r3 = r1 - q3 * mod;
        long res = r3 + ((r3 >> 63) & mod);
        if (res >= mod) {
            res -= mod;
        }
        return res;
    }
 
    // Driver code
    public static void main(String[] args)
    {
        long x = 123456789;
        long mod = 1000000007;
        long mu = ((1L << 64) + mod - 1) / mod;
 
        // Function call
        // Calculate x^2 % mod
        long xSquared = mulMod(x, x, mod);
        long xSquaredBarrett
            = barrettReduce(xSquared, mod, mu);
        System.out.println("x^2 % mod = "
                           + xSquaredBarrett);
 
        // Calculate x^1234 % mod
        long xPow1234 = powMod(x, 1234, mod);
        long xPow1234Barrett
            = barrettReduce(xPow1234, mod, mu);
        System.out.println("x^1234 % mod = "
                           + xPow1234Barrett);
    }
}


Python3




# Calculates (a * b) % mod
def mul_mod(a, b, mod):
    res = 0
    while b > 0:
        if b & 1 == 1:
            res = (res + a) % mod
        a = (a * 2) % mod
        b >>= 1
    return res
 
# Calculates (a^b) % mod
def pow_mod(a, b, mod):
    res = 1
    a %= mod
    while b > 0:
        if b & 1 == 1:
            res = mul_mod(res, a, mod)
        a = mul_mod(a, a, mod)
        b >>= 1
    return res
 
# Calculates Barrett reduction of x modulo mod
def barrett_reduce(x, mod, mu):
    r1 = x & 0xFFFFFFFF
    r2 = mu * r1
    q3 = r2 >> 32
    r3 = r1 - q3 * mod
    res = r3 + ((r3 >> 63) & mod)
    if res >= mod:
        res -= mod
    return res
 
# Driver code
def main():
    x = 123456789
    mod = 1000000007
    mu = ((1 << 64) + mod - 1) // mod
 
    # Function call
    # Calculate x^2 % mod
    x_squared = mul_mod(x, x, mod)
    x_squared_barrett = barrett_reduce(x_squared, mod, mu)
    print("x^2 % mod =", x_squared_barrett)
 
    # Calculate x^1234 % mod
    x_pow_1234 = pow_mod(x, 1234, mod)
    x_pow_1234_barrett = barrett_reduce(x_pow_1234, mod, mu)
    print("x^1234 % mod =", x_pow_1234_barrett)
 
# Run the main function
if __name__ == "__main__":
    main()


C#




using System;
 
public class MainClass
{
    // Calculates (a * b) % mod
    public static long MulMod(long a, long b, long mod)
    {
        long res = 0;
        while (b > 0)
        {
            if ((b & 1) == 1)
            {
                res = (res + a) % mod;
            }
            a = (a * 2) % mod;
            b >>= 1;
        }
        return res;
    }
 
    // Calculates (a^b) % mod
    public static long PowMod(long a, long b, long mod)
    {
        long res = 1;
        a %= mod;
        while (b > 0)
        {
            if ((b & 1) == 1)
            {
                res = MulMod(res, a, mod);
            }
            a = MulMod(a, a, mod);
            b >>= 1;
        }
        return res;
    }
 
    // Calculates Barrett reduction of x modulo mod
    public static long BarrettReduce(long x, long mod, long mu)
    {
        //long q1 = x >> 32; // Removed unused variable
        long r1 = x & 0xffffffff;
        long r2 = mu * r1;
        long q3 = r2 >> 32;
        long r3 = r1 - q3 * mod;
        long res = r3 + ((r3 >> 63) & mod);
        if (res >= mod)
        {
            res -= mod;
        }
        return res;
    }
 
    // Driver code
    public static void Main(string[] args)
    {
        long x = 123456789;
        long mod = 1000000007;
        long mu = ((1L << 64) + mod - 1) / mod;
 
        // Function call
        // Calculate x^2 % mod
        long xSquared = MulMod(x, x, mod);
        long xSquaredBarrett = BarrettReduce(xSquared, mod, mu);
        Console.WriteLine("x^2 % mod = " + xSquaredBarrett);
 
        // Calculate x^1234 % mod
        long xPow1234 = PowMod(x, 1234, mod);
        long xPow1234Barrett = BarrettReduce(xPow1234, mod, mu);
        Console.WriteLine("x^1234 % mod = " + xPow1234Barrett);
    }
}


Javascript




// Calculates (a * b) % mod
function mulMod(a, b, mod) {
    let res = BigInt(0);
    while (b > 0) {
        if (b & 1n) {
            res = (res + a) % mod;
        }
        a = (a * 2n) % mod;
        b >>= 1n;
    }
    return res;
}
 
// Calculates (a^b) % mod
function powMod(a, b, mod) {
    let res = BigInt(1);
    a %= mod;
    while (b > 0) {
        if (b & 1n) {
            res = mulMod(res, a, mod);
        }
        a = mulMod(a, a, mod);
        b >>= 1n;
    }
    return res;
}
 
// Calculates Barrett reduction of x modulo mod
function barrettReduce(x, mod, mu) {
    const r1 = x & 0xFFFFFFFFn;
    const r2 = mu * r1;
    const q3 = r2 >> 32n;
    const r3 = r1 - q3 * mod;
    let res = r3 + ((r3 >> 31n) & mod);
    if (res >= mod) {
        res -= mod;
    }
    return res;
}
 
// Driver code
function main() {
    const x = BigInt(123456789);
    const mod = BigInt(1000000007);
    const mu = ((BigInt(1) << BigInt(64)) + mod - 1n) / mod;
 
    // Function call
    // Calculate x^2 % mod
    const xSquared = mulMod(x, x, mod);
    const xSquaredBarrett = barrettReduce(xSquared, mod, mu);
    console.log(`x^2 % mod = ${xSquaredBarrett}`);
 
    // Calculate x^1234 % mod
    const xPow1234 = powMod(x, 1234n, mod);
    const xPow1234Barrett = barrettReduce(xPow1234, mod, mu);
    console.log(`x^1234 % mod = ${xPow1234Barrett}`);
}
 
// Run the main function
main();


Output

x^2 % mod = 643499475
x^1234 % mod = 182778789








Time Complexity: O(log b) 
Auxiliary Space: O(1)



Like Article
Suggest improvement
Previous
Next
Share your thoughts in the comments

Similar Reads