题解
2025-08-09 19:22:35
发布于:广东
2阅读
0回复
0点赞
#include <bits/stdc++.h>
using namespace std;
const int Maxn = 500005;
int now, ct, fail[Maxn], Q[2 * Maxn], seq[Maxn], border[Maxn], pos[Maxn];
long long f[Maxn], res[Maxn], sta[Maxn], ans, w;
string str;
void get_fail(void){
int siz = str.size();
fail[0] = fail[1] = 0;
for (int i = 1; i < siz; i++){
int tmp = fail[i];
while (tmp && str[tmp] != str[i]) tmp = fail[tmp];
fail[i + 1] = str[tmp] == str[i] ? tmp + 1 : 0;
}
int now = fail[siz];
while (now){
border[++ct] = siz - now;
now = fail[now];
}
border[++ct] = siz;
}
void change_mod(int mod){
int cnt = __gcd(mod, now);
for (int i = 0; i < now; i++)
res[i] = f[i];
for (int i = 0; i < mod; i++)
f[i] = 0x3f3f3f3f3f3f3f3fLL;
for (int i = 0, tmp; i < now; i++)
tmp = res[i] % mod, f[tmp] = min(f[tmp], res[i]);
for (int i = 0; i < cnt; i++){
int top = 0;
Q[++top] = i;
int tmp = (i + now) % mod;
while (tmp != Q[1])
Q[++top] = tmp, tmp = (tmp + now) % mod;
for (int j = top + 1; j <= 2 * top; j++)
Q[j] = Q[j - top];
top <<= 1;
for (int j = 2; j <= top; j++)
f[Q[j]] = min(f[Q[j]], f[Q[j - 1]] + now);
}
now = mod;
}
void work(int first, int diff, int siz){
int cnt = __gcd(diff, first);
change_mod(first);
if (diff < 0) return ;
for (int i = 0; i < cnt; i++){
int top = 0;
Q[++top] = i;
int tmp = (i + diff) % first;
while (tmp != Q[1])
Q[++top] = tmp, tmp = (tmp + diff) % first;
int mini_pos = 1;
for (int j = 1; j <= top; j++)
if (f[Q[j]] < f[Q[mini_pos]]) mini_pos = j;
int tmp_cnt = 0;
for (int j = mini_pos; j <= top; j++)
seq[++tmp_cnt] = Q[j];
for (int j = 1; j < mini_pos; j++)
seq[++tmp_cnt] = Q[j];
int head = 1, tail = 1;
pos[1] = 1, sta[1] = f[seq[1]] - diff;
for (int j = 2; j <= top; j++){
while (head <= tail && pos[head] + siz < j) head++;
if (head <= tail) f[seq[j]] = min(f[seq[j]], sta[head] + j * (long long) diff + first);
while (head <= tail && sta[tail] >= f[seq[j]] - j * (long long) diff) tail--;
sta[++tail] = f[seq[j]] - j * (long long) diff, pos[tail] = j;
}
}
}
int T, n;
int main(){
scanf("%d", &T);
while (T--){
ans = ct = 0;
scanf("%d%lld", &n, &w), w -= n;
memset(f, 0x3f, sizeof(long long[n]));
f[0] = 0;
now = n;
cin >> str;
get_fail();
for (int i = 1, j = 1; i <= ct; i = j){
while (border[j + 1] - border[j] == border[i + 1] - border[i]) j++;
work(border[i], border[i + 1] - border[i], j - i - 1);
}
for (int i = 0; i < now; i++)
if (f[i] <= w) ans += (w - f[i]) / now + 1;
printf("%lld\n", ans);
}
return 0;
}
这里空空如也
有帮助,赞一个