难
2026-03-23 22:57:46
发布于:河北
4阅读
0回复
0点赞
好像是正赛场切过的最牛的题。
考虑 dp。先考虑设 f i,j
记得滚动数组。
赛后回忆随手写的丑陋的代码,仅供参考:
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define pii pair<int,int>
#define fir first
#define sec second
#define chmin(a,b) a=min(a,b)
#define chmax(a,b) a=max(a,b)
#define pb push_back
const int inf=0x3f3f3f3f3f3f3f3f;
const int mod=998244353;
int n,m,dp[2][510][510],fac[510],C[510][510],P[510][510],ans;
string s;
int c[510],rm[510];
signed main()
{
cin>>n>>m>>s;
s=" "+s;
for(int i=1;i<=n;i++)
{
int x;
cin>>x;
c[x]++;
for(int j=0;j<x;j++)rm[j]++;
}
fac[0]=1;
for(int i=1;i<=n;i++)fac[i]=fac[i-1]*i%mod;
for(int i=0;i<=n;i++)
{
C[i][0]=1;
for(int j=1;j<=i;j++)C[i][j]=(C[i-1][j]+C[i-1][j-1])%mod;
for(int j=0;j<=i;j++)P[i][j]=C[i][j]*fac[j]%mod;
}
dp[0][0][0]=1;
for(int i=1;i<=n;i++)
{
memset(dp[i&1],0,sizeof(dp[i&1]));
bool ps=s[i]-'0';
for(int j=0;j<i;j++)if(j<=n-m)
for(int k=0;k<i;k++)
{
int tx=c[j+1];
int tmp=dp[i&1^1][j][k];
if(!tmp)continue;
if(ps)(dp[i&1][j][k+1]+=tmp)%=mod;
int pr=(n-rm[j])-(i-1-k);
if(pr)for(int l=min(k,tx);~l;l--)(dp[i&1][j+1][k-l]+=tmp*pr%mod*C[k][l]%mod*P[tx][l]%mod)%=mod;
if(!ps)for(int l=min(k+1,tx);~l;l--)(dp[i&1][j+1][k+1-l]+=tmp%mod*C[k+1][l]%mod*P[tx][l]%mod)%=mod;
}
}
for(int i=0;i<=n-m;i++)(ans+=dp[n&1][i][rm[i]]*fac[rm[i]]%mod)%=mod;
cout<<ans<<endl;
return 0;
}
这里空空如也





有帮助,赞一个