GeeksforGeeks App
Open App
Browser
Continue

# Program to generate an array having convolution of two given arrays

Given two arrays A[] and B[] consisting of N and M integers respectively, the task is to construct a convolution array C[] of size (N + M – 1)

The convolution of 2 arrays is defined as C[i + j] = ∑(a[i] * b[j]) for every i and j

Note: If the value for any index becomes very large, then print it to modulo 998244353.

Examples:

Input: A[] = {1, 2, 3, 4}, B[] = {5, 6, 7, 8, 9}
Output: {5, 16, 34, 60, 70, 70, 59, 36}
Explanation:
Size of array, C[] = N + M – 1 = 8.
C[0] = A[0] * B[0] = 1 * 5 = 5
C[1] = A[0] * B[1] + A[1] * B[0] = 1 * 6 + 2 * 5 = 16
C[2] = A[0] * B[2] + A[1] * B[1] + A[2] * B[0] = 1 * 7 + 2 * 6 + 3 * 5 = 34
Similarly, C[3] = 60, C[4] = 70, C[5] = 70, C[6] = 59, C[7] = 36.

Input: A[] = {10000000}, B[] = {10000000}
Output: {871938225}
Explanation:
Size of array, C[] = N + M – 1 = 1.
C[0] = A[0] * B[0] = (10000000 * 10000000) % 998244353 = 871938225

Naive Approach: In the convolution array, each term C[i + j] = (a[i] * b[j]) % 998244353. Therefore, the simplest approach is to iterate over both the arrays A[] and B[] using two nested loops to find the resulting convoluted array C[].

Below is the implementation of the above approach:

## C++

 `// C++ program for the above approach``#include ``using` `namespace` `std;``constexpr ``int` `MOD = 998244353;` `// Function to generate a convolution``// array of two given arrays``void` `findConvolution(``const` `vector<``int``>& a,``                     ``const` `vector<``int``>& b)``{``    ``// Stores the size of arrays``    ``int` `n = a.size(), m = b.size();` `    ``// Stores the final array``    ``vector<``long` `long``> c(n + m - 1);` `    ``// Traverse the two given arrays``    ``for` `(``int` `i = 0; i < n; ++i) {``        ``for` `(``int` `j = 0; j < m; ++j) {``          ` `              ``// Update the convolution array``            ``c[i + j] += 1LL*(a[i] * b[j]) % MOD;``        ``}``    ``}` `    ``// Print the convolution array c[]``    ``for` `(``int` `k = 0; k < c.size(); ++k) {``        ``c[k] %= MOD;``        ``cout << c[k] << ``" "``;``    ``}``}` `// Driver Code``int` `main()``{``    ``vector<``int``> A = { 1, 2, 3, 4 };``    ``vector<``int``> B = { 5, 6, 7, 8, 9 };``  ` `    ``findConvolution(A, B);` `    ``return` `0;``}`

## Java

 `// Java program for the above approach``import` `java.util.*;` `class` `GFG{` `static` `int` `MOD = ``998244353``;` `// Function to generate a convolution``// array of two given arrays``static` `void` `findConvolution(``int``[] a,``                            ``int``[] b)``{``    ` `    ``// Stores the size of arrays``    ``int` `n = a.length, m = b.length;` `    ``// Stores the final array``    ``int``[] c = ``new` `int``[(n + m - ``1``)];` `    ``// Traverse the two given arrays``    ``for``(``int` `i = ``0``; i < n; ++i)``    ``{``        ``for``(``int` `j = ``0``; j < m; ++j)``        ``{``            ` `            ``// Update the convolution array``            ``c[i + j] += (a[i] * b[j]) % MOD;``        ``}``    ``}` `    ``// Print the convolution array c[]``    ``for``(``int` `k = ``0``; k < c.length; ++k)``    ``{``        ``c[k] %= MOD;``        ``System.out.print(c[k] + ``" "``);``    ``}``}` `// Driver Code``public` `static` `void` `main(String args[])``{``    ``int``[] A = { ``1``, ``2``, ``3``, ``4` `};``    ``int``[] B = { ``5``, ``6``, ``7``, ``8``, ``9` `};``  ` `    ``findConvolution(A, B);}``}` `// This code is contributed by souravghosh0416`

## Python3

 `# Python3 program for the above approach``MOD ``=` `998244353` `# Function to generate a convolution``# array of two given arrays``def` `findConvolution(a, b):``    ` `    ``global` `MOD``    ` `    ``# Stores the size of arrays``    ``n, m ``=` `len``(a), ``len``(b)` `    ``# Stores the final array``    ``c ``=` `[``0``] ``*` `(n ``+` `m ``-` `1``)` `    ``# Traverse the two given arrays``    ``for` `i ``in` `range``(n):``        ``for` `j ``in` `range``(m):``            ` `            ``# Update the convolution array``            ``c[i ``+` `j] ``+``=` `(a[i] ``*` `b[j]) ``%` `MOD` `    ``# Print the convolution array c[]``    ``for` `k ``in` `range``(``len``(c)):``        ``c[k] ``%``=` `MOD``        ``print``(c[k], end ``=` `" "``)` `# Driver Code``if` `__name__ ``=``=` `'__main__'``:``    ` `    ``A ``=` `[``1``, ``2``, ``3``, ``4``]``    ``B ``=` `[``5``, ``6``, ``7``, ``8``, ``9``]``    ` `    ``findConvolution(A, B)` `# This code is contributed by mohit kumar 29`

## C#

 `// C# program for the above approach``using` `System;` `class` `GFG``{`` ``static` `int` `MOD = 998244353;` `// Function to generate a convolution``// array of two given arrays``static` `void` `findConvolution(``int``[] a,``                    ``int``[] b)``{``  ` `    ``// Stores the size of arrays``    ``int` `n = a.Length, m = b.Length;` `    ``// Stores the final array``    ``int``[] c = ``new` `int``[(n + m - 1)];` `    ``// Traverse the two given arrays``    ``for` `(``int` `i = 0; i < n; ++i) {``        ``for` `(``int` `j = 0; j < m; ++j) {``          ` `              ``// Update the convolution array``            ``c[i + j] += (a[i] * b[j]) % MOD;``        ``}``    ``}` `    ``// Print the convolution array c[]``    ``for` `(``int` `k = 0; k < c.Length; ++k) {``        ``c[k] %= MOD;``        ``Console.Write(c[k] + ``" "``);``    ``}``}` `// Driver Code``public` `static` `void` `Main(String[] args)``{``    ``int``[] A = { 1, 2, 3, 4 };``    ``int``[] B = { 5, 6, 7, 8, 9 };``  ` `    ``findConvolution(A, B);``}``}` `// This code is contributed by code_hunt.`

## Javascript

 ``

Output:

`5 16 34 60 70 70 59 36`

Time Complexity: O(N*M)
Auxiliary Space: O(N+M)

Efficient Approach: To optimize the above approach, the idea is to use the Number-Theoretic Transform (NTT) which is similar to Fast Fourier transform (FFT) for polynomial multiplication, which can work under modulo operations. The problem can be solved by using the same concept of iterative FFT to perform NTT on the given arrays as the same properties hold for the Nth roots of unity in modular arithmetic

Below is the implementation of the above approach:

## C++

 `// C++ program for the above approach``#include ``using` `namespace` `std;` `#define ll long long` `const` `ll mod = 998244353, maxn = 3e6;``ll a[maxn], b[maxn];` `// Iterative FFT function to compute``// the DFT of given coefficient vector``void` `fft(ll w0, ll n, ll* a)``{``    ``// Do bit reversal of the given array``    ``for` `(ll i = 0, j = 0; i < n; i++) {` `        ``// Swap a[i] and a[j]``        ``if` `(i < j)``            ``swap(a[i], a[j]);` `        ``// Right Shift N by 1``        ``ll bit = n >> 1;` `        ``for` `(; j & bit; bit >>= 1)``            ``j ^= bit;``        ``j ^= bit;``    ``}` `    ``// Perform the iterative FFT``    ``for` `(ll len = 2; len <= n; len <<= 1) {` `        ``ll wlen = w0;``        ``for` `(ll aux = n; aux > len; aux >>= 1) {``            ``wlen = wlen * wlen % mod;``        ``}` `        ``for` `(ll bat = 0; bat + len <= n; bat += len) {` `            ``for` `(ll i = bat, w = 1; i < bat + len / 2;``                 ``i++, w = w * wlen % mod) {` `                ``ll u = a[i], v = w * a[i + len / 2] % mod;` `                ``// Update the value of a[i]``                ``a[i] = (u + v) % mod,` `                ``// Update the value``                ``// of a[i + len/2]``                    ``a[i + len / 2]``                    ``= ((u - v) % mod + mod) % mod;``            ``}``        ``}``    ``}``}` `// Function to find (a ^ x) % mod``ll binpow(ll a, ll x)``{``    ``// Stores the result of a ^ x``    ``ll ans = 1;` `    ``// Iterate over the value of x``    ``for` `(; x; x /= 2, a = a * a % mod) {` `        ``// If x is odd``        ``if` `(x & 1)``            ``ans = ans * a % mod;``    ``}` `    ``// Return the resultant value``    ``return` `ans;``}` `// Function to find the``// inverse of a % mod``ll inv(ll a) { ``return` `binpow(a, mod - 2); }` `// Function to find the``// convolution of two arrays``void` `findConvolution(ll a[], ll b[], ll n, ll m)``{``    ``// Stores the first power of 2``    ``// greater than or equal to n + m``    ``ll _n = 1ll << 64 - __builtin_clzll(n + m);` `    ``// Stores the primitive root``    ``ll w = 15311432;``    ``for` `(ll aux = 1 << 23; aux > _n; aux >>= 1)``        ``w = w * w % mod;` `    ``// Convert arrays a[] and``    ``// b[] to point value form``    ``fft(w, _n, a);``    ``fft(w, _n, b);` `    ``// Perform multiplication``    ``for` `(ll i = 0; i < _n; i++)``        ``a[i] = a[i] * b[i] % mod;` `    ``// Perform inverse fft to``    ``// recover final array``    ``fft(inv(w), _n, a);``    ``for` `(ll i = 0; i < _n; i++)``        ``a[i] = a[i] * inv(_n) % mod;` `    ``// Print the convolution``    ``for` `(ll i = 0; i < n + m - 1; i++)``        ``cout << a[i] << ``" "``;``}` `// Driver Code``int` `main()``{``    ``// Given size of the arrays``    ``ll N = 4, M = 5;` `    ``// Fill the arrays``    ``for` `(ll i = 0; i < N; i++)``        ``a[i] = i + 1;` `    ``for` `(ll i = 0; i < M; i++)``        ``b[i] = 5 + i;` `    ``findConvolution(a, b, N, M);` `    ``return` `0;``}`

## Java

 `import` `java.math.BigInteger;` `public` `class` `GFG {``    ``public` `static` `final` `long` `mod = ``998244353``;``    ``public` `static` `final` `int` `maxn = ``3000000``;` `    ``public` `static` `long``[] a = ``new` `long``[maxn];``    ``public` `static` `long``[] b = ``new` `long``[maxn];` `    ``// Function to find the next power of 2 using``    ``// bit masking``    ``public` `static` `int` `nextPowerOf2(``int` `n)``    ``{``        ``// if n is a power of two simply return it``        ``if` `((n & (n - ``1``)) == ``0``)``            ``return` `n;``        ``// else set only the left bit of most significant``        ``// bit as in Java right shift is filled with most``        ``// significant bit we consider``        ``return` `0x40000000``            ``>> (Integer.numberOfLeadingZeros(n) - ``2``);``    ``}` `    ``// Iterative fft function to compute``    ``// the DFT of given coefficient vector``    ``public` `static` `void` `fft(``long` `w0, ``long` `n, ``long``[] a)``    ``{``        ``// Do bit reversal of the given array``        ``for` `(``long` `i = ``0``, j = ``0``; i < n; i++) {``            ``// Swap a[i] and a[j]``            ``if` `(i < j) {``                ``long` `temp = a[(``int``)i];``                ``a[(``int``)i] = a[(``int``)j];``                ``a[(``int``)j] = temp;``            ``}` `            ``// Right Shift N by 1``            ``long` `bit = n >> ``1``;``            ``for` `(; (j & bit) != ``0``; bit >>= ``1``)``                ``j ^= bit;``            ``j ^= bit;``        ``}` `        ``// Perform the iterative fft``        ``for` `(``long` `len = ``2``; len <= n; len <<= ``1``) {``            ``long` `wlen = w0;``            ``for` `(``long` `aux = n; aux > len; aux >>= ``1``) {``                ``wlen = wlen * wlen % mod;``            ``}` `            ``for` `(``long` `bat = ``0``; bat + len <= n; bat += len) {``                ``for` `(``long` `i = bat, w = ``1``; i < bat + len / ``2``;``                     ``i++, w = w * wlen % mod) {``                    ``long` `u = a[(``int``)i],``                         ``v``                         ``= w * a[(``int``)(i + len / ``2``)] % mod;` `                    ``// Update the value of a[i]``                    ``a[(``int``)i] = (u + v) % mod;` `                    ``// Update the value``                    ``// of a[i + len/2]``                    ``a[(``int``)(i + len / ``2``)]``                        ``= ((u - v) % mod + mod) % mod;``                ``}``            ``}``        ``}``    ``}` `    ``// Function to find (a ^ x) % mod``    ``public` `static` `long` `binpow(``long` `a, ``long` `x)``    ``{``        ``// Stores the result of a ^ x``        ``long` `ans = ``1``;` `        ``// Iterate over the value of x``        ``for` `(; x != ``0``; x /= ``2``, a = a * a % mod) {``            ``// If x is odd``            ``if` `((x & ``1``) != ``0``)``                ``ans = ans * a % mod;``        ``}` `        ``// Return the resultant value``        ``return` `ans;``    ``}` `    ``// Function to find the``    ``// inverse of a % mod``    ``public` `static` `long` `inv(``long` `a)``    ``{``        ``return` `binpow(a, mod - ``2``);``    ``}` `    ``// Function to find the``    ``// convolution of two arrays``    ``static` `void` `FindConvolution(``long``[] a, ``long``[] b, ``int` `n,``                                ``int` `m)``    ``{``        ``// Stores the first power of 2``        ``// greater than or equal to n + m``        ``int` `_n = nextPowerOf2(n + m);` `        ``// Stores the primitive root``        ``long` `w = ``15311432``;``        ``for` `(``int` `aux = ``1` `<< ``23``; aux > _n; aux >>= ``1``)``            ``w = w * w % mod;` `        ``// Convert arrays a[] and b[] to point value form``        ``fft(w, _n, a);``        ``fft(w, _n, b);` `        ``// Perform multiplication``        ``for` `(``int` `i = ``0``; i < _n; i++)``            ``a[i] = a[i] * b[i] % mod;` `        ``// Perform inverse fft to recover final array``        ``fft(inv(w), _n, a);` `        ``for` `(``int` `i = ``0``; i < _n; i++)``            ``a[i] = a[i] * inv(_n) % mod;` `        ``for` `(``int` `i = ``0``; i < n + m - ``1``; i++)``            ``System.out.print(a[i] + ``" "``);``    ``}` `    ``public` `static` `void` `main(String[] args)``    ``{``        ``// Given size of the arrays``        ``int` `N = ``4``, M = ``5``;` `        ``// Fill the arrays``        ``for` `(``int` `i = ``0``; i < N; i++)``            ``a[i] = i + ``1``;` `        ``for` `(``int` `i = ``0``; i < M; i++)``            ``b[i] = ``5` `+ i;` `        ``FindConvolution(a, b, N, M);``    ``}``}`

## C#

 `using` `System;``using` `System.Numerics;``using` `System.Collections.Generic;` `public` `class` `GFG {``  ``public` `const` `long` `mod = 998244353;``  ``public` `const` `int` `maxn = 3000000;` `  ``public` `static` `long``[] a = ``new` `long``[maxn];``  ``public` `static` `long``[] b = ``new` `long``[maxn];` `  ``static` `int` `nextPowerOf2(``int` `N)``  ``{` `    ``// if N is a power of two simply return it``    ``if` `((N & (N - 1)) == 0)``      ``return` `N;` `    ``// else set only the left bit of most significant``    ``// bit``    ``return` `Convert.ToInt32(``      ``"1"``      ``+ ``new` `string``(``'0'``,``                   ``Convert.ToString(N, 2).Length),``      ``2);``  ``}``  ``// Iterative fft function to compute``  ``// the DFT of given coefficient vector``  ``public` `static` `void` `fft(``long` `w0, ``long` `n, ``long``[] a)``  ``{``    ``// Do bit reversal of the given array``    ``for` `(``long` `i = 0, j = 0; i < n; i++) {``      ``// Swap a[i] and a[j]``      ``if` `(i < j) {``        ``long` `temp = a[i];``        ``a[i] = a[j];``        ``a[j] = temp;``      ``}` `      ``// Right Shift N by 1``      ``long` `bit = n >> 1;``      ``for` `(; (j & bit) != 0; bit >>= 1)``        ``j ^= bit;``      ``j ^= bit;``    ``}` `    ``// Perform the iterative fft``    ``for` `(``long` `len = 2; len <= n; len <<= 1) {``      ``long` `wlen = w0;``      ``for` `(``long` `aux = n; aux > len; aux >>= 1) {``        ``wlen = wlen * wlen % mod;``      ``}` `      ``for` `(``long` `bat = 0; bat + len <= n; bat += len) {``        ``for` `(``long` `i = bat, w = 1; i < bat + len / 2;``             ``i++, w = w * wlen % mod) {``          ``long` `u = a[i],``          ``v = w * a[i + len / 2] % mod;` `          ``// Update the value of a[i]``          ``a[i] = (u + v) % mod;` `          ``// Update the value``          ``// of a[i + len/2]``          ``a[i + len / 2]``            ``= ((u - v) % mod + mod) % mod;``        ``}``      ``}``    ``}``  ``}` `  ``// Function to find (a ^ x) % mod``  ``public` `static` `long` `binpow(``long` `a, ``long` `x)``  ``{``    ``// Stores the result of a ^ x``    ``long` `ans = 1;` `    ``// Iterate over the value of x``    ``for` `(; x != 0; x /= 2, a = a * a % mod) {``      ``// If x is odd``      ``if` `((x & 1) != 0)``        ``ans = ans * a % mod;``    ``}` `    ``// Return the resultant value``    ``return` `ans;``  ``}` `  ``// Function to find the``  ``// inverse of a % mod``  ``public` `static` `long` `inv(``long` `a)``  ``{``    ``return` `binpow(a, mod - 2);``  ``}` `  ``// Function to find the``  ``// convolution of two arrays``  ``// Function to find the convolution of two arrays``  ``static` `void` `FindConvolution(``long``[] a, ``long``[] b, ``int` `n,``                              ``int` `m)``  ``{``    ``// Stores the first power of 2``    ``// greater than or equal to n + m``    ``int` `_n = nextPowerOf2(n + m);` `    ``// Stores the primitive root``    ``long` `w = 15311432;``    ``for` `(``int` `aux = 1 << 23; aux > _n; aux >>= 1)``      ``w = w * w % mod;` `    ``// Convert arrays a[] and b[] to point value form``    ``fft(w, _n, a);``    ``fft(w, _n, b);` `    ``// Perform multiplication``    ``for` `(``int` `i = 0; i < _n; i++)``      ``a[i] = a[i] * b[i] % mod;` `    ``// Perform inverse fft to recover final array``    ``fft(inv(w), _n, a);` `    ``for` `(``int` `i = 0; i < _n; i++)``      ``a[i] = a[i] * inv(_n) % mod;` `    ``for` `(``int` `i = 0; i < n + m - 1; i++)``      ``Console.Write(a[i] + ``" "``);``  ``}` `  ``public` `static` `void` `Main(``string``[] args)``  ``{``    ``// Given size of the arrays``    ``int` `N = 4, M = 5;` `    ``// Fill the arrays``    ``for` `(``int` `i = 0; i < N; i++)``      ``a[i] = i + 1;` `    ``for` `(``int` `i = 0; i < M; i++)``      ``b[i] = 5 + i;` `    ``FindConvolution(a, b, N, M);``  ``}``}`

Output:

`5 16 34 60 70 70 59 36`

Time Complexity: O(N*log(N))
Auxiliary Space: O(N + M)

My Personal Notes arrow_drop_up