首先查询吉老师的PPT
Segment tree Beats!.pdf
能熟练写好pushUp pushDown upDate函数就能快速完成.jpg
分步练习实现题目的要求
实现区间最值更新
#include<iostream>#include<cstdio>using namespace std;typedef long long ll;const int maxn = 1e5 + 10;ll n, m;ll a[maxn];struct NODE { ll l, r, sum; ll maxn, se, cnt; ll add_m, add_n;}t[4*maxn];/*showTree*/ll num = 0;void showTree() { cout << "show tree below" << endl; ll count = 0; for (ll i = 1; i <= num; i++) { if (i == pow(2, count)) { cout << endl; count++; } cout << "( " << t[i].l << ' ' << t[i].r << ' ' << t[i].sum << ' ' << t[i].maxn << ' ' << t[i].se <<' '<< t[i].cnt << " )"; } cout << endl << "end" << endl;}/*pushUp*/void pushUp(ll i) { t[i].sum = t[2 * i].sum + t[2 * i + 1].sum; t[i].maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn); if (t[2 * i].maxn == t[2 * i + 1].maxn) { t[i].se = max(t[2 * i].se, t[2 * i + 1].se); t[i].cnt = t[2 * i].cnt + t[2 * i + 1].cnt; } else if (t[2 * i].maxn > t[2 * i + 1].maxn) { t[i].se = max(t[2 * i].se, t[2 * i + 1].maxn); t[i].cnt = t[2 * i].cnt; } else { t[i].se = max(t[2 * i].maxn, t[2 * i + 1].se); t[i].cnt = t[2 * i + 1].cnt; }}/*built*/void built(ll i, ll l, ll r) { t[i].l = l; t[i].r = r; t[i].add_m = t[i].add_n = 0; num++; //测试用变量 if (l == r) { t[i].sum = t[i].maxn = a[l]; t[i].se = -1e9; t[i].cnt = 1; return; } ll mid = (l + r) / 2; built(2 * i, l, mid); built(2 * i + 1, mid + 1, r); pushUp(i);}/*pushDown 核心代码*/void pushDown(ll i) { ll maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn); if (t[2 * i].maxn == maxn) { t[2 * i].sum += t[2 * i].cnt * t[i].add_m; t[2 * i].maxn += t[i].add_m; t[2 * i].add_m += t[i].add_m; } if (t[2 * i + 1].maxn == maxn) { t[2 * i + 1].sum += t[2 * i + 1].cnt * t[i].add_m; t[2 * i + 1].maxn += t[i].add_m; t[2 * i + 1].add_m += t[i].add_m; }}/*change_min*/void change_min(ll i, ll x, ll y, ll k) { if (y < t[i].l || x > t[i].r || t[i].maxn <= k) return; if (x <= t[i].l && t[i].r <= y && t[i].se < k) { t[i].sum += t[i].cnt * (k - t[i].maxn); t[i].add_m = k - t[i].maxn; t[i].maxn = k; return; } pushDown(i); ll mid = (t[i].l + t[i].r) / 2; change_min(2 * i, x, y, k); change_min(2 * i + 1, x, y, k); pushUp(i);}/*findSum*/ll query(ll i, ll x, ll y) { if (x <= t[i].l && t[i].r <= y) { return t[i].sum; } pushDown(i); ll mid = (t[i].l + t[i].r) / 2; ll t = 0; if (x <= mid) t = t + query(2 * i, x, y); if(y > mid) t = t + query(2 * i + 1, x, y); return t;}int main() { scanf("%lld%lld", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &a[i]); } built(1, 1, n); showTree(); //测试用函数 int flag; ll x, y, k; for (int i = 1; i <= m; i++) { scanf("%d", &flag); if (flag == 4) { scanf("%lld%lld", &x, &y); printf("%lld\n", query(1, x, y)); } else if (flag == 2) { scanf("%lld%lld%lld", &x, &y, &k); change_min(1, x, y, k); } } return 0;}
在区间最值更新基础上增加区间加法,且实现代码复用
精髓在于通过upDate函数,统一处理两种懒惰标记,或者说,两种更新动作!
#include<iostream>#include<cstdio>using namespace std;typedef long long ll;const int maxn = 1e5 + 10;ll n, m;ll a[maxn];struct NODE { ll l, r, sum; ll maxn, se, cnt; ll add_m, add_n;}t[4*maxn];/*showTree*/ll num = 0;void showTree() { cout << "show tree below" << endl; ll count = 0; for (ll i = 1; i <= num; i++) { if (i == pow(2, count)) { cout << endl; count++; } cout << "( " << t[i].l << ' ' << t[i].r << ' ' << t[i].sum << ' ' << t[i].maxn << ' ' << t[i].se <<' '<< t[i].cnt << " )"; } cout << endl << "end" << endl;}/*pushUp*/void pushUp(ll i) { t[i].sum = t[2 * i].sum + t[2 * i + 1].sum; t[i].maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn); if (t[2 * i].maxn == t[2 * i + 1].maxn) { t[i].se = max(t[2 * i].se, t[2 * i + 1].se); t[i].cnt = t[2 * i].cnt + t[2 * i + 1].cnt; } else if (t[2 * i].maxn > t[2 * i + 1].maxn) { t[i].se = max(t[2 * i].se, t[2 * i + 1].maxn); t[i].cnt = t[2 * i].cnt; } else { t[i].se = max(t[2 * i].maxn, t[2 * i + 1].se); t[i].cnt = t[2 * i + 1].cnt; }}/*built*/void built(ll i, ll l, ll r) { t[i].l = l; t[i].r = r; t[i].add_m = t[i].add_n = 0; num++; //测试用变量 if (l == r) { t[i].sum = t[i].maxn = a[l]; t[i].se = -1e9; t[i].cnt = 1; return; } ll mid = (l + r) / 2; built(2 * i, l, mid); built(2 * i + 1, mid + 1, r); pushUp(i);}/*upDate 核心代码*/void upDate(ll i, ll add_m, ll add_n) { t[i].sum += t[i].cnt * add_m + (t[i].r - t[i].l + 1 - t[i].cnt) * add_n; t[i].maxn += add_m; //t[i].se += add_n; 错误!注意变量的范围 if (t[i].se != -1e9) t[i].se += add_n; t[i].add_m += add_m; t[i].add_n = add_n;}/*pushDown 核心代码*/void pushDown(ll i) { ll maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn); if (t[2 * i].maxn == maxn) upDate(2 * i, t[i].add_m, t[i].add_n); else upDate(2 * i, t[i].add_n, t[i].add_n); if (t[2 * i + 1].maxn == maxn) upDate(2 * i + 1, t[i].add_m, t[i].add_n); else upDate(2 * i + 1, t[i].add_n, t[i].add_n); t[i].add_n = 0; t[i].add_m = 0;}/*add_segment*/void add_segment(ll i, ll x, ll y, ll k) { if (x <= t[i].l && t[i].r <= y) { upDate(i, k, k); return; } pushDown(i); ll mid = (t[i].r + t[i].l) / 2; if (x <= mid) add_segment(2 * i, x, y, k); if(y > mid) add_segment(2 * i + 1, x, y, k); pushUp(i);}/*change_min*/void change_min(ll i, ll x, ll y, ll k) { if (y < t[i].l || x > t[i].r || t[i].maxn <= k) return; if (x <= t[i].l && t[i].r <= y && t[i].se < k) { upDate(i, k-t[i].maxn, 0); return; } pushDown(i); ll mid = (t[i].l + t[i].r) / 2; change_min(2 * i, x, y, k); change_min(2 * i + 1, x, y, k); pushUp(i);}/*findSum*/ll query(ll i, ll x, ll y) { if (x <= t[i].l && t[i].r <= y) { return t[i].sum; } pushDown(i); ll mid = (t[i].l + t[i].r) / 2; ll t = 0; if (x <= mid) t = t + query(2 * i, x, y); if(y > mid) t = t + query(2 * i + 1, x, y); return t;}int main() { scanf("%lld%lld", &n, &m); for (int i = 1; i <= n; i++) { scanf("%lld", &a[i]); } built(1, 1, n); showTree(); //测试用函数 int flag; ll x, y, k; for (int i = 1; i <= m; i++) { scanf("%d", &flag); if (flag == 4) { scanf("%lld%lld", &x, &y); printf("%lld\n", query(1, x, y)); } else if (flag == 2) { scanf("%lld%lld%lld", &x, &y, &k); change_min(1, x, y, k); } else if (flag == 1) { scanf("%lld%lld%lld", &x, &y, &k); add_segment(1, x, y, k); } showTree(); //测试用函数 } return 0;}