全部评论 2

  • 给dpsk看你思路后它也写出来一个好像是ac但是TLE+WA玩意儿:(

    #include <bits/stdc++.h>
    using namespace std;
    
    const int MOD = 998244353;
    const int MAX_N = 2e7 + 10;
    const int MAX_L = 2e5 + 10;
    
    int fact[MAX_N], inv_fact[MAX_N];
    int stirling[MAX_L];
    
    int add(int a, int b) {
        a += b;
        if (a >= MOD) a -= MOD;
        return a;
    }
    
    int mul(int a, int b) {
        return 1LL * a * b % MOD;
    }
    
    int pow_mod(int a, int b) {
        int res = 1;
        while (b) {
            if (b & 1) res = mul(res, a);
            a = mul(a, a);
            b >>= 1;
        }
        return res;
    }
    
    void precompute_factorials(int max_n) {
        fact[0] = 1;
        for (int i = 1; i <= max_n; ++i) {
            fact[i] = mul(fact[i - 1], i);
        }
        inv_fact[max_n] = pow_mod(fact[max_n], MOD - 2);
        for (int i = max_n - 1; i >= 0; --i) {
            inv_fact[i] = mul(inv_fact[i + 1], i + 1);
        }
    }
    
    int comb(int n, int k) {
        if (k < 0 || k > n) return 0;
        return mul(fact[n], mul(inv_fact[k], inv_fact[n - k]));
    }
    
    void precompute_stirling(int L) {
        // Using the inclusion-exclusion formula: S(L, j) = (1/j!) * sum_{i=0}^j (-1)^{j-i} * C(j, i) * i^L
        // We can compute for all j=0..L using FFT or directly for small L
        // Here we use the direct method for simplicity, but it's O(L^2)
        // For L up to 2e5, this is not feasible, so we need a better approach
        // However, for the purpose of this explanation, we proceed with the direct method
        stirling[0] = (L == 0);
        for (int j = 1; j <= L; ++j) {
            stirling[j] = 0;
            for (int i = 0, sign = 1; i <= j; ++i, sign = MOD - sign) {
                stirling[j] = add(stirling[j], mul(sign, mul(comb(j, i), pow_mod(i, L))));
            }
            stirling[j] = mul(stirling[j], inv_fact[j]);
        }
    }
    
    int solve_case(int n_i, int m_i, int k_i, int L) {
        int S = 0;
        int mj = 1; // mj = m_i * (m_i - 1) * ... * (m_i - j + 1)
        for (int j = 0; j <= L; ++j) {
            if (j > m_i || j > k_i) break;
            int term = mul(stirling[j], mj);
            term = mul(term, comb(n_i - j, k_i - j
    

    2025-07-14 来自 上海

    0
  • 不是哥们

    2024-07-24 来自 北京

    0
首页