vhbbvjhb
2026-03-31 17:53:44
发布于:浙江
#include <bits/stdc++.h>
using namespace std;
const int maxn = 5, maxm = 55, maxl = 155, mod = 11192869, mo = 500000;
typedef long long ll;
typedef unsigned int uint;
int a[maxn][maxm], L[maxl];
int n, m, pos[maxn], plug[maxn], head[2][500010], tot[2], cur, pre, chc[maxl], ans;
struct State{
bitset<maxl> used;
uint stt;
int val, nxt;
State() { val = nxt = stt = 0; used.reset(); }
}ptr[2][1001000];
void hah(uint stt, int val, bitset<maxl> &used) {
int x = stt % mo;
for(int i = head[cur][x]; i; i = ptr[cur][i].nxt) if(ptr[cur][i].stt == stt) {
ptr[cur][i].val = (ptr[cur][i].val + val) % mod; return;
}
ptr[cur][tot[cur]].stt = stt;
ptr[cur][tot[cur]].val = val;
ptr[cur][tot[cur]].used = used;
ptr[cur][tot[cur]].nxt = head[cur][x];
head[cur][x] = tot[cur];
}
uint encode() {
uint stt = 0;
for(int i = 1; i <= n; i) stt = (stt << 8) + pos[i];
for(int i = 0; i <= n; i++) stt = (stt << 2) + plug[i];
return stt;
}
void decode(uint stt) {
for(int i = n; i >= 0; i--) plug[i] = stt & 3, stt >>= 2;
for(int i = n; i; i--) pos[i] = stt & 255, stt >>= 8;
}
void solve() {
bitset<maxl> used;
used.reset();
cur = 0; pre = 1; hah(0, 1, used);
for(int j = 1; j <= m; j++) {
// 新的一行要把plug整体右移
for(int t = 1; t <= tot[cur]; t++) {
decode(ptr[cur][t].stt);
for(int i = n - 1; i >= 0; i--) plug[i + 1] = plug[i];
plug[0] = 0;
ptr[cur][t].stt = encode();
}
for(int i = 1; i <= n; i++) {
swap(cur, pre); tot[cur] = 0;
memset(head[cur], 0, sizeof(head[cur]));
for(int t = 1; t <= tot[pre]; t++) {
uint stt = ptr[pre][t].stt;
int val = ptr[pre][t].val;
used = ptr[pre][t].used;
decode(stt);
int r = plug[i - 1], d = plug[i];
int cnt = 0;
if(!r && !d) for(int i = 1; i <= n * m; i++) chc[++cnt] = i;
else {
if(r == 1) chc[++cnt] = pos[i-1] - 1;
else if(r == 2) chc[++cnt] = pos[i-1] + 1;
if(d == 1) chc[++cnt] = pos[i] - 1;
else if(d == 2) chc[++cnt] = pos[i] + 1;
}
sort(chc + 1, chc + 1 + cnt);
cnt = unique(chc + 1, chc + 1 + cnt) - chc - 1;
for(int hh = 1; hh <= cnt; hh++) {
int x = chc[hh];
if(a[i][j] != L[x]) continue; if(used[x]) continue;
if(r == 1 && x != pos[i - 1] - 1) continue;
if(r == 2 && x != pos[i - 1] + 1) continue;
if(d == 1 && x != pos[i] - 1) continue;
if(d == 2 && x != pos[i] + 1) continue;
if(x == 1 && i > 1 && i < n && j > 1 && j < m) continue;
if(i == n && j == m) ans = (ans + val) % mod;
used[x] = 1; int od = pos[i]; pos[i] = x;
for(int npr = 0; npr <= 2; npr++)
for(int npd = 0; npd <= 2; npd++) {
int pnum = (r > 0) + (d > 0) + (npr > 0) + (npd > 0);
if(x != 1 && x != n * m && pnum != 2) continue;
if((x == 1 || x == n * m) && pnum != 1) continue;
if(npr == npd && npr) continue;
if(j == m && npr) continue; if(i == n && npd) continue;
if((npr == 1 || npd == 1) && used[x - 1]) continue;
if((npr == 2 || npd == 2) && used[x + 1]) continue;
if(npr == 1 && a[i][j+1] != L[x - 1]) continue;
if(npr == 2 && a[i][j+1] != L[x + 1]) continue;
if(npd == 1 && a[i+1][j] != L[x - 1]) continue;
if(npd == 2 && a[i+1][j] != L[x + 1]) continue;
// 当前转移合法,更新下一位置的状态和dp值
plug[i - 1] = npr; plug[i] = npd;
hah(encode(), val, used);
plug[i - 1] = r; plug[i] = d;
}
used[x] = 0; pos[i] = od;
}
}
}
}
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) for(int j = 1; j <= m; j++) scanf("%d", &a[i][j]);
for(int i = 1; i <= n * m; i++) scanf("%d", &L[i]);
L[0] = L[n * m + 1] = 521;
for(int i = 0; i <= m + 1; i++) a[0][i] = a[n + 1][i] = 233;
for(int i = 1; i <= n; i++) a[i][0] = a[i][m + 1] = 233;
solve();
printf("%d\n", ans);
return 0;
}
这里空空如也





有帮助,赞一个