线段树
树状数组和线段树可谓是亲兄弟,但他两毕竟还有一些区别:
树状数组能有的操作,线段树一定有;
线段树有的操作,树状数组不一定有。
但是,并不是直接选择线段树而放弃学习树状数组。
与按照二进制位进行区间划分的树状数组相比,线段树是一种更加通用的结构:
- 线段树的每个节点都代表一个区间
- 线段树具有唯一的根节点,代表的区间是整个统计范围,如
- 线段树的每个叶节点都代表一个长度为 1 的元区间
- 对于每个内部节点,它的左子节点是 , 右子节点是,其中 %2F2%5Crfloor#card=math&code=mid%20%3D%5Clfloor%28l%2Br%29%2F2%5Crfloor&id=S7gk9)
上图展示了一颗线段树。可以发现,除去树的最后一层,整颗线段树一定是一颗完全二叉树,树的深度为 #card=math&code=O%28%5Clog%20N%29&id=HS08y)。因此,我们可以按照与二叉堆类似的“父子2倍”节点编号方法。
- 根节点编号为1
- 编号为 x 的节点,左子节点编号为 ,右子节点编号为 。
使用一个 struct 数组来保存线段树,树的最后一层节点在数组汇总保存的位置不是连续的,直接空出数组中多余的位置即可。
理想情况下, 个叶子节点的满二叉树有 个节点。
因为在上述存储方式下,最后还有一层产生了空余,所以保存线段树的数组长度要不小于 才能保证不会越界。
线段树的建树
给定一个长度为 的序列 ,我们可以在区间 上建立一颗线段树。
每个叶子结点 保存 的信息,由于线段树的二叉结构很容易的从下往上的传递信息,所以以求区间最大值为例。记 ,显然有 #card=math&code=val%5Bl%2Cr%5D%20%3D%20max%28val%5Bl%2Cmid%5D%2C%20val%5Bmid%2B1%2C%20r%5D%29&id=WJGwW)
struct SegTree{
int l, r;
int val;
}t[N*4]; // struct 数组存储线段树
void build(int p, int l, int r){
t[p].l = l, t[p].r = r;
if(l == r) {t[p].val = a[l]; return;}
int mid = (l + r) / 2; // l + r >> 1
build(p*2, l, mid);
build(p*2+1, mid+1, r);
t[p].val = max(t[p*2].val, t[p*2+1].val); // pushup()
}
build(1, 1, n); // 调用入口
单点修改
单调修改形如 的指令,表示把 的值修改为 v。
在线段树中,根节点(编号为 1 的节点)时执行各种命令的入口。我们需要从根节点出发,递归找到代表区间 的叶节点,然后从下往上更新 以及它的所有祖先节点上保存的信息。时间复杂度是 #card=math&code=O%28%5Clog%20N%29&id=PdOrQ)
void change(int p, int x, int v){
if(t[p].l == t[p].r) {
t[p].val = v; return;
}
int mid = (t[p].l + t[p].r) / 2;
if(x <= mid) change(p*2, x, v);
else change(p*2+1, x, v);
t[p].val = max(t[p*2].val, t[p*2+1].val);
}
change(1,x,v); // 调用入口
线段树的区间查询
区间查询形如:“” 的指令,例如查序列 在区间 上的最大值,即 ,我们只需要从根节点开始,递归执行以下过程:
- 若 完全覆盖了当前节点代表的区间,则立即回溯,并且该节点的 值为候选答案。
- 若左子节点与 有重叠部分,则递归访问左子节点
- 若右子节点与 有重叠部分,则递归访问右子节点
int ask(int p, int l, int r){
if(l <= t[p].l && r >= t[p].r) return t[p].val; // 完全包含
int mid = (t[p].l + t[p].r) / 2;
int val = - (1 << 30);
if(l <= mid) val = max(val, ask(p*2, l, r)); // 左子节点有重叠
if(r > mid) val = max(val, ask(p*2+1, l, r)); // 右子节点有重叠
return val;
}
cout << ask(1, l, r) << endl; // 调用入口
该查询过程会把询问区间 在线段树上分成 #card=math&code=O%28%5Clog%20N%29&id=U3fCy) 个节点,取它们的最大值作为答案。
为什么是 #card=math&code=O%28log%20N%29&id=VJ2Ip) 个呢?分析上述过程,在每个结点 上,设 %2F2#card=math&code=mid%20%3D%20%28p_l%2Bp_r%29%2F2&id=yvefd) (向下取整),可能出现如下情况:
- ,完全覆盖当前结点,直接返回
- ,只有 处于结点之内。
1. , 只会递归右子树
2. ,虽然递归两颗子树,但是右子节点会在递归后直接返回(对应完全覆盖的情况)- $ l\le p_l \le r \le p_r$ ,即只有 处于结点之内,与情况 2 类似。
- ,即 与 都位于结点之内。
1. 都位于 的一侧,则只会递归一颗子树
2. 分别位域 的两侧,递归左右两颗子树。可以发现,只有 4.2 会真正产生对左右两颗子树的递归,但这类情况最多只会发生一次,之后子节点上就会变成情况 2 或者情况 3.
所以 %20%3D%20O(%5Clog%20N)#card=math&code=O%282%5Clog%20N%29%20%3D%20O%28%5Clog%20N%29&id=EFK5D) 。从宏观上理解,相当于查询区间 在线段树上划分出一条递归访问路径。情况 4.2 在两条路径在从下往上第一次交汇处产生。
线段树例题
例题1 区间最大连续子段和
SP1043 GSS1 - Can you answer these queries I
注:如果提交失败,可以去vjudge 进行提交。网址:https://vjudge.net/problem/SPOJ-GSS1
#include <bits/stdc++.h>
using namespace std;
const int N = 50010;
struct SegTree{
int l, r;
int sum, lmax, rmax, mmax;
}t[N<<2];
int n, m, a[N];
void pushup(SegTree &p, SegTree ls, SegTree rs) {
p.sum = ls.sum + rs.sum;
p.lmax = max(ls.lmax, ls.sum + rs.lmax);
p.rmax = max(rs.rmax, rs.sum + ls.rmax);
p.mmax = max(ls.mmax, max(rs.mmax, ls.rmax+rs.lmax));
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if(l == r) {
t[p].sum = t[p].lmax = t[p].rmax = t[p].mmax = a[l];
return;
}
int mid = l + r >> 1;
build(p*2, l, mid);
build(p*2+1, mid+1, r);
pushup(t[p], t[p*2], t[p*2+1]);
}
SegTree query(int p, int l, int r) {
if(t[p].l >= l && t[p].r <= r) {
return t[p];
}
int mid = t[p].l + t[p].r >> 1;
if(l > mid) {
return query(p*2+1, l, r);
} else if(r <= mid) {
return query(p*2, l, r);
} else {
SegTree res;
pushup(res, query(p*2, l, r), query(p*2+1, l, r));
return res;
}
}
int main(){
scanf("%d", &n);
for(int i=1;i<=n;i++) scanf("%d", &a[i]);
build(1, 1, n);
scanf("%d", &m);
while(m--){
int l, r;
scanf("%d%d", &l, &r);
printf("%d\n", query(1, l, r).mmax);
}
return 0;
}
例题2 带修改的区间最大连续子段和
SP1716 GSS3 - Can you answer these queries III
VJudge 提交网址:https://vjudge.net/problem/SPOJ-GSS3
#include <bits/stdc++.h>
using namespace std;
const int N = 50010;
struct SegTree{
int l, r;
int sum, lmax, rmax, mmax;
}t[N<<2];
int n, m, a[N];
void pushup(SegTree &p, SegTree ls, SegTree rs) {
p.sum = ls.sum + rs.sum;
p.lmax = max(ls.lmax, ls.sum + rs.lmax);
p.rmax = max(rs.rmax, rs.sum + ls.rmax);
p.mmax = max(ls.mmax, max(rs.mmax, ls.rmax+rs.lmax));
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if(l == r) {
t[p].sum = t[p].lmax = t[p].rmax = t[p].mmax = a[l];
return;
}
int mid = l + r >> 1;
build(p*2, l, mid);
build(p*2+1, mid+1, r);
pushup(t[p], t[p*2], t[p*2+1]);
}
SegTree query(int p, int l, int r) {
if(t[p].l >= l && t[p].r <= r) {
return t[p];
}
int mid = t[p].l + t[p].r >> 1;
if(l > mid) {
return query(p*2+1, l, r);
} else if(r <= mid) {
return query(p*2, l, r);
} else {
SegTree res;
pushup(res, query(p*2, l, r), query(p*2+1, l, r));
return res;
}
}
void change(int p, int pos, int val) {
if(t[p].l == pos && t[p].r == pos) {
t[p].sum = t[p].rmax = t[p].lmax = t[p].mmax = val;
return;
}
int mid = t[p].l + t[p].r >> 1;
if(pos <= mid) change(p*2, pos, val);
else change(p*2+1, pos, val);
pushup(t[p], t[p*2], t[p*2+1]);
}
int main(){
scanf("%d", &n);
for(int i=1;i<=n;i++) scanf("%d", &a[i]);
build(1, 1, n);
scanf("%d", &m);
while(m--){
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op == 0) {
change(1, l, r);
} else {
printf("%d\n", query(1, l, r).mmax);
}
}
return 0;
}
例题3 区间GCD
链接:https://ac.nowcoder.com/acm/contest/949/H?&headNav=acm (注:需要提前注册牛客网账号)
简要题意
小阳手中一共有 个贝壳,每个贝壳都有颜色,且初始第 个贝壳的颜色为 。现在小阳有 3 种操作:
- :给
[l,r]
区间里所有贝壳的颜色值加上x
。 - :询问
[l,r]
区间里所有相邻贝壳 颜色值的差(取绝对值) 的最大值(若l=r
输出 0)。 - :询问
[l,r]
区间里所有贝壳颜色值的最大公约数。
根据更相减损术(欧几里得算法),我们知道 %20%3D%20gcd(x%2C%20y-x)#card=math&code=gcd%28x%2C%20y%29%20%3D%20gcd%28x%2C%20y-x%29&id=qitYO), 他可以进一步扩展到三个数字的情况:#card=math&code=gcd%28x%2C%20y-x%2C%20z-x%29&id=LF4VN)。实际上,读者用数学归纳法容易证明,该性质对任意多个整数都成立。
因此,我们构造一个长度为 N 的新数列 B,其中B[i] = A[i] - A[i-1]
,B[1]
可以为任意值(后面不参与具体运算)。数列 B 为 A 的差分序列,用线段树维护序列 B 的区间最大公约数和区间绝对值的最大值。
这样,对于操作2,可以直接在线段树中查询输出答案。对于操作3,等价于求 )#card=math&code=gcd%28A%5Bl%5D%2C%20ask%281%2C%20l%2B1%2C%20r%29%29&id=x7sjH)
对于操作 1,只有 B[l]
加了 x,B[r]
减掉了 x,所以在维护 B 的线段树上只需进行两次单点修改即可。
另外,询问时,还需要数列 A 的值,可额外用一个支持“区间增加,单点查询”的树状数组对数列 A进行维护。
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
int n, m, a[N];
int gcd(int a, int b) {return b == 0 ? a : gcd(b, a % b);}
namespace BIT {
int n, bit[N];
inline void init(int tot) {
n = tot;
memset(bit, 0, sizeof bit);
}
inline void add(int x, int y){
for(;x<=n;x+=x&-x) bit[x] += y;
}
inline int query(int x) {
int ans = 0;
for(;x;x-=x&-x) ans += bit[x];
return ans;
}
}
namespace SGT {
struct SegTree{
int l, r;
int g, Max;
}t[N<<2];
int n;
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if(l == r) {
t[p].g = a[l] - a[l-1];
t[p].Max = abs(a[l] - a[l-1]);
return;
}
int mid = l + r >> 1;
build(p*2, l, mid);
build(p*2+1, mid+1, r);
t[p].g = gcd(t[p*2].g, t[p*2+1].g);
t[p].Max = max(t[p*2].Max, t[p*2+1].Max);
}
void init(int tot){ n = tot; build(1, 1, n);}
void change(int p, int pos, int val) {
if(t[p].l == pos && t[p].r == pos) {
t[p].g += val;
t[p].Max = abs(BIT::query(t[p].l) - BIT::query(t[p].l-1));
return ;
}
int mid = t[p].l + t[p].r >> 1;
if(pos <= mid) change(p*2, pos, val);
else change(p*2+1, pos, val);
t[p].g = gcd(t[p*2].g, t[p*2+1].g);
t[p].Max = max(t[p*2].Max, t[p*2+1].Max);
}
int queryGcd(int p, int l, int r) {
if(t[p].l >= l && t[p].r <= r) return t[p].g;
int mid = t[p].l + t[p].r >> 1;
int g = 0;
if(l <= mid) g = gcd(g, queryGcd(p*2, l, r));
if(r > mid) g = gcd(g, queryGcd(p*2+1, l, r));
return g;
}
int queryMax(int p, int l, int r){
if(t[p].l >= l && t[p].r <= r) return t[p].Max;
int mid = t[p].l + t[p].r >> 1;
int Max = 0;
if(l <= mid) Max = max(Max, queryMax(p*2, l, r));
if(r > mid) Max = max(Max, queryMax(p*2+1, l, r));
return Max;
}
}
int main(){
scanf("%d%d", &n, &m);
for(int i=1;i<=n;i++) scanf("%d", &a[i]);
BIT::init(n);
for(int i=1;i<=n;i++) BIT::add(i, a[i] - a[i-1]);
SGT::init(n);
while(m--){
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op == 1) {
int x;scanf("%d", &x);
BIT::add(l, x);
if(r < n)
BIT::add(r+1, -x);
SGT::change(1, l, x);
if(r < n)
SGT::change(1, r+1, -x);
} else if(op == 2) {
if(l == r) puts("0");
else {
printf("%d\n", SGT::queryMax(1, l+1, r));
}
} else {
printf("%d\n", gcd(BIT::query(l), abs(SGT::queryGcd(1, l+1, r))));
}
}
return 0;
}