题解(求赞)
2025-11-12 21:13:09
发布于:福建
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <ctime>
#include <cctype>
#include <algorithm>
#include <random>
#include <bitset>
#include <queue>
#include <functional>
#include <set>
#include <map>
#include <vector>
#include <chrono>
#include <iostream>
#include <limits>
#include <numeric>
#define LOG(FMT...) fprintf(stderr, FMT)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 110, M = 50010, P = 998244353;
int n, m;
int s[N], p[N], a[M], b[M];
int norm(int x) { return x >= P ? x - P : x; }
void exGcd(int a, int b, int& x, int& y) {
if (!b) {
x = 1;
y = 0;
return;
}
exGcd(b, a % b, y, x);
y -= a / b * x;
}
int inv(int a) {
int x, y;
exGcd(a, P, x, y);
return norm(x + P);
}
int calc(int* arr) {
int ret = 0;
for (int i = 0; i < m; ++i) {
int q = norm(2LL * i * inv(m) % P + P - 1);
ret = (ret + arr[i] * (ll)inv(norm(P + q - 1))) % P;
}
return ret;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &s[i]);
for (int i = 1; i <= n; ++i)
scanf("%d", &p[i]);
a[0] = 1;
b[0] = 1;
for (int i = 1; i <= n; ++i) {
m += p[i];
if (s[i] == 1) {
for (int j = m; j >= p[i]; --j)
a[j] = norm(P + a[j - p[i]] - a[j]);
for (int j = p[i] - 1; j >= 0; --j)
a[j] = norm(P - a[j]);
} else
for (int j = m; j >= p[i]; --j)
a[j] = norm(a[j - p[i]] + a[j]);
for (int j = m; j >= p[i]; --j)
b[j] = norm(b[j - p[i]] + b[j]);
}
int f = calc(a), g = calc(b);
int ans = norm(f + P - g);
printf("%d\n", ans);
return 0;
}
这里空空如也







有帮助,赞一个