import
java.util.ArrayList;
import
java.util.Arrays;
import
java.util.List;
public
class
CountNumbersWithXOR {
static
int
[][][] dp;
static
int
countNumbers(List<Integer> digits,
int
index,
boolean
smaller,
int
currXOR,
int
k) {
if
(index == digits.size()) {
return
currXOR == k ?
1
:
0
;
}
if
(dp[index][smaller ?
1
:
0
][currXOR] != -
1
) {
return
dp[index][smaller ?
1
:
0
][currXOR];
}
int
limit = smaller || digits.get(index) <
1
?
9
: digits.get(index);
int
count =
0
;
for
(
int
digit =
0
; digit <= limit; digit++) {
int
newCurrXOR = currXOR ^ digit;
count += countNumbers(digits, index +
1
, smaller || digit < digits.get(index), newCurrXOR, k);
}
dp[index][smaller ?
1
:
0
][currXOR] = count;
return
count;
}
static
int
countNumbersWithXOR(
int
num1,
int
num2,
int
k) {
List<Integer> digits =
new
ArrayList<>();
int
temp = num1;
while
(temp >
0
) {
digits.add(temp %
10
);
temp /=
10
;
}
int
n = digits.size();
dp =
new
int
[n][
2
][
1024
];
for
(
int
i =
0
; i < n; i++) {
for
(
int
j =
0
; j <
2
; j++) {
Arrays.fill(dp[i][j], -
1
);
}
}
int
count1 = countNumbers(digits,
0
,
false
,
0
, k);
digits.clear();
temp = num2;
while
(temp >
0
) {
digits.add(temp %
10
);
temp /=
10
;
}
n = digits.size();
dp =
new
int
[n][
2
][
1024
];
for
(
int
i =
0
; i < n; i++) {
for
(
int
j =
0
; j <
2
; j++) {
Arrays.fill(dp[i][j], -
1
);
}
}
int
count2 = countNumbers(digits,
0
,
false
,
0
, k);
return
count2 - count1;
}
public
static
void
main(String[] args) {
int
num1 =
100
;
int
num2 =
500
;
int
k =
3
;
int
result = countNumbersWithXOR(num1, num2, k);
System.out.println(
"Count of numbers = "
+ result);
}
}