tj
2025-09-13 20:11:53
发布于:广东
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MOD = 998244353;
const int N = 105, M = 2005;
int n, m;
ll a[N][M], s[N];
ll dp[2][2 * N + 5];
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
s[i] = 0;
for (int j = 1; j <= m; ++j) {
cin >> a[i][j];
s[i] = (s[i] + a[i][j]) % MOD;
}
}
ll total = 1;
for (int i = 1; i <= n; ++i) {
total = total * (s[i] + 1) % MOD;
}
total = (total - 1 + MOD) % MOD;
ll invalid = 0;
const int offset = N;
for (int j = 1; j <= m; ++j) {
memset(dp, 0, sizeof(dp));
int cur = 0;
dp[cur][offset] = 1;
for (int i = 1; i <= n; ++i) {
int prev = cur;
cur ^= 1;
memset(dp[cur], 0, sizeof(dp[cur]));
ll c1 = a[i][j] % MOD;
ll c2 = (s[i] - c1 + MOD) % MOD;
for (int d_prev = 0; d_prev <= 2 * N; ++d_prev) {
if (dp[prev][d_prev] == 0) continue;
dp[cur][d_prev] = (dp[cur][d_prev] + dp[prev][d_prev]) % MOD;
int d_new = d_prev + 1;
if (d_new <= 2 * N) {
dp[cur][d_new] = (dp[cur][d_new] + dp[prev][d_prev] * c1) % MOD;
}
d_new = d_prev - 1;
if (d_new >= 0 && d_new <= 2 * N) {
dp[cur][d_new] = (dp[cur][d_new] + dp[prev][d_prev] * c2) % MOD;
}
}
}
for (int d = offset + 1; d <= 2 * N; ++d) {
invalid = (invalid + dp[cur][d]) % MOD;
}
}
ll ans = (total - invalid + MOD) % MOD;
cout << ans << '\n';
return 0;
}
这里空空如也
有帮助,赞一个