简介

线段树是一种可以用来维护区间信息常用的数据结构。线段树可以在线段树 - 图1的时间复杂度完成单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

查询时我们使用懒标记,当要用到时,再给他加上。

学当初学了一点,但是并没有完全掌握,现在重新学一下,毕竟比赛的时候就遇到了线段树的题目,结果一题都不会写。

主要是写的线段树的题目太少了,根本不可能做到灵活运用。

先上一个板子

  1. const int N=100010
  2. struct Node
  3. {
  4. int l, r;
  5. // TODO: 需要维护的信息和懒标记
  6. }tr[N * 4];
  7. void pushup(int u)
  8. {
  9. // TODO: 利用左右儿子信息维护当前节点的信息
  10. }
  11. void pushdown(int u)
  12. {
  13. // TODO: 将懒标记下传
  14. }
  15. void build(int u, int l, int r)
  16. {
  17. if (l == r) tr[u] = {l, r};
  18. else
  19. {
  20. tr[u] = {l, r};
  21. int mid = l + r >> 1;
  22. build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
  23. pushup(u);
  24. }
  25. }
  26. void update(int u, int l, int r, int d)
  27. {
  28. if (tr[u].l >= l && tr[u].r <= r)
  29. {
  30. // TODO: 修改区间
  31. }
  32. else
  33. {
  34. pushdown(u);
  35. int mid = tr[u].l + tr[u].r >> 1;
  36. if (l <= mid) update(u << 1, l, r, d);
  37. if (r > mid) update(u << 1 | 1, l, r, d);
  38. pushup(u);
  39. }
  40. }
  41. int query(int u, int l, int r)
  42. {
  43. if (tr[u].l >= l && tr[u].r <= r)
  44. {
  45. return ; // TODO 需要补充返回值
  46. }
  47. else
  48. {
  49. pushdown(u);
  50. int mid = tr[u].l + tr[u].r >> 1;
  51. int res = 0;
  52. if (l <= mid ) res = query(u << 1, l, r);
  53. if (r > mid) res += query(u << 1 | 1, l, r);
  54. return res;
  55. }
  56. }

模板题

洛谷P3372

区间更新,区间查询,在一段区间加上一个数,查询一段区间的和。

  1. #include<bits/stdc++.h>
  2. using namespace std;
  3. typedef long long ll;
  4. const int N=1000010;
  5. ll a[N],sum[N<<2],lazy[N<<4]; //树状数组长度最大不会超过4*n
  6. //构造线段树
  7. void build(int rt,int l,int r){
  8. if(l==r){
  9. sum[rt] = a[l];
  10. lazy[rt] = 0;
  11. return;
  12. }//递归构造左右子树
  13. int mid=l+r>>1;
  14. build(rt<<1,l,mid);
  15. build(rt<<1|1,mid+1,r);
  16. sum[rt]=sum[rt<<1] + sum[rt<<1|1];//把父区间[l,r],分成左右两个[l,mid],(mid,r]区间。
  17. }
  18. //向下传递
  19. void push_down(int rt,int l,int r){
  20. int mid=l+r>>1;
  21. if(lazy[rt]){
  22. sum[rt<<1]+=lazy[rt]*(mid-l+1);//维护区间和,如果是别的可以改
  23. sum[rt<<1|1]+=lazy[rt]*(r-mid);
  24. lazy[rt<<1]+=lazy[rt];
  25. lazy[rt<<1|1]+=lazy[rt];//懒标记的值传递
  26. lazy[rt]=0;
  27. }
  28. }
  29. //更新值
  30. void update(int rt,int l,int r,int L,int R,int z){
  31. int mid=l+r>>1;
  32. if(L<=l&&r<=R){
  33. sum[rt]+=z*(r-l+1);//维护区间和,如果是别的可以改
  34. lazy[rt]+=z;
  35. return;
  36. }
  37. push_down(rt,l,r);
  38. if(L<=mid) update(rt<<1,l,mid,L,R,z);
  39. if(R>mid) update(rt<<1|1,mid+1,r,L,R,z);
  40. sum[rt]=sum[rt<<1]+sum[rt<<1|1];
  41. }
  42. //查询值
  43. ll query(int rt,int l,int r,int L,int R){
  44. int mid=l+r>>1;
  45. if(L<=l&&r<=R){
  46. return sum[rt];
  47. }
  48. push_down(rt,l,r);
  49. ll res=0;
  50. if(L<=mid) res+=query(rt<<1,l,mid,L,R);
  51. if(R>mid) res+=query(rt<<1|1,mid+1,r,L,R);
  52. return res;
  53. }
  54. int main(){
  55. int n,m;
  56. scanf("%d%d",&n,&m);
  57. for(int i=1;i<=n;i++){
  58. scanf("%lld",&a[i]);
  59. }
  60. build(1,1,n);
  61. while(m--){
  62. int op;
  63. scanf("%d",&op);
  64. if(op&1){
  65. int x,y,z;
  66. scanf("%d%d%d",&x,&y,&z);
  67. update(1,1,n,x,y,z);
  68. }
  69. else{
  70. int x,y;
  71. scanf("%d%d",&x,&y);
  72. printf("%lld\n",query(1,1,n,x,y));
  73. }
  74. }
  75. return 0;
  76. }

P3373

区间乘上一个数,求一段区间的和

代码:

  1. #include<bits/stdc++.h>
  2. using namespace std;
  3. #define int long long
  4. const int N=100010;
  5. int mod;
  6. int sum[N<<2],mul[N<<2],lazy[N<<2],a[N];
  7. void push_up(int rt){
  8. sum[rt]=(sum[rt<<1]+sum[rt<<1|1])%mod;
  9. }
  10. void push_down(int rt,int l,int r){
  11. int mid=l+r>>1;
  12. if(mul[rt]!=1){
  13. mul[rt<<1]=(mul[rt<<1]*mul[rt])%mod;
  14. mul[rt<<1|1]=(mul[rt<<1|1]*mul[rt])%mod;
  15. lazy[rt<<1]=(lazy[rt<<1]*mul[rt])%mod;
  16. lazy[rt<<1|1]=(lazy[rt<<1|1]*mul[rt])%mod;
  17. sum[rt<<1]=(sum[rt<<1]*mul[rt])%mod;
  18. sum[rt<<1|1]=(sum[rt<<1|1]*mul[rt])%mod;
  19. mul[rt]=1;
  20. }
  21. if(lazy[rt]){
  22. sum[rt<<1]=(sum[rt<<1]+lazy[rt]*(mid-l+1))%mod;
  23. sum[rt<<1|1]=(sum[rt<<1|1]+lazy[rt]*(r-mid))%mod;
  24. lazy[rt<<1]=(lazy[rt<<1]+lazy[rt])%mod;
  25. lazy[rt<<1|1]=(lazy[rt<<1|1]+lazy[rt])%mod;
  26. lazy[rt]=0;
  27. }
  28. }
  29. void build(int rt,int l,int r){
  30. mul[rt]=1;
  31. if(l==r){
  32. sum[rt]=a[l];
  33. return;
  34. }
  35. int mid=l+r>>1;
  36. build(rt<<1,l,mid);build(rt<<1|1,mid+1,r);
  37. push_up(rt);
  38. }
  39. void update1(int rt,int l,int r,int L,int R,int d){
  40. if(L<=l && r<=R){
  41. mul[rt]=mul[rt]*d%mod;
  42. lazy[rt]=lazy[rt]*d%mod;
  43. sum[rt]=sum[rt]*d%mod;
  44. return;
  45. }
  46. push_down(rt,l,r);
  47. int mid=l+r>>1;
  48. if(L<=mid) update1(rt<<1,l,mid,L,R,d);
  49. if(R>mid) update1(rt<<1|1,mid+1,r,L,R,d);
  50. push_up(rt);
  51. }
  52. void update2(int rt,int l,int r,int L,int R,int d){
  53. if(L<=l && r<=R){
  54. sum[rt]=(sum[rt]+d*(r-l+1))%mod;
  55. lazy[rt]=(lazy[rt]+d)%mod;
  56. return;
  57. }
  58. push_down(rt,l,r);
  59. int mid=l+r>>1;
  60. if(L<=mid) update2(rt<<1,l,mid,L,R,d);
  61. if(R>mid) update2(rt<<1|1,mid+1,r,L,R,d);
  62. push_up(rt);
  63. }
  64. int query(int rt,int l,int r,int L,int R){
  65. if(L<=l&& r<=R) return sum[rt];
  66. push_down(rt,l,r);
  67. int res=0;
  68. int mid=l+r>>1;
  69. if(L<=mid) res+=query(rt<<1,l,mid,L,R);
  70. res%=mod;
  71. if(R>mid) res+=query(rt<<1|1,mid+1,r,L,R);
  72. return res%mod;
  73. }
  74. signed main(){
  75. int n,m,p;
  76. cin>>n>>m>>mod;
  77. for(int i=1;i<=n;i++){
  78. cin>>a[i];
  79. }
  80. build(1,1,n);
  81. while(m--){
  82. int op,l,r,d;
  83. cin>>op;
  84. if(op==1){
  85. cin>>l>>r>>d;
  86. update1(1,1,n,l,r,d);
  87. }
  88. else if(op==2){
  89. cin>>l>>r>>d;
  90. update2(1,1,n,l,r,d);
  91. }
  92. else if(op==3){
  93. cin>>l>>r;
  94. cout<<query(1,1,n,l,r)<<endl;
  95. }
  96. }
  97. return 0;
  98. }

维护区间最小值

#include <bits/stdc++.h>
#define endl '\n'
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int N = 5e5 + 7;
const int M = 1e6 + 7;
const int mod = 1e9 + 7;
const double eps = 1e-6;
ll t, n, k,m;
int a[N],tree[N<<2];
void build(int rt,int l,int r){
    if(l==r){
        tree[rt]=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(rt<<1,l,mid);
    build(rt<<1|1,mid+1,r);
    tree[rt]=min(tree[rt<<1],tree[rt<<1|1]);
}

void update(int rt,int l,int r,int x,int y){
    if(l==r){
        tree[rt]=y;
        return;
    }
    int mid=(l+r)>>1;
    if(x<=mid){
        update(rt<<1,l,mid,x,y);
    }
    else{
        update(rt<<1|1,mid+1,r,x,y);
    }
    tree[rt]=min(tree[rt<<1],tree[rt<<1|1]);
}

int query(int rt,int l,int r,int L,int R){
    if(L<=l&&r<=R){
        return tree[rt];
    }
    int mid=(l+r)>>1;
    int ans=1e9+10;
    if(L<=mid) ans=min(ans,query(rt<<1,l,mid,L,R));
    if(R>mid) ans=min(ans,query(rt<<1|1,mid+1,r,L,R));
    return ans;
}

void solve() {
    cin>>n>>m;
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    build(1,1,n);
    while(m--){
        int op,x,y;
        cin>>op;
        if(op==1){
            cin>>x>>y;
            update(1,1,n,x,y);
            a[x]=y;
        }
        else{
            cin>>x;
            int al,ar,l=1,r=x;
            while(l<r){
                //TODO
                int mid=(l+r)>>1;
                if(query(1,1,n,mid,x)>=a[x]) r=mid;
                else l=mid+1;
            }
              al=x-l+1;
              l=x,r=n+1;
              while(l<r){
                  int mid=(l+r)>>1;
                  if(query(1,1,n,x,mid)<a[x]) r=mid;
                  else l=mid+1;
            }
            l--;
            ar=l-x+1;
            cout<<1ll*al*ar<<endl;
        }
    }
}

signed main() {
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
    t=1;
    //cin >> t;
    while(t--) solve();
    return 0;
}

代码:

#include<bits/stdc++.h>
using namespace std;
const int N=100010;
const int INF=0x3f3f3f3f;
int a[N],mn[N<<2],need[N<<2],he[N],res=1,h=0;

void update(int rt,int l,int r,int x,int val){
    if(l==r){
        mn[rt]=val;
        return;
    }
    int mid=l+r>>1;
    if(x<=mid) update(rt<<1,l,mid,x,val);
    if(x>mid) update(rt<<1|1,mid+1,r,x,val);
    mn[rt]=min(mn[rt<<1],mn[rt<<1|1]);
}


int query(int rt,int l,int r,int x,int y){
    if(x<=l&&r<=y) return mn[rt];
    int mid=l+r>>1,ans=INF;
    if(x<=mid) ans=min(query(rt<<1,l,mid,x,y),ans);
    if(y>mid) ans=min(query(rt<<1|1,mid+1,r,x,y),ans);
    return ans;
}

int main(){
    int n,k;
    cin>>k>>n;
    for(int i=0;i<n;i++){
        int x,y;
        cin>>x>>y;
        if(i<k){//前面k个,先初始化 
            update(1,0,k-1,i,x+y);
            he[i]++;
            continue;
        }
        if(mn[1]>x) continue;//如果当前所有都不能满足,直接跳过
        int d=i%k,l,r,ans;
        int tot=query(1,0,k-1,d,k-1);
        if(tot<=x){
            l=d,r=k-1;//i%k的后半段 
        } 
        else{//i%k的前半段 
            l=0,r=d-1; 
        }
        //二分查找区间长度,确定位置最靠近左边的合适的位置
        while(l<=r){
            int mid=l+r>>1;
            if(query(1,0,k-1,l,mid)<=x) ans=mid,r=mid-1;
            else l=mid+1;
        } 
        update(1,0,k-1,ans,x+y);
        he[ans]++;
        res=max(res,he[ans]);//记录处理最多事件的数量
        //printf("res = %d\n",res); 
    }
    for(int i=0;i<k;i++){
        //printf("%d ",he[i]);
        if(he[i]==res) need[++h]=i;
    }
    //puts("");
    for(int i=1;i<=h;i++){
        if(i!=1) printf(" ");
        printf("%d",need[i]);
    }
}