前言
主席树全称可持久化权值线段树,用于解决区间第 小问题。
主席树通过对取值范围进行多次建树并进行可持久化操作,利用其可相减的性质来解决区间第 小问题。
前置芝士:可持久化数组
正文
Ⅰ可持久化数组
首先简单提一下可持续化数组。
可持续化数组需要我们在用线段树维护最新数组的同时用线段树维护其历史版本,也就是维护每一次修改前后的版本。
非常暴力的想法就是每一次改点前把线段树复制一遍然后再改,然而这显然是不行的。
我们考虑每一个相邻版本的线段树其实是有很多相似的节点的,所以我们考虑能不能把这些节点合并起来,也就是每次创建新的副本的时候,只创建必须更改的,而和历史版本完全相同的部分直接引用历史版本。这就是可持久化数组的核心思想。
我这里直接给出P3919 【模板】可持久化线段树 1(可持久化数组)的AC代码。
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
const int N = 1e6+10;
int n,m,a[N],tot,version,opt,p,x;
struct node{
int l,r,num;
node * ls , * rs;
}Tree[20*N],*root[2*N];
inline node * create(){return &Tree[++tot];}
inline void build(node * cur,int L,int R){
cur->l = L , cur->r = R;
if(L == R){
cur->num = a[L];
return;
}
int mid = (L+R)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,L,mid) , build(cur->rs,mid+1,R);
return;
}
inline node * upd(node * cur,int L,int R){
node * now = create();
now->ls = cur->ls , now->rs = cur->rs , now->l = cur->l , now->r = cur->r;
int mid = (L+R)>>1;
if(L == R) now->num = x;
else if(p <= mid) now->ls = upd(now->ls,L,mid);
else now->rs = upd(now->rs,mid+1,R);
return now;
}
inline int query(node * cur){
int mid = (cur->l + cur->r) >> 1;
if(cur->l == cur->r) return cur->num;
else if(p <= mid) return query(cur->ls);
else return query(cur->rs);
}
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("in.in", "r", stdin);
cin >> n >> m;
rep(i,1,n) cin >> a[i];
root[0] = create();
build(root[0],1,n);
rep(i,1,m){
cin >> version >> opt >> p;
if(opt&1){ // 改点
cin >> x;
root[i] = upd(root[version],1,n);
} else { // 查询
cout << query(root[version]) << "\n";
root[i] = root[version];
}
}
return 0;
}
其中特别注意的就是upd()
这个函数
inline node * upd(node * cur,int L,int R){
node * now = create();
now->ls = cur->ls , now->rs = cur->rs , now->l = cur->l , now->r = cur->r;
int mid = (L+R)>>1;
if(L == R) now->num = x;
else if(p <= mid) now->ls = upd(now->ls,L,mid);
else now->rs = upd(now->rs,mid+1,R);
return now;
}
因为除了这个函数,其他部分基本上就是裸的线段树。
我们更新节点的时候,先人为将旧版本的节点完全复制,然后根据改点的需求来更新其中一个自节点,将这个子节点更新。
说明①:右儿子先复制为历史版本的右儿子,然后更新为新的节点,递归当前过程
说明②:左儿子先复制为历史版本的左儿子,然后不再处理左儿子以及左儿子的子树
总之,可持久化数组就是不断动态开点,并最终引回历史版本。而每个版本有各自的root[version]
索引。
Ⅱ主席树求静态区间第k小
➀建树
我们先给出主席树的建树流程再进一步做解释。
- 1.将原数组进行去重排序离散化,求得原数组中的元素种类数,去重有序数组,离散化映射函数使,记忆化的离散化映射数组,有。
rep(i,1,n) cin >> a[i]; // 读入
rep(i,1,n) b[i] = a[i]; // 复制
b[0] = std::unique(b+1,b+n+1) - b - 1; // 去重
std::sort(b+1,b+b[0]+1,cmp); // 排序
rep(i,1,b[0]) mp[ b[i] ] = i; // 建立映射关系
rep(i,1,n) p[i] = mp[ a[i] ]; // 记忆化
- 2.建立以
root[0]
为根,下标范围为的最初版本线段树,并将所有节点的权值初始化为。- 这里给出线段树维护内容:
root[i]
索引的线段树中,区间为的节点维护的是符合且的元素个数- 懒得画图,如果觉得过于抽象请参考这篇
build(root[0],1,b[0]);
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->sum = 0;
if(l >= r) return;
int mid = (l+r)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
return;
}
- 3.枚举,创建新的版本使下标为的节点值加一,其父节点依然。
rep(i,1,n) root[i] = add(root[i-1],p[i]);
inline node * add(node * cur,int x){
node * now = create(); // 先复制历史版本
now->ls = cur->ls , now->l = cur->l ,
now->rs = cur->rs , now->r = cur->r ,
now->sum = cur->sum + 1; // 直接从上往下更新,等效于pushup
if(now->l == now->r) return now; // 看情况更新
if(now->l <= x && x <= now->ls->r) now->ls = add(cur->ls,x);
if(now->rs->l <= x && x <= now->r) now->rs = add(cur->rs,x);
return now;
}
- 4.建树完毕
我们借助了可持久化数组的历史版本这个特性,用它来维护的所有前缀中各个元素的个数。
从主席树的角度来讲,我们需要维护的所有前缀串中各元素的个数,为了缩小空间,需要将类似的部分合并,于是利用了可持久化数组的思想。
➁查询
我们先考虑任意区间内的第小如何借助我们刚刚建的树完成。很显然,我们可以通过二分实现。
我们现在给出一个节点cur
,它维护的信息是这样的:
- 所处版本: (即维护的是内的各个元素个数)
- 维护区间:[
cur->l
,cur->r
] - 节点权值:
cur->sum
(表示版本下属于维护区间的元素个数) - 左儿子:
cur->ls
- 右儿子:
cur->rs
它左儿子和右儿子的维护信息高度相似。
我们现在要找这第小,并且我们又知道这个节点左儿子的权值,很容易得到以下结论:
- 时,第小存在于左儿子维护的区间中
- 时,第小存在于右儿子维护的区间中
考虑到,由于区间不再是前缀,所以我们维护的信息不能直接使用,也就是不再能直接得到。想到我们之前维护的是前缀串,联想前缀和思想,我们可以猜测这一步这里能否做减法实现,也就是用版本和版本中维护同一个区间的节点的权值做差,得到的就是中属于当前维护范围的元素个数。
类比的过程,通过二分递归实现查询。
//query(root[L-1],root[R],k)返回原数组[L,R]中第k小的离散结果,记得反映射回去
inline int query(node * u,node * v,int k){
if(v->l >= v->r) return v->l;
int diff = v->ls->sum - u->ls->sum;
if(k > diff) return query(u->rs,v->rs,k-diff);
else return query(u->ls,v->ls,k);
}
相关题目
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
const int N = 2e5+10;
int n,m,a[2*N],b[2*N],p[2*N],tot,L,R,k;
std::map<int,int> mp;
struct node{
int l,r,sum;
node * ls, * rs;
}Tree[32*N],*root[2*N];
inline node * create(){return &Tree[++tot];}
inline bool cmp(int x,int y){return x < y;}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->sum = 0;
if(l >= r) return;
int mid = (l+r)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
return;
}
inline node * add(node * cur,int x){
node * now = create();
now->ls = cur->ls , now->l = cur->l , now->rs = cur->rs , now->r = cur->r , now->sum = cur->sum + 1;
if(now->l == now->r) return now;
if(now->l <= x && x <= now->ls->r) now->ls = add(cur->ls,x);
if(now->rs->l <= x && x <= now->r) now->rs = add(cur->rs,x);
return now;
}
inline int query(node * u,node * v,int k){
if(v->l >= v->r) return v->l;
int diff = v->ls->sum - u->ls->sum;
if(k > diff) return query(u->rs,v->rs,k-diff);
else return query(u->ls,v->ls,k);
}
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("in.in", "r", stdin);
cin >> n >> m;
rep(i,1,n) cin >> a[i];
rep(i,1,n) b[i] = a[i];
b[0] = std::unique(b+1,b+n+1) - b - 1; // 去重
std::sort(b+1,b+b[0]+1,cmp);
rep(i,1,b[0]) mp[ b[i] ] = i; // 建立映射关系
rep(i,1,n) p[i] = mp[ a[i] ]; // 记忆化
root[0] = create();
build(root[0],1,b[0]);
rep(i,1,n) root[i] = add(root[i-1],p[i]);
while(m--){
cin >> L >> R >> k;
cout << b[query(root[L-1],root[R],k)] << "\n";
}
return 0;
}
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
const int N = 1e5+10;
int n,m,a[2*N],opt,tot,FLAG;
struct node{
int l,r,fa,deep,num;
node * ls, * rs;
}Tree[32*N],*root[4*N];
inline node * create(){return &Tree[++tot];}
inline void copy(node * u , node * v){
u->l = v->l , u->r = v->r;
u->ls = v->ls , u->rs = v->rs;
u->fa = v->fa , u->num = v->num , u->deep = v->deep;
return;
}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r;
if(l == r){
cur->fa = cur->num = l , cur->deep = 1;
return;
}
int mid = (l+r)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
return;
}
inline node * upd(node * cur,int x,int F){
node * now = create();
copy(now,cur);
if(cur->l == cur->r && cur->l == x){
now->fa = F;
return now;
} else if(cur->l == cur->r && cur->l == F){
now->deep = now->deep + FLAG;
return now;
}
if(cur->l <= x && x <= cur->ls->r){
now->ls = upd(cur->ls,x,F);
if(cur->rs->l <= F && F <= cur->r && FLAG) now->rs = upd(cur->rs,x,F);
}
if(cur->rs->l <= x && x <= cur->r){
now->rs = upd(cur->rs,x,F);
if(cur->l <= F && F <= cur->ls->r && FLAG) now->ls = upd(cur->ls,x,F);
}
return now;
}
inline node * find(node * cur,int x){
if(cur->l == cur->r) return cur;
if(x <= cur->ls->r) return find(cur->ls,x);
if(x >= cur->rs->l) return find(cur->rs,x);
return 0;
}
inline node * get_fa(node * cur,int x){
node * now = find(cur,x);
if(x == now->fa) return now;
else return get_fa(cur,now->fa);
}
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
// freopen("in.in","r",stdin);
// freopen("out.out","w",stdout);
int x,y,F,D,f;
cin >> n >> m;
root[0] = create();
build(root[0],1,n);
rep(i,1,m){
cin >> opt;
if(opt == 1){
cin >> x >> y;
node * px = get_fa(root[i-1],x) ,* py = get_fa(root[i-1],y);
if(px->fa == py->fa){
root[i] = root[i-1];
continue;
}
if(px->deep > py->deep) F = px->fa , f = py->fa;
else F = py->fa , f = px->fa;
FLAG = px->deep == py->deep;
root[i] = upd(root[i-1],f,F);
} else if(opt == 2){
cin >> x;
root[i] = root[x];
} else {
root[i] = root[i-1];
cin >> x >> y;
node * px = get_fa(root[i],x) ,* py = get_fa(root[i],y);
if(px->fa == py->fa) cout << "1\n";
else cout << "0\n";
}
}
return 0;
}