AC题解
2026-04-11 19:58:29
发布于:重庆
2阅读
0回复
0点赞
问题本质
求所有 i∈[0,n−1]i∈[0,n−1],j∈[0,m−1]j∈[0,m−1] 中,max((i⊕j)−k, 0)
max((i⊕j)−k, 0) 的总和。
等价于:
先找出所有 i⊕j>ki⊕j>k 的格子
对它们的 (i⊕j) 求和,再减去 k*格子个数
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int T;
ll n, m, k, p;
int bits_n[61], bits_m[61], bits_k[61];
pair<ll, ll> memo[61][2][2][2];
bool vis[61][2][2][2];
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int T;
ll n, m, k, p;
int bits_n[61], bits_m[61], bits_k[61];
pair<ll, ll> memo[61][2][2][2];
bool vis[61][2][2][2];
// 数位DP,返回 (符合条件的(i,j)对数, 这些对中(i xor j)的总和) 均对 p 取模
pair<ll, ll> dfs(int pos, int lim_n, int lim_m, int gt) {
if (pos == -1) {
// 所有位处理完,如果已经大于 k 则有效,否则无效
if (gt) return {1, 0};
else return {0, 0};
}
if (vis[pos][lim_n][lim_m][gt]) return memo[pos][lim_n][lim_m][gt];
int n_bit = bits_n[pos];
int m_bit = bits_m[pos];
int k_bit = bits_k[pos];
int max_n = lim_n ? n_bit : 1;
int max_m = lim_m ? m_bit : 1;
ll cnt = 0, sum = 0;
for (int x = 0; x <= max_n; ++x) {
for (int y = 0; y <= max_m; ++y) {
int z = x ^ y; // 当前位的异或值
int new_lim_n = lim_n && (x == n_bit);
int new_lim_m = lim_m && (y == m_bit);
int new_gt = gt;
if (!gt) {
if (z > k_bit) new_gt = 1;
else if (z < k_bit) continue; // 已经小于,不可能大于 k
else new_gt = 0; // 相等,继续比较低位
}
auto [c, s] = dfs(pos - 1, new_lim_n, new_lim_m, new_gt);
cnt = (cnt + c) % p;
ll add = ( (z * ((1LL << pos) % p)) % p * c ) % p;
sum = (sum + s + add) % p;
}
}
vis[pos][lim_n][lim_m][gt] = true;
return memo[pos][lim_n][lim_m][gt] = {cnt, sum};
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> T;
while (T--) {
cin >> n >> m >> k >> p;
// 注意:行号 i 的范围是 [0, n-1],列号 j 的范围是 [0, m-1]
ll n1 = n - 1, m1 = m - 1;
for (int i = 0; i < 61; ++i) {
bits_n[i] = (n1 >> i) & 1;
bits_m[i] = (m1 >> i) & 1;
bits_k[i] = (k >> i) & 1;
}
memset(vis, 0, sizeof(vis));
auto [cnt, sum] = dfs(60, 1, 1, 0);
ll ans = (sum - (k % p) * cnt % p + p) % p;
cout << ans << '\n';
}
return 0;
}
这里空空如也







有帮助,赞一个