dpsk做的
2025-07-14 06:12:30
发布于:上海
0阅读
0回复
0点赞
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int G = 3, Gi = 332748118;
const int MAX_N = 2e7 + 10;
const int MAX_L = 2e5 + 10;
int fact[MAX_N], inv_fact[MAX_N];
int stirling[MAX_L];
int add(int a, int b) {
a += b;
if (a >= MOD) a -= MOD;
return a;
}
int mul(int a, int b) {
return 1LL * a * b % MOD;
}
int pow_mod(int a, int b) {
int res = 1;
while (b) {
if (b & 1) res = mul(res, a);
a = mul(a, a);
b >>= 1;
}
return res;
}
void precompute_factorials(int max_n) {
fact[0] = 1;
for (int i = 1; i <= max_n; ++i) {
fact[i] = mul(fact[i - 1], i);
}
inv_fact[max_n] = pow_mod(fact[max_n], MOD - 2);
for (int i = max_n - 1; i >= 0; --i) {
inv_fact[i] = mul(inv_fact[i + 1], i + 1);
}
}
int comb(int n, int k) {
if (k < 0 || k > n) return 0;
return mul(fact[n], mul(inv_fact[k], inv_fact[n - k]));
}
void ntt(vector<int>& a, int len, int type) {
for (int i = 1, j = 0; i < len - 1; i++) {
for (int s = len; j ^= s >>= 1, ~j & s;);
if (i < j) swap(a[i], a[j]);
}
for (int h = 2; h <= len; h <<= 1) {
int wn = pow_mod(type == 1 ? G : Gi, (MOD - 1) / h);
for (int j = 0; j < len; j += h) {
int w = 1;
for (int k = j; k < j + h / 2; k++) {
int u = a[k];
int t = mul(w, a[k + h / 2]);
a[k] = add(u, t);
a[k + h / 2] = add(u, MOD - t);
w = mul(w, wn);
}
}
}
if (type == -1) {
int inv_len = pow_mod(len, MOD - 2);
for (int i = 0; i < len; i++) {
a[i] = mul(a[i], inv_len);
}
}
}
void precompute_stirling(int L) {
if (L == 0) {
stirling[0] = 1;
return;
}
int len = 1;
while (len < 2 * L + 1) len <<= 1;
vector<int> A(len, 0), B(len, 0);
for (int i = 0; i <= L; i++) {
A[i] = (i % 2) ? (MOD - inv_fact[i]) : inv_fact[i];
B[i] = (i == 0) ? 0 : mul(pow_mod(i, L), inv_fact[i]);
}
ntt(A, len, 1);
ntt(B, len, 1);
for (int i = 0; i < len; i++) {
A[i] = mul(A[i], B[i]);
}
ntt(A, len, -1);
for (int j = 0; j <= L; j++) {
stirling[j] = A[j];
}
}
int solve_case(int n_i, int m_i, int k_i, int L) {
int S = 0;
int mj = 1;
for (int j = 0; j <= L; ++j) {
if (j > m_i || j > k_i) break;
int term = mul(stirling[j], mj);
term = mul(term, comb(n_i - j, k_i - j));
S = add(S, term);
if (j < m_i) {
mj = mul(mj, (m_i - j));
} else {
mj = 0;
}
}
int C = comb(n_i, k_i);
int inv_C = pow_mod(C, MOD - 2);
return mul(S, inv_C);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N, M, S_val, L;
cin >> N >> M >> S_val >> L;
precompute_factorials(N);
precompute_stirling(L);
for (int i = 0; i < S_val; ++i) {
int n_i, m_i, k_i;
cin >> n_i >> m_i >> k_i;
cout << solve_case(n_i, m_i, k_i, L) << '\n';
}
return 0;
}
这里空空如也
有帮助,赞一个