Color N boxes using M colors such that K boxes have different color from the box on its left
Last Updated :
29 Oct, 2023
Given N number of boxes arranged in a row and M number of colors. The task is to find the number of ways to paint those N boxes using M colors such that there are exactly K boxes with a color different from the color of the box on its left. Print this answer modulo 998244353.
Examples:
Input: N = 3, M = 3, K = 0
Output: 3
Since the value of K is zero, no box can have a different color from color of the box on its left. Thus, all boxes should be painted with same color and since there are 3 types of colors, so there are total 3 ways.
Input: N = 3, M = 2, K = 1
Output: 4
Let’s number the colors as 1 and 2. Four possible sequences of painting 3 boxes with 1 box having different color from color of box on its left are (1 2 2), (1 1 2), (2 1 1) (2 2 1)
Prerequisites : Dynamic Programming
Approach: This problem can be solved using dynamic programming where dp[i][j] will denote the number of ways to paint i boxes using M colors such that there are exactly j boxes with a color different from the color of the box on its left. For every current box except 1st, either we can paint the same color as painted on its left box and solve for dp[i – 1][j] or we can paint it with remaining M – 1 color and solve for dp[i – 1][j – 1] recursively.
Steps to follow to according to above approach:
- If idx is greater than N, then check if diff is equal to K.
- *If diff is equal to K, then return 1.
*Else, return 0.
- If the result for the current value of idx and diff is not equal to -1 then return the precomputed result dp[idx].
- Otherwise, recursively call solve() function with idx+1, diff, N, M, and K and store the result in the variable ans.
- recursively call solve() function with idx+1, diff+1, N, M, and K, and multiply the result with (M-1).
- add the value obtained in step 5 to the result obtained in step 4, and take its modulo with MOD.
- store the result obtained in step 6 in the dp array for the current value of idx and diff, and return the same value.
Below is the code to implement the above approach:
C++
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
vector<vector< int >> dp;
int solve( int idx, int diff, int N, int M, int K)
{
if (idx > N) {
if (diff == K)
return 1;
return 0;
}
if (dp[idx][ diff] != -1)
return dp[idx][ diff];
int ans = solve(idx + 1, diff, N, M, K);
ans = ans % MOD + ((M - 1) % MOD * solve(idx + 1, diff + 1, N, M, K) % MOD) % MOD;
return dp[idx][ diff] = ans;
}
int main()
{
int N = 3, M = 3, K = 0;
dp = vector<vector< int >>(N+1,vector< int >(N+1,-1));
cout << (M * solve(2, 0, N, M, K)) << endl;
return 0;
}
|
Java
class GFG
{
static int M = 1001 ;
static int MOD = 998244353 ;
static int [][] dp = new int [M][M];
static int solve( int idx, int diff,
int N, int M, int K)
{
if (idx > N)
{
if (diff == K)
return 1 ;
return 0 ;
}
if (dp[idx][ diff] != - 1 )
return dp[idx][ diff];
int ans = solve(idx + 1 , diff, N, M, K);
ans += (M - 1 ) * solve(idx + 1 ,
diff + 1 , N, M, K);
return dp[idx][ diff] = ans % MOD;
}
public static void main (String[] args)
{
int N = 3 , M = 3 , K = 0 ;
for ( int i = 0 ; i <= M; i++)
for ( int j = 0 ; j <= M; j++)
dp[i][j] = - 1 ;
System.out.println((M * solve( 2 , 0 , N, M, K)));
}
}
|
Python3
M = 1001 ;
MOD = 998244353 ;
dp = [[ - 1 ] * M ] * M
def solve(idx, diff, N, M, K) :
if (idx > N) :
if (diff = = K) :
return 1
return 0
if (dp[idx][ diff] ! = - 1 ) :
return dp[idx];
ans = solve(idx + 1 , diff, N, M, K);
ans + = (M - 1 ) * solve(idx + 1 , diff + 1 , N, M, K);
dp[idx][ diff] = ans % MOD;
return dp[idx][ diff]
if __name__ = = "__main__" :
N = 3
M = 3
K = 0
print (M * solve( 2 , 0 , N, M, K))
|
C#
using System;
class GFG
{
static int M = 1001;
static int MOD = 998244353;
static int [,] dp = new int [M, M];
static int solve( int idx, int diff,
int N, int M, int K)
{
if (idx > N)
{
if (diff == K)
return 1;
return 0;
}
if (dp[idx, diff] != -1)
return dp[idx, diff];
int ans = solve(idx + 1, diff, N, M, K);
ans += (M - 1) * solve(idx + 1,
diff + 1, N, M, K);
return dp[idx, diff] = ans % MOD;
}
public static void Main ()
{
int N = 3, M = 3, K = 0;
for ( int i = 0; i <= M; i++)
for ( int j = 0; j <= M; j++)
dp[i, j] = -1;
Console.WriteLine((M * solve(2, 0, N, M, K)));
}
}
|
Javascript
<script>
let m = 1001;
let MOD = 998244353;
let dp = new Array(m);
for (let i = 0; i < m; i++)
{
dp[i] = new Array(m);
for (let j = 0; j < m; j++)
{
dp[i][j] = 0;
}
}
function solve(idx, diff, N, M, K)
{
if (idx > N)
{
if (diff == K)
return 1;
return 0;
}
if (dp[idx][ diff] != -1)
return dp[idx][ diff];
let ans = solve(idx + 1, diff, N, M, K);
ans += (M - 1) * solve(idx + 1,
diff + 1, N, M, K);
dp[idx][ diff] = ans % MOD;
return dp[idx][ diff];
}
let N = 3, M = 3, K = 0;
for (let i = 0; i <= M; i++)
for (let j = 0; j <= M; j++)
dp[i][j] = -1;
document.write((M * solve(2, 0, N, M, K)));
</script>
|
PHP
<?php
$M = 1001;
$MOD = 998244353;
$dp = array_fill (0, $M ,
array_fill (0, $M , -1));
function solve( $idx , $diff , $N , $M , $K )
{
global $dp , $MOD ;
if ( $idx > $N )
{
if ( $diff == $K )
return 1;
return 0;
}
if ( $dp [ $idx ][ $diff ] != -1)
return $dp [ $idx ][ $diff ];
$ans = solve( $idx + 1, $diff , $N , $M , $K );
$ans += ( $M - 1) * solve( $idx + 1,
$diff + 1, $N , $M , $K );
return $dp [ $idx ][ $diff ] = $ans % $MOD ;
}
$N = 3;
$M = 3;
$K = 0;
echo ( $M * solve(2, 0, $N , $M , $K ));
?>
|
Time Complexity: O(M*M)
Auxiliary Space: O(M*M)
Efficient approach : Using DP Tabulation method ( Iterative approach )
The approach to solve this problem is same but DP tabulation(bottom-up) method is better then Dp + memoization(top-down) because memoization method needs extra stack space of recursion calls.
Steps to solve this problem :
- Create a DP to store the solution of the subproblems .
- Initialize the DP with base cases by initializing the first row of DP.
- Now Iterate over subproblems to get the value of current problem form previous computation of subproblems stored in DP.
- Return the final solution stored in dp[1][K].
Implementation :
C++
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
int dp[2][1010];
int solve( int N, int M, int K)
{
for ( int i = 0; i <= N; i++) {
dp[0][i] = 1;
}
for ( int i = 2; i <= N + 1; i++) {
for ( int j = 0; j <= N; j++) {
int ans = dp[0][j];
if (j == 0) {
ans = (M * dp[1][j]) % MOD;
} else {
ans = (ans % MOD + ((M - 1) % MOD * dp[1][j - 1]) % MOD) % MOD;
}
dp[0][j] = ans;
}
swap(dp[0], dp[1]);
}
return dp[1][K];
}
int main()
{
int N = 3, M = 3, K = 0;
cout << (M * solve(N, M, K + 1)) % MOD << endl;
return 0;
}
|
Java
import java.util.*;
public class Main {
static final int MOD = 998244353 ;
static int [][] dp;
static int solve( int N, int M, int K) {
for ( int i = 0 ; i <= N; i++) {
dp[ 0 ][i] = 1 ;
}
for ( int i = 2 ; i <= N + 1 ; i++) {
for ( int j = 0 ; j <= N; j++) {
int ans = dp[ 0 ][j];
if (j == 0 ) {
ans = (M * dp[ 1 ][j]) % MOD;
} else {
ans = (ans % MOD + ((M - 1 ) % MOD * dp[ 1 ][j - 1 ]) % MOD) % MOD;
}
dp[ 0 ][j] = ans;
}
int [] temp = dp[ 0 ];
dp[ 0 ] = dp[ 1 ];
dp[ 1 ] = temp;
}
return dp[ 1 ][K];
}
public static void main(String[] args) {
int N = 3 , M = 3 , K = 0 ;
dp = new int [ 2 ][ 1010 ];
System.out.println((M * solve(N, M, K + 1 )) % MOD);
}
}
|
Python3
MOD = 998244353
def solve(N, M, K):
global dp
for i in range (N + 1 ):
dp[ 0 ][i] = 1
for i in range ( 2 , N + 2 ):
for j in range (N + 1 ):
ans = dp[ 0 ][j]
if j = = 0 :
ans = (M * dp[ 1 ][j]) % MOD
else :
ans = (ans % MOD + ((M - 1 ) % MOD * dp[ 1 ][j - 1 ]) % MOD) % MOD
dp[ 0 ][j] = ans
dp[ 0 ], dp[ 1 ] = dp[ 1 ], dp[ 0 ]
return dp[ 1 ][K]
if __name__ = = '__main__' :
N, M, K = 3 , 3 , 0
dp = [[ 0 ] * 1010 for _ in range ( 2 )]
print ((M * solve(N, M, K + 1 )) % MOD)
|
C#
using System;
public class MainClass
{
static readonly int MOD = 998244353;
static int [,] dp;
static int Solve( int N, int M, int K)
{
for ( int i = 0; i <= N; i++)
{
dp[0, i] = 1;
}
for ( int i = 2; i <= N + 1; i++)
{
for ( int j = 0; j <= N; j++)
{
int ans = dp[0, j];
if (j == 0)
{
ans = ( int )((M * ( long )dp[1, j]) % MOD);
}
else
{
ans = ( int )((ans % MOD + ((M - 1) % MOD * ( long )dp[1, j - 1]) % MOD) % MOD);
}
dp[0, j] = ans;
}
for ( int k = 0; k <= N; k++)
{
int temp = dp[0, k];
dp[0, k] = dp[1, k];
dp[1, k] = temp;
}
}
return dp[1, K];
}
public static void Main( string [] args)
{
int N = 3, M = 3, K = 0;
dp = new int [2, 1010];
Console.WriteLine(( int )((M * ( long )Solve(N, M, K + 1)) % MOD));
}
}
|
Javascript
const MOD = 998244353;
function solve(N, M, K) {
const dp = new Array(2).fill().map(() => new Array(N + 1).fill(0));
for (let i = 0; i <= N; i++) {
dp[0][i] = 1;
}
for (let i = 2; i <= N + 1; i++) {
for (let j = 0; j <= N; j++) {
let ans = dp[0][j];
if (j === 0) {
ans = (M * dp[1][j]) % MOD;
} else {
ans = (ans % MOD + ((M - 1) % MOD * dp[1][j - 1]) % MOD) % MOD;
}
dp[0][j] = ans;
}
[dp[0], dp[1]] = [dp[1], dp[0]];
}
return dp[1][K];
}
const N = 3, M = 3, K = 0;
console.log((M * solve(N, M, K + 1)) % MOD);
|
Output:
3
Time Complexity: O(N * N)
Auxiliary Space: O(N)
Like Article
Suggest improvement
Share your thoughts in the comments
Please Login to comment...