题解(求赞)
2025-11-12 17:56:56
发布于:福建
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
typedef pair<int,int> PI;
typedef long long LL;
const int N=200005;
const int MAX=(1<<28);
const int SIZ=440;
struct qq
{
int x,y,last;
}e[N];int num,last[N];
void addedge (int x,int y)
{
// printf("link:%d %d\n",x,y);
e[num].x=x;e[num].y=y;
e[num].last=last[x];
last[x]=num;
}
int T;
int n,p,m;
int a[N];
unsigned int SA, SB, SC;
unsigned int rng61(){SA ^= SA << 16;SA ^= SA >> 5;SA ^= SA << 1;unsigned int t = SA;SA = SB;SB = SC;SC ^= t ^ SA;return SC;}
void gen(){scanf("%d%d%u%u%u", &n, &p, &SA, &SB, &SC);for(int i = 2; i <= p; i)addedge(i - 1, i);for(int i = p + 1; i <= n; i++)addedge(rng61() % (i - 1) + 1, i);for(int i = 1; i <= n; i++) a[i] = rng61() % n + 1;}
int L[N],R[N],id[N];
int fa[N][21],dep[N];
int lst[N];//每个点上一个和他一样的颜色
int g[N],mx;
void dfs (int x)
{
mx=max(mx,dep[x]);
lst[x]=g[a[x]];int Lst=g[a[x]];g[a[x]]=dep[x];
L[x]=num;id[num]=x;
for (int u=1;u<=20;u) fa[x][u]=fa[fa[x][u-1]][u-1];
for (int u=last[x];u!=-1;u=e[u].last)
{
int y=e[u].y;
fa[y][0]=x;dep[y]=dep[x]+1;dfs(y);
}
num++;R[x]=num;id[num]=x;g[a[x]]=Lst;
}
int get_LCA (int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int u=20;u>=0;u--)
if (dep[fa[x][u]]>=dep[y])
x=fa[x][u];
if (xy) return x;
for (int u=20;u>=0;u--)
if (fa[x][u]!=fa[y][u])
{x=fa[x][u];y=fa[y][u];}
return fa[x][0];
}
LL ans[N];
struct qt
{
int l,r;
int id,LCA;
}h[N*2];int tot=0;
bool cmp (qt x,qt y) {return x.l/SIZy.l/SIZ?x.r<y.r:x.l<y.l;}
bool in[N];
int TOT[N];
LL Ans;
void modify (int x)
{
if (in[x]true)
{
TOT[a[x]]--;
if (TOT[a[x]]0) Ans--;
}
else
{
if (TOT[a[x]]0) Ans++;
TOT[a[x]];
}
in[x]^=1;
}
void case1 ()
{
memset(in,false,sizeof(in));
Ans=0;
sort(h+1,h+1+tot,cmp);
int L=1,R=0;
for (int u=1;u<=tot;u)
{
while (R<h[u].r) modify(id[R]);
while (L>h[u].l) modify(id[--L]);
while (R>h[u].r) modify(id[R--]);
while (L<h[u].l) modify(id[L]);
if (h[u].LCA) modify(h[u].LCA);
ans[h[u].id]=Ans;
if (h[u].LCA) modify(h[u].LCA);
}
}
struct qy
{
LL c,c1,c2,c3,c4,c5,c6;
//个数 xx的和 2xxdep[now]的和 dep[z]的和 dep[now]的和 dep[now]dep[z]的和
qy () {};
qy(LL _c,LL _c1,LL _c2,LL _c3,LL _c4,LL _c5,LL _c6) {c=_c;c1=_c1;c2=_c2;c3=_c3;c4=_c4;c5=_c5;c6=_c6;}
void print()
{
printf("c:%lld c1:%lld c2:%lld c3:%lld c4:%lld c5:%lld c6:%lld\n",c,c1,c2,c3,c4,c5,c6);
}
};
qy zero;
qy operator + (qy x,qy y) {return qy(x.c+y.c,x.c1+y.c1,x.c2+y.c2,x.c3+y.c3,x.c4+y.c4,x.c5+y.c5,x.c6+y.c6);}
qy operator - (qy x,qy y) {return qy(x.c-y.c,x.c1-y.c1,x.c2-y.c2,x.c3-y.c3,x.c4-y.c4,x.c5-y.c5,x.c6+y.c6);}
int rt[N],num1;
qy c[N20];
int s1[N20],s2[N20];
void change (int &now,int l,int r,int x,qy cc)
{
num1++;c[num1]=c[now];s1[num1]=s1[now];s2[num1]=s2[now];now=num1;
c[now]=c[now]+cc;
if (lr)
{
if (c[now].c60) c[now].c6=cc.c6;
else c[now].c6=min(c[now].c6,cc.c6);
return ;
}
int mid=(l+r)>>1;
if (x<=mid) change(s1[now],l,mid,x,cc);
else change(s2[now],mid+1,r,x,cc);
}
void dfs2 (int x)
{
int xx=dep[x]-dep[lst[x]];
qy cc=qy(1LL,(LL)xx,(LL)2xxdep[x],(LL)dep[lst[x]],(LL)dep[x],(LL)dep[x]*dep[lst[x]],(LL)dep[x]);
change(rt[x],0,mx,dep[lst[x]],cc);
for (int u=last[x];u!=-1;u=e[u].last)
{
int y=e[u].y;
rt[y]=rt[x];dfs2(y);
}
}
qy ask (int rt1,int rt2,int l,int r,int L,int R)
{
if (rt20) return zero;
if (lL&&rR) return c[rt2]-c[rt1];
int mid=(l+r)>>1;
if (R<=mid) return ask(s1[rt1],s1[rt2],l,mid,L,R);
else if (L>mid) return ask(s2[rt1],s2[rt2],mid+1,r,L,R);
else return ask(s1[rt1],s1[rt2],l,mid,L,mid)+ask(s2[rt1],s2[rt2],mid+1,r,mid+1,R);
}
//个数 xx的和 2xxdep[now]的和 dep[z]的和 dep[now]的和 dep[now]dep[z]的和
LL calc (int x,int y)//x和y在一条链上 x是LCA
{
if (x==0) return 0;
if (dep[x]>dep[y]) swap(x,y);
LL lalal=0;
qy d;
d=c[rt[x]];
lalal=lalal+d.c1(dep[x]+dep[y]+2);
lalal=lalal-d.c2;
lalal-=d.c;
d=ask(rt[x],rt[y],0,mx,0,dep[x]-1);
lalal=lalal+d.cdep[x](dep[y]+1);
lalal=lalal-d.c3*(dep[y]+1);
lalal=lalal-d.c4*dep[x];
lalal=lalal+d.c5;
return lalal;
}
void add (int x,int c)
{
if (c1)
{
if (TOT[x]0) Ans++;
TOT[x]++;
}
else
{
TOT[x]--;
if (TOT[x]0) Ans--;
}
}
int vec[N],vec1[N];
int siz,siz1;
vector<int> t[N];
int f[N];
LL case2 (int x,int y,int LCA)
{
if (LCAx) return calc(x,y);
if (LCA>p)//不在主链上
{
LL lalal=0;
lalal=calc(LCA,x)+calc(LCA,y)-calc(LCA,LCA);
siz=siz1=0;
while (x!=LCA) {vec[++siz]=x;;x=fa[x][0];}
while (y!=LCA) {vec1[siz1]=y;y=fa[y][0];}
TOT[a[LCA]];
for (int u=siz;u>=1;u--)
{
Ans=1;
for (int i=siz;i>=u;i--) add(a[vec[i]],1);
for (int i=siz1;i>=1;i--)
{
add(a[vec1[i]],1);
lalal=lalal+Ans;
}
for (int i=siz;i>=u;i--) add(a[vec[i]],-1);
for (int i=siz1;i>=1;i--) add(a[vec1[i]],-1);
}
TOT[a[LCA]]--;
return lalal;
}
else//在主链上
{
siz=0;
int xx=x,yy=y;
while (x>p) x=fa[x][0];
while (y>p) y=fa[y][0];
if (x>y) swap(xx,yy);
x=xx;y=yy;
while (x>p) {vec[siz]=x;x=fa[x][0];}
while (y>p) {f[a[y]]=dep[y];y=fa[y][0];}
LL lalal=0;
// printf("%d %d %d\n",x,xx,yy);
lalal=calc(x,yy)+calc(x,xx)-calc(x,x);
// printf("lalal:%lld %lld %lld %lld\n",calc(x,yy),calc(x,xx),calc(x,x),lalal);
LL now;
now=(calc(x,yy)-calc(fa[x][0],yy))-(calc(x,x)-calc(fa[x][0],x));
//printf("%lld\n",now);
for (int u=siz;u>=1;u--)//开始往前移动
{
if (lst[vec[u]]==0)
{
int o=-1;
int siz2=t[a[vec[u]]].size();
for (int i=0;i<siz2;i)
if (t[a[vec[u]]][i]>x)
{o=t[a[vec[u]]][i];break;}
if (o>y) o=-1;
// printf("o:%d\n",o);
if (o!=-1) now=now+(dep[o]-dep[x]-1);
else if (f[a[vec[u]]]!=-1) now=now+(f[a[vec[u]]]-dep[x]-1);
else now=now+dep[yy]-dep[x];
}
else if (lst[vec[u]]<dep[x])
{
qy cc=ask(0,rt[yy],0,mx,lst[vec[u]],lst[vec[u]]);
//cc.print();
if (cc.c0)//如果没有
now=now+(dep[yy]-dep[x]);
else now=now+(cc.c6-(dep[x]+1));
}
lalal=lalal+now;
}
y=yy;while (y>p) {f[a[y]]=-1;y=fa[y][0];}
return lalal;
}
}
int main()
{
// freopen("old-task3.in","r",stdin);
//freopen("a.out","w",stdout);
c[0]=zero=qy(0,0,0,0,0,0,0);
scanf("%d",&T);
while (T--)
{
memset(g,0,sizeof(g));
memset(TOT,0,sizeof(TOT));
num1=mx=0;
memset(f,-1,sizeof(f));
tot=num=0;memset(last,-1,sizeof(last));
gen();
for (int u=1;u<=n;u++) t[u].clear();
dep[0]=0;dep[1]=1;num=0;dfs(1);
for (int u=1;u<=p;u++) if (lst[u]0)
{
t[a[u]].push_back(u);
}
rt[1]=1;s1[1]=0;s2[1]=0;c[1]=zero;
dfs2(1);
/for (int u=1;u<=n;u++) printf("%d ",a[u]);
printf("\n");/
//for (int u=1;u<=n;u++) printf("%d ",lst[u]);printf("\n");
/for (int u=1;u<=n;u++) printf("%d %d\n",L[u],R[u]);
printf("\n");
for (int u=1;u<=num;u++) printf("%d ",id[u]);
printf("\n");/
scanf("%d",&m);
for (int u=1;u<=m;u++)
{
ans[u]=0;
int op,x,y;
scanf("%d%d%d",&op,&x,&y);
if (L[x]>L[y]) swap(x,y);
int LCA=get_LCA(x,y);
if (op1)
{
tot++;
if (LCAx) {h[tot].LCA=0;h[tot].l=L[x];h[tot].r=L[y];h[tot].id=u;}
else {h[tot].LCA=LCA;h[tot].l=R[x];h[tot].r=L[y];h[tot].id=u;}
}
else ans[u]=case2(x,y,LCA);
}
case1();
///printf("%d\n",num1);
for (int u=1;u<=m;u++) printf("%lld\n",ans[u]);
}
return 0;
}
这里空空如也







有帮助,赞一个