[0.x]前言
树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。
──OI Wiki
[1.x] 轻重链剖分
[1.1.x]前置概念
- 重儿子:对于一个节点,它的儿子中最大的一个儿子为重儿子(子树最大的儿子)
- 轻儿子:除了重儿子的其他儿子
- 重边:连接重儿子的边
- 轻边:除了重边的边
- 重链:把1个轻儿子和相邻重儿子连起来的链,即一个重链总是以轻儿子或者根重儿子为头的重儿子串
[1.2.x]预处理
我们进行两次dfs
,分别维护这些东西
[1.2.1.x]第一次dfs
第一次dfs
主要是确定重儿子,为第二次dfs
做准备,顺便处理可以一次性处理的东西。
dep[i]
节点i
的深度fa[i]
节点i
的父亲son[i]
节点i
的重儿子size[i]
节点i
为根的子树大小
inline void dfs1(int now,int father){
size[now] = 1 , deep[now] = deep[father] + 1 , fa[now] = father;
for(int i = head[now];i;i = next[i]){
if(ver[i] == fa[now]) continue;
dfs1(ver[i],now);
size[now] += size[ ver[i] ];
son[now] = size[ son[now] ] > size[ ver[i] ] ? son[now] : ver[i]; // 更新重儿子
}
return;
}
[1.2.2.x]第二次dfs
第二次dfs
是在根据已经得到的重儿子来确定一个特别的dfs
序,来维护一些性质
这个dfs
序号的特别要求是:先走重儿子,这样能保证重儿子链的dfs
序是连续的
top[i]
节点i
所在的重链头id[i]
节点i
的dfs
序号b[id[i]]
通过节点i
的dfs
序索引a[i]
inline void dfs2(int now,int toper){
id[now] = ++id[0] , top[now] = toper , b[id[0]] = a[now];
if(!son[now]) return; // 没有重儿子即叶子节点
dfs2(son[now],toper); // 重儿子先走
for(int i = head[now];i;i = next[i]){
if(id[ ver[i] ]) continue;
dfs2(ver[i],ver[i]);
}
return;
}
[1.3.x]线段树维护
经过预处理后,我们可以发现,所有在同一个重链中的重儿子的dfs
序是连续的,同时,还有一个性质,以i
为根的子树上所有的点的dfs
序都在[id[i],id[i]+size[i]-1]
中。因此我们可以以dfs
序为索引建立线段树
int TOT;
struct node{
int l,r;
ll sum,tag;
int dis(){return r-l+1;}
node * ls, * rs;
}Tree[4*N];
inline node * create(){return &Tree[++TOT];}
inline void pushdown(node * cur){
cur->ls->sum += cur->tag * cur->ls->dis(), cur->ls->sum %= p;
cur->ls->tag += cur->tag , cur->ls->tag %= p;
cur->rs->sum += cur->tag * cur->rs->dis(), cur->rs->sum %= p;
cur->rs->tag += cur->tag , cur->rs->tag %= p;
cur->tag = 0;
return;
}
inline void pushup(node * cur){cur->sum = (cur->ls->sum + cur->rs->sum)%p;}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->sum = cur->tag = 0;
if(cur->l == cur->r){cur->sum = b[l];return;}
int mid = (l+r)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
pushup(cur);
return;
}
inline void edit(node * cur,int l,int r,int x){
if(l > cur->r || r < cur->l) return;
if(l <= cur->l && cur->r <= r){
cur->sum += x * cur->dis() % p , cur->sum %= p , cur->tag += x;
return;
}
if(cur->l == cur->r) return;
pushdown(cur);
if(l <= cur->ls->r) edit(cur->ls,l,r,x);
if(r >= cur->rs->l) edit(cur->rs,l,r,x);
pushup(cur);
return;
}
inline ll query(node * cur,int l,int r){
if(l > cur->r || r < cur->l) return 0;
if(l <= cur->l && cur->r <= r) return cur->sum;
pushdown(cur);
return (query(cur->ls,l,r) + query(cur->rs,l,r))%p;
}
[1.4.x]实现
[1.4.1.x]链上修改
首先这里肯定要用到线段树的区间修改,怎样增加效率呢。首先我们已经有了所有重链的序号连续,那自然通过top[i]
不断向上跳,类似于树上倍增LCA
,不断上跳直到两个点在同一个重链中。
inline void edit_chain(node * cur){
//把x到y都加上z
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
edit(cur,id[top[x]],id[x],z);
x = fa[ top[x] ];
}
if(id[x] > id[y]) std::swap(x,y);
edit(cur,id[x],id[y],z);
return;
}
链上查询和链上修改同理。
inline void query_chain(node * cur){
//查询x到y的和
ll ans = 0;
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
ans += query(cur,id[top[x]],id[x]) , ans %= p;
x = fa[top[x]];
}
if(id[x] > id[y]) std::swap(x,y);
ans += query(cur,id[x],id[y]) , ans %= p;
cout << ans << "\n";
return;
}
树上修改和查询就更简单了,直接操作[id[i],id[i]+size[i]-1]
就行
inline void edit_tree(node * cur){
//给x的子树都加上z
edit(cur,id[x],id[x]+size[x]-1,z);
return;
}
inline void query_tree(node * cur){
//查询x的子树和
cout << query(cur,id[x],id[x]+size[x]-1) << "\n";
return;
}
区间最值的维护和区间和的维护基本上类似,可以参考相关题目第二题的代码。
[1.5.x]相关题目
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
const int N = 1e5+5;
int n,m,root,p,a[2*N],b[2*N],x,y,z,opt;
int head[2*N],next[2*N],ver[2*N],tot;
int fa[2*N],top[2*N],deep[2*N],son[2*N],size[2*N],id[2*N];
//------------------------------------------------------
inline void link(int x,int y){ver[++tot] = y , next[tot] = head[x] , head[x] = tot;}
inline void dfs1(int now,int father){
size[now] = 1 , deep[now] = deep[father] + 1 , fa[now] = father;
for(int i = head[now];i;i = next[i]){
if(ver[i] == fa[now]) continue;
dfs1(ver[i],now);
size[now] += size[ ver[i] ];
son[now] = size[ son[now] ] > size[ ver[i] ] ? son[now] : ver[i]; // 更新重儿子
}
return;
}
inline void dfs2(int now,int toper){
id[now] = ++id[0] , top[now] = toper , b[id[0]] = a[now];
if(!son[now]) return; // 没有重儿子即叶子节点
dfs2(son[now],toper); // 重儿子先走
for(int i = head[now];i;i = next[i]){
if(id[ ver[i] ]) continue;
dfs2(ver[i],ver[i]);
}
return;
}
//------------------------------------------------------
int TOT;
struct node{
int l,r;
ll sum,tag;
int dis(){return r-l+1;}
node * ls, * rs;
}Tree[4*N];
inline node * create(){return &Tree[++TOT];}
inline void pushdown(node * cur){
cur->ls->sum += cur->tag * cur->ls->dis(), cur->ls->sum %= p;
cur->ls->tag += cur->tag , cur->ls->tag %= p;
cur->rs->sum += cur->tag * cur->rs->dis(), cur->rs->sum %= p;
cur->rs->tag += cur->tag , cur->rs->tag %= p;
cur->tag = 0;
return;
}
inline void pushup(node * cur){cur->sum = (cur->ls->sum + cur->rs->sum)%p;}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->sum = cur->tag = 0;
if(cur->l == cur->r){cur->sum = b[l];return;}
int mid = (l+r)>>1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
pushup(cur);
return;
}
inline void edit(node * cur,int l,int r,int x){
if(l > cur->r || r < cur->l) return;
if(l <= cur->l && cur->r <= r){
cur->sum += x * cur->dis() % p , cur->sum %= p , cur->tag += x;
return;
}
if(cur->l == cur->r) return;
pushdown(cur);
if(l <= cur->ls->r) edit(cur->ls,l,r,x);
if(r >= cur->rs->l) edit(cur->rs,l,r,x);
pushup(cur);
return;
}
inline ll query(node * cur,int l,int r){
if(l > cur->r || r < cur->l) return 0;
if(l <= cur->l && cur->r <= r) return cur->sum;
pushdown(cur);
return (query(cur->ls,l,r) + query(cur->rs,l,r))%p;
}
//------------------------------------------------------
inline void edit_chain(node * cur){
//把x到y都加上z
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
edit(cur,id[top[x]],id[x],z);
x = fa[ top[x] ];
}
if(id[x] > id[y]) std::swap(x,y);
edit(cur,id[x],id[y],z);
return;
}
inline void query_chain(node * cur){
//查询x到y的和
ll ans = 0;
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
ans += query(cur,id[top[x]],id[x]) , ans %= p;
x = fa[top[x]];
}
if(id[x] > id[y]) std::swap(x,y);
ans += query(cur,id[x],id[y]) , ans %= p;
cout << ans << "\n";
return;
}
inline void edit_tree(node * cur){
//给x的子树都加上z
edit(cur,id[x],id[x]+size[x]-1,z);
return;
}
inline void query_tree(node * cur){
//查询x的子树和
cout << query(cur,id[x],id[x]+size[x]-1) << "\n";
return;
}
//------------------------------------------------------
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("in.in", "r", stdin);
cin >> n >> m >> root >> p;
rep(i,1,n) cin >> a[i];
rep(i,1,n-1){
cin >> x >> y;
link(x,y) , link(y,x);
}
fa[root] = 0 , deep[fa[root]] = 0;
dfs1(root,0);
dfs2(root,root);
node * Root = create();
build(Root,1,n);
while(m--){
cin >> opt;
if(opt == 1){//链改
cin >> x >> y >> z;
z %= p;
edit_chain(Root);
} else if(opt == 2){//链查
cin >> x >> y;
query_chain(Root);
} else if(opt == 3){//树改
cin >> x >> z;
z %= p;
edit_tree(Root);
} else if(opt == 4){//树查
cin >> x;
query_tree(Root);
}
}
return 0;
}
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
#define mkp std::make_pair
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
inline bool cmp(int x,int y){return x < y;}
const int inf = 1e9+9;
const int N = 1e6+9;
const double eps = 1e-7;
int _,n,m,a[2*N],u,v;
string str;
//----------------------------
int head[2*N],next[2*N],ver[2*N],tot;
inline void link(int x,int y){ver[++tot] = y , next[tot] = head[x] , head[x] = tot;}
//----------------------------
int fa[2*N],top[2*N],son[2*N],size[2*N],deep[2*N],id[2*N],b[2*N];
inline void dfs1(int now,int fathter){
fa[now] = fathter , deep[now] = deep[fathter] + 1 , size[now] = 1;
for(int i = head[now];i;i = next[i]){
if(ver[i] == fathter) continue;
dfs1(ver[i],now);
size[now] += size[ ver[i] ];
son[now] = size[ son[now] ] > size[ son[ver[i]] ] ? son[now] : ver[i];
}
return;
}
inline void dfs2(int now,int toper){
id[now] = ++id[0] , b[ id[0] ] = a[now] , top[now] = toper;
if(!son[now]) return;
dfs2(son[now],toper);
for(int i = head[now];i;i = next[i]){
if(id[ ver[i] ]) continue;
dfs2(ver[i],ver[i]);
}
return;
}
//----------------------------
int TOT;
struct node{
int l,r,sum,max;
node * ls , * rs;
}Tree[4*N];
inline node * create(){return &Tree[++TOT];}
inline void pushup(node * cur){cur->sum = cur->ls->sum + cur->rs->sum , cur->max = std::max(cur->ls->max,cur->rs->max);}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->sum = cur->max = 0;
if(l == r){
cur->sum = cur->max = b[l];
return;
}
int mid = (l+r) >> 1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
pushup(cur);
return;
}
inline void edit(node * cur){
if(cur->l == cur->r){
cur->sum = cur->max = v;
return;
}
if(u <= cur->ls->r) edit(cur->ls);
if(cur->rs->l <= u) edit(cur->rs);
pushup(cur);
return;
}
inline int query_max(node * cur){
if(cur->r < u || v < cur->l) return -inf;
if(u <= cur->l && cur->r <= v) return cur->max;
return std::max(query_max(cur->ls),query_max(cur->rs));
}
inline int query_sum(node * cur){
if(cur->r < u || v < cur->l) return 0;
if(u <= cur->l && cur->r <= v) return cur->sum;
return query_sum(cur->ls) + query_sum(cur->rs);
}
//----------------------------
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
//freopen("in.in", "r", stdin);
cin >> n;
rep(i,1,n-1){
cin >> u >> v;
link(u,v) , link(v,u);
}
rep(i,1,n) cin >> a[i];
//----------------------------
dfs1(1,1) , dfs2(1,1);
node * root = create();
build(root,1,n);
//----------------------------
cin >> m;
while(m--){
cin >> str >> u >> v;
if(str == "CHANGE"){
u = id[u];
edit(root);
} else if(str == "QMAX"){
int x = u , y = v , ans = -inf;
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
u = id[ top[x] ] , v = id[x] , x = fa[ top[x] ] , ans = std::max(ans,query_max(root));
}
if(id[x] > id[y]) std::swap(x,y);
u = id[x] , v = id[y] , ans = std::max(ans,query_max(root));
cout << ans << "\n";
} else if(str == "QSUM"){
int x = u , y = v , ans = 0;
while(top[x] != top[y]){
if(deep[ top[x] ] < deep[ top[y] ]) std::swap(x,y);
u = id[ top[x] ] , v = id[x] , x = fa[ top[x] ] , ans += query_sum(root);
}
if(id[x] > id[y]) std::swap(x,y);
u = id[x] , v = id[y] , ans += query_sum(root);
cout << ans << "\n";
}
}
return 0;
}
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<bitset>
#include<vector>
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#define rep(i,a,b) for(register int i = (a);i <= (b);++i)
#define per(i,a,b) for(register int i = (a);i >= (b);--i)
#define mkp std::make_pair
typedef long long ll;
typedef unsigned long long ull;
using std::string;using std::cin;using std::cout;
inline bool cmp(int x,int y){return x < y;}
const int N = 1e5+9;
const int inf = 1e9+9;
const double eps = 1e-7;
int _,n,m,a[2*N],opt,u,v,w;
//----------------------------
int head[2*N],ver[2*N],next[2*N],tot;
inline void link(int x,int y){ver[++tot] = y , next[tot] = head[x] , head[x] = tot;}
//----------------------------
int fa[2*N],top[2*N],deep[2*N],size[2*N],son[2*N],id[2*N],b[2*N];
inline void dfs1(int now,int father){
fa[now] = father , deep[now] = deep[father] + 1 , size[now] = 1;
for(int i = head[now];i;i = next[i]){
if(ver[i] == father) continue;
dfs1(ver[i],now);
size[now] += size[ ver[i] ];
son[now] = size[ son[now] ] > size[ ver[i] ] ? son[now] : ver[i];
}
}
inline void dfs2(int now,int toper){
top[now] = toper , id[now] = ++id[0] , b[ id[0] ] = a[now];
if(!son[now]) return;
dfs2(son[now],toper);
for(int i = head[now];i;i = next[i]){
if(id[ ver[i] ]) continue;
dfs2(ver[i],ver[i]);
}
}
//----------------------------
int TOT;
struct node{
int l,r;
ll tag,sum;
node * ls , * rs;
ll dis(){return r-l+1;}
}Tree[4*N];
inline node * create(){return &Tree[++TOT];}
inline void pushup(node * cur){cur->sum = cur->ls->sum + cur->rs->sum;}
inline void pushdown(node * cur){
cur->ls->tag += cur->tag , cur->rs->tag += cur->tag;
cur->ls->sum += cur->tag * cur->ls->dis();
cur->rs->sum += cur->tag * cur->rs->dis();
cur->tag = 0;
return;
}
inline void build(node * cur,int l,int r){
cur->l = l , cur->r = r , cur->tag = 0;
if(l == r){cur->sum = b[l];return;}
int mid = (l+r) >> 1;
cur->ls = create() , cur->rs = create();
build(cur->ls,l,mid) , build(cur->rs,mid+1,r);
pushup(cur);
}
inline void edit(node * cur){
if(v < cur->l || cur->r < u) return;
if(u <= cur->l && cur->r <= v){
cur->tag += w , cur->sum += w * cur->dis();
return;
}
if(cur->l == cur->r) return;
pushdown(cur);
if(u <= cur->ls->r) edit(cur->ls);
if(cur->rs->l <= v) edit(cur->rs);
pushup(cur);
}
inline ll query(node * cur){
if(v < cur->l || cur->r < u) return 0;
if(u <= cur->l && cur->r <= v) return cur->sum;
pushdown(cur);
return query(cur->ls) + query(cur->rs);
}
//----------------------------
int main(){
std::ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
cin >> n >> m;
rep(i,1,n) cin >> a[i];
rep(i,1,n-1){
cin >> u >> v;
link(u,v) , link(v,u);
}
dfs1(1,1) , dfs2(1,1);
node * root = create();
build(root,1,n);
int p,x;
while(m--){
cin >> opt;
if(opt == 1){//point edit
cin >> p >> x;
u = v = id[p] , w = x;
edit(root);
} else if(opt == 2){//tree edit
cin >> p >> x;
u = id[p] , v = id[p] + size[p] - 1 , w = x;
edit(root);
} else {//query
cin >> p;
ll ans = 0;
while(top[1] != top[p]){
u = id[ top[p] ] , v = id[p];
ans += query(root);
p = fa[ top[p] ];
}
u = id[1] , v = id[p];
cout << ans + query(root) << "\n";
}
}
return 0;
}