这是本文档旧的修订版!
树套树,就是在一个树型数据结构上,每个点不再是一个节点,而是另外一个树形数据结构。
经常应用在一些普通的数据结构外套上区间操作或者动态操作时候。没有固定的套路,根据题目来选不同的树型数据结构组合。
下面介绍一些常用的树套树
现在给出 $1∼n $ 的一个排列,按照某种顺序依次删除 $m$ 个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
我们可以建立树状数组,第 $i$ 位维护 $a[i-lowbit(i)+1] ∼ a[i]$ 的权值线段树。插入和普通的求逆序对方法相同,删除 $x$ 的时候查询位置在后面比 $x$ 大的数有多少即可。
为了防止内存爆找,线段树采用动态开点。
#include<iostream> #include<cstdio> #include<algorithm> #include<cctype> #include<cstring> #define ll long long using namespace std; inline int read() { int k=0,f=1;char c=getchar(); while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();} while(isdigit(c)) k=k*10+c-'0',c=getchar();return f*k; } const int N=100055; int root[N],lch[N*91],rch[N*91],sum[N*91],cnt; int n,m,pos[N]; ll anss; void add(int &k,int l,int r,int x) { if(!k) k=++cnt;sum[k]++; if(l==r) return ; int mid=l+r>>1; if(x<=mid) add(lch[k],l,mid,x); else add(rch[k],mid+1,r,x); } void del(int k,int l,int r,int x) { if(!k) return ; sum[k]--; if(l==r) return ; int mid=l+r>>1; if(x<=mid) del(lch[k],l,mid,x); else del(rch[k],mid+1,r,x); } int query(int k,int l,int r,int a,int b) { if(!k) return 0; if(a<=l&&b>=r) return sum[k]; int mid=l+r>>1,ans=0; if(a<=mid) ans+=query(lch[k],l,mid,a,b); if(b>mid) ans+=query(rch[k],mid+1,r,a,b); return ans; } int main() { n=read();m=read(); for(int i=1;i<=n;i++) { int a=read();pos[a]=i; for(int j=i;j<=n;j+=j&-j) add(root[j],1,n,a); for(int j=i;j;j-=j&-j) anss+=query(root[j],1,n,a+1,n); } for(int i=1;i<=m;i++) { printf("%lld\n",anss); int a=read(); if(a!=n) { for(int j=pos[a];j;j-=j&-j) anss-=query(root[j],1,n,a+1,n); } if(a!=1) { for(int j=n;j;j-=j&-j) anss-=query(root[j],1,n,1,a-1); for(int j=pos[a]-1;j;j-=j&-j) anss+=query(root[j],1,n,1,a-1); } for(int j=pos[a];j<=n;j+=j&-j) del(root[j],1,n,a); } return 0; }
让你维护一个有序数列,有以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱
5.查询k在区间内的后继
就是平衡时问题加上了区间限制。我们外层建线段树,每一个节点维护该节点包含区间的平衡树。
操作一和操作三对线段树上包含区间的节点全部进行操作。
操作四,五,外层线段树区间查询,对每个包含[l,r]的节点得到的前驱/后继求一个最大值/最小值即可。
操作二 我们二分答案,和操作一类似,利用小于某数的个数进行二分。复杂度多一个 $log$ 。
#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<cctype> using namespace std; inline int read() { int k=0,f=1;char c=getchar(); while(!isdigit(c)) {if(c=='-') f=-1;c=getchar();} while(isdigit(c)) k=k*10+c-'0',c=getchar();return f*k; } const int N=100005,inf=2147483647; struct T { int ch[2],fa,size,cnt,v; }tr[N*200]; int tot,rt[N*20],n,m,a[N]; #define l(x) tr[x].ch[0] #define r(x) tr[x].ch[1] void pu(int x) { tr[x].size=tr[l(x)].size+tr[r(x)].size+tr[x].cnt; } void rotate(int x) { int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x; tr[z].ch[tr[z].ch[1]==y]=x;tr[x].fa=z; tr[y].ch[k]=tr[x].ch[k^1];tr[tr[x].ch[k^1]].fa=y; tr[y].fa=x;tr[x].ch[k^1]=y; pu(y);pu(x); } void splay(int x,int r,int pos) { while(tr[x].fa!=pos) { int y=tr[x].fa,z=tr[y].fa; if(z!=pos) (l(z)==y)^(l(y)==x)?rotate(x):rotate(y); rotate(x); } if(!pos) rt[r]=x; } void ins(int x,int r) { int u=rt[r],f=0; if(!u) { tr[++tot].v=x;tr[tot].size=tr[tot].cnt=1;rt[r]=tot;return ; } while(u&&tr[u].v!=x) f=u,u=tr[u].ch[tr[u].v<x]; if(u) {tr[u].cnt++;tr[u].size++;splay(u,r,0);return;} u=++tot; tr[u].fa=f;if(f) tr[f].ch[tr[f].v<x]=u; tr[u].v=x;tr[u].size=tr[u].cnt=1;splay(u,r,0); } void find(int x,int r) { int u=rt[r]; while(tr[u].v!=x&&tr[u].ch[tr[u].v<x]) u=tr[u].ch[tr[u].v<x]; splay(u,r,0); } int find_max(int x) { while(r(x)) x=r(x);return x; } int lxt(int x,int r) { int u=rt[r],ans=-inf; while(u) { if(tr[u].v<x) ans=max(ans,tr[u].v),u=r(u); else u=l(u); } return ans; } int rxt(int x,int r) { int u=rt[r],ans=inf; while(u) { if(x<tr[u].v) ans=min(ans,tr[u].v),u=l(u); else u=r(u); } return ans; } void replace(int x,int y,int r) { find(x,r); int u=rt[r]; if(tr[u].cnt>1) tr[u].size--,tr[u].cnt--; else if(!tr[u].ch[0]) rt[r]=r(u),tr[r(u)].fa=0; else if(!tr[u].ch[1]) rt[r]=l(u),tr[l(u)].fa=0; else { splay(find_max(l(u)),rt[r],u); tr[tr[u].ch[0]].ch[1]=tr[u].ch[1]; tr[r(u)].fa=l(u);tr[l(u)].fa=0; rt[r]=l(u);pu(l(u)); } ins(y,r); } int ran(int x,int r) { int u=rt[r],ans=0; while(u) { if(tr[u].v>x) u=l(u); else if(tr[u].v<x) ans+=tr[l(u)].size+tr[u].cnt,u=r(u); else {ans+=tr[l(u)].size;return ans;} } return ans; } #define lson k<<1,l,mid #define rson k<<1|1,mid+1,r void build(int k,int l,int r) { for(int i=l;i<=r;i++) ins(a[i],k); if(l==r) return ; int mid=l+r>>1; build(lson);build(rson); } int sol1(int k,int l,int r,int a,int b,int x) { if(a<=l&&b>=r) { return ran(x,k); } int mid=l+r>>1,ans=0; if(a<=mid) ans+=sol1(lson,a,b,x); if(b>mid) ans+=sol1(rson,a,b,x); return ans; } int sol2(int l,int r,int k) { int L=0,R=1e8; while(R>L) { int mid=L+R>>1; if(sol1(1,1,n,l,r,mid)<k) L=mid+1; else R=mid; } return L-1; } void sol3(int k,int l,int r,int c,int x) { replace(a[c],x,k); if(l==r) {a[c]=x;return;} int mid=l+r>>1; if(c<=mid) sol3(lson,c,x); else sol3(rson,c,x); } int sol4(int k,int l,int r,int a,int b,int x) { if(a<=l&&b>=r) return lxt(x,k); int mid=l+r>>1,ans=-inf; if(a<=mid) ans=max(ans,sol4(lson,a,b,x)); if(b>mid) ans=max(ans,sol4(rson,a,b,x)); return ans; } int sol5(int k,int l,int r,int a,int b,int x) { if(a<=l&&b>=r) { return rxt(x,k); } int mid=l+r>>1,ans=inf; if(a<=mid) ans=min(ans,sol5(lson,a,b,x)); if(b>mid) ans=min(ans,sol5(rson,a,b,x)); return ans; } int main() { int opt,b,c,d; n=read();m=read(); for(int i=1;i<=n;i++) a[i]=read(); build(1,1,n); for(int i=1;i<=m;i++) { opt=read();b=read();c=read();if(opt!=3) d=read(); if(opt==1) printf("%d\n",sol1(1,1,n,b,c,d)+1); else if(opt==2) printf("%d\n",sol2(b,c,d)); else if(opt==3) sol3(1,1,n,b,c); else if(opt==4) printf("%d\n",sol4(1,1,n,b,c,d)); else printf("%d\n",sol5(1,1,n,b,c,d)); } return 0; }