首先查询吉老师的PPT

Segment tree Beats!.pdf

能熟练写好pushUp pushDown upDate函数就能快速完成.jpg

分步练习实现题目的要求

实现区间最值更新

  1. #include<iostream>
  2. #include<cstdio>
  3. using namespace std;
  4. typedef long long ll;
  5. const int maxn = 1e5 + 10;
  6. ll n, m;
  7. ll a[maxn];
  8. struct NODE {
  9. ll l, r, sum;
  10. ll maxn, se, cnt;
  11. ll add_m, add_n;
  12. }t[4*maxn];
  13. /*showTree*/
  14. ll num = 0;
  15. void showTree() {
  16. cout << "show tree below" << endl;
  17. ll count = 0;
  18. for (ll i = 1; i <= num; i++) {
  19. if (i == pow(2, count)) { cout << endl; count++; }
  20. cout << "( " << t[i].l << ' ' << t[i].r << ' ' << t[i].sum << ' ' << t[i].maxn << ' ' << t[i].se <<' '<< t[i].cnt << " )";
  21. }
  22. cout << endl << "end" << endl;
  23. }
  24. /*pushUp*/
  25. void pushUp(ll i) {
  26. t[i].sum = t[2 * i].sum + t[2 * i + 1].sum;
  27. t[i].maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn);
  28. if (t[2 * i].maxn == t[2 * i + 1].maxn) {
  29. t[i].se = max(t[2 * i].se, t[2 * i + 1].se);
  30. t[i].cnt = t[2 * i].cnt + t[2 * i + 1].cnt;
  31. }
  32. else if (t[2 * i].maxn > t[2 * i + 1].maxn) {
  33. t[i].se = max(t[2 * i].se, t[2 * i + 1].maxn);
  34. t[i].cnt = t[2 * i].cnt;
  35. }
  36. else {
  37. t[i].se = max(t[2 * i].maxn, t[2 * i + 1].se);
  38. t[i].cnt = t[2 * i + 1].cnt;
  39. }
  40. }
  41. /*built*/
  42. void built(ll i, ll l, ll r) {
  43. t[i].l = l; t[i].r = r;
  44. t[i].add_m = t[i].add_n = 0;
  45. num++; //测试用变量
  46. if (l == r) {
  47. t[i].sum = t[i].maxn = a[l];
  48. t[i].se = -1e9;
  49. t[i].cnt = 1;
  50. return;
  51. }
  52. ll mid = (l + r) / 2;
  53. built(2 * i, l, mid);
  54. built(2 * i + 1, mid + 1, r);
  55. pushUp(i);
  56. }
  57. /*pushDown 核心代码*/
  58. void pushDown(ll i) {
  59. ll maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn);
  60. if (t[2 * i].maxn == maxn) {
  61. t[2 * i].sum += t[2 * i].cnt * t[i].add_m;
  62. t[2 * i].maxn += t[i].add_m;
  63. t[2 * i].add_m += t[i].add_m;
  64. }
  65. if (t[2 * i + 1].maxn == maxn) {
  66. t[2 * i + 1].sum += t[2 * i + 1].cnt * t[i].add_m;
  67. t[2 * i + 1].maxn += t[i].add_m;
  68. t[2 * i + 1].add_m += t[i].add_m;
  69. }
  70. }
  71. /*change_min*/
  72. void change_min(ll i, ll x, ll y, ll k) {
  73. if (y < t[i].l || x > t[i].r || t[i].maxn <= k) return;
  74. if (x <= t[i].l && t[i].r <= y && t[i].se < k) {
  75. t[i].sum += t[i].cnt * (k - t[i].maxn);
  76. t[i].add_m = k - t[i].maxn;
  77. t[i].maxn = k;
  78. return;
  79. }
  80. pushDown(i);
  81. ll mid = (t[i].l + t[i].r) / 2;
  82. change_min(2 * i, x, y, k);
  83. change_min(2 * i + 1, x, y, k);
  84. pushUp(i);
  85. }
  86. /*findSum*/
  87. ll query(ll i, ll x, ll y) {
  88. if (x <= t[i].l && t[i].r <= y) {
  89. return t[i].sum;
  90. }
  91. pushDown(i);
  92. ll mid = (t[i].l + t[i].r) / 2;
  93. ll t = 0;
  94. if (x <= mid) t = t + query(2 * i, x, y);
  95. if(y > mid) t = t + query(2 * i + 1, x, y);
  96. return t;
  97. }
  98. int main() {
  99. scanf("%lld%lld", &n, &m);
  100. for (int i = 1; i <= n; i++) {
  101. scanf("%lld", &a[i]);
  102. }
  103. built(1, 1, n);
  104. showTree(); //测试用函数
  105. int flag;
  106. ll x, y, k;
  107. for (int i = 1; i <= m; i++) {
  108. scanf("%d", &flag);
  109. if (flag == 4) {
  110. scanf("%lld%lld", &x, &y);
  111. printf("%lld\n", query(1, x, y));
  112. }
  113. else if (flag == 2) {
  114. scanf("%lld%lld%lld", &x, &y, &k);
  115. change_min(1, x, y, k);
  116. }
  117. }
  118. return 0;
  119. }

在区间最值更新基础上增加区间加法,且实现代码复用

精髓在于通过upDate函数,统一处理两种懒惰标记,或者说,两种更新动作!

  1. #include<iostream>
  2. #include<cstdio>
  3. using namespace std;
  4. typedef long long ll;
  5. const int maxn = 1e5 + 10;
  6. ll n, m;
  7. ll a[maxn];
  8. struct NODE {
  9. ll l, r, sum;
  10. ll maxn, se, cnt;
  11. ll add_m, add_n;
  12. }t[4*maxn];
  13. /*showTree*/
  14. ll num = 0;
  15. void showTree() {
  16. cout << "show tree below" << endl;
  17. ll count = 0;
  18. for (ll i = 1; i <= num; i++) {
  19. if (i == pow(2, count)) { cout << endl; count++; }
  20. cout << "( " << t[i].l << ' ' << t[i].r << ' ' << t[i].sum << ' ' << t[i].maxn << ' ' << t[i].se <<' '<< t[i].cnt << " )";
  21. }
  22. cout << endl << "end" << endl;
  23. }
  24. /*pushUp*/
  25. void pushUp(ll i) {
  26. t[i].sum = t[2 * i].sum + t[2 * i + 1].sum;
  27. t[i].maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn);
  28. if (t[2 * i].maxn == t[2 * i + 1].maxn) {
  29. t[i].se = max(t[2 * i].se, t[2 * i + 1].se);
  30. t[i].cnt = t[2 * i].cnt + t[2 * i + 1].cnt;
  31. }
  32. else if (t[2 * i].maxn > t[2 * i + 1].maxn) {
  33. t[i].se = max(t[2 * i].se, t[2 * i + 1].maxn);
  34. t[i].cnt = t[2 * i].cnt;
  35. }
  36. else {
  37. t[i].se = max(t[2 * i].maxn, t[2 * i + 1].se);
  38. t[i].cnt = t[2 * i + 1].cnt;
  39. }
  40. }
  41. /*built*/
  42. void built(ll i, ll l, ll r) {
  43. t[i].l = l; t[i].r = r;
  44. t[i].add_m = t[i].add_n = 0;
  45. num++; //测试用变量
  46. if (l == r) {
  47. t[i].sum = t[i].maxn = a[l];
  48. t[i].se = -1e9;
  49. t[i].cnt = 1;
  50. return;
  51. }
  52. ll mid = (l + r) / 2;
  53. built(2 * i, l, mid);
  54. built(2 * i + 1, mid + 1, r);
  55. pushUp(i);
  56. }
  57. /*upDate 核心代码*/
  58. void upDate(ll i, ll add_m, ll add_n) {
  59. t[i].sum += t[i].cnt * add_m + (t[i].r - t[i].l + 1 - t[i].cnt) * add_n;
  60. t[i].maxn += add_m;
  61. //t[i].se += add_n; 错误!注意变量的范围
  62. if (t[i].se != -1e9) t[i].se += add_n;
  63. t[i].add_m += add_m;
  64. t[i].add_n = add_n;
  65. }
  66. /*pushDown 核心代码*/
  67. void pushDown(ll i) {
  68. ll maxn = max(t[2 * i].maxn, t[2 * i + 1].maxn);
  69. if (t[2 * i].maxn == maxn)
  70. upDate(2 * i, t[i].add_m, t[i].add_n);
  71. else
  72. upDate(2 * i, t[i].add_n, t[i].add_n);
  73. if (t[2 * i + 1].maxn == maxn)
  74. upDate(2 * i + 1, t[i].add_m, t[i].add_n);
  75. else
  76. upDate(2 * i + 1, t[i].add_n, t[i].add_n);
  77. t[i].add_n = 0;
  78. t[i].add_m = 0;
  79. }
  80. /*add_segment*/
  81. void add_segment(ll i, ll x, ll y, ll k) {
  82. if (x <= t[i].l && t[i].r <= y) {
  83. upDate(i, k, k);
  84. return;
  85. }
  86. pushDown(i);
  87. ll mid = (t[i].r + t[i].l) / 2;
  88. if (x <= mid) add_segment(2 * i, x, y, k);
  89. if(y > mid) add_segment(2 * i + 1, x, y, k);
  90. pushUp(i);
  91. }
  92. /*change_min*/
  93. void change_min(ll i, ll x, ll y, ll k) {
  94. if (y < t[i].l || x > t[i].r || t[i].maxn <= k) return;
  95. if (x <= t[i].l && t[i].r <= y && t[i].se < k) {
  96. upDate(i, k-t[i].maxn, 0);
  97. return;
  98. }
  99. pushDown(i);
  100. ll mid = (t[i].l + t[i].r) / 2;
  101. change_min(2 * i, x, y, k);
  102. change_min(2 * i + 1, x, y, k);
  103. pushUp(i);
  104. }
  105. /*findSum*/
  106. ll query(ll i, ll x, ll y) {
  107. if (x <= t[i].l && t[i].r <= y) {
  108. return t[i].sum;
  109. }
  110. pushDown(i);
  111. ll mid = (t[i].l + t[i].r) / 2;
  112. ll t = 0;
  113. if (x <= mid) t = t + query(2 * i, x, y);
  114. if(y > mid) t = t + query(2 * i + 1, x, y);
  115. return t;
  116. }
  117. int main() {
  118. scanf("%lld%lld", &n, &m);
  119. for (int i = 1; i <= n; i++) {
  120. scanf("%lld", &a[i]);
  121. }
  122. built(1, 1, n);
  123. showTree(); //测试用函数
  124. int flag;
  125. ll x, y, k;
  126. for (int i = 1; i <= m; i++) {
  127. scanf("%d", &flag);
  128. if (flag == 4) {
  129. scanf("%lld%lld", &x, &y);
  130. printf("%lld\n", query(1, x, y));
  131. }
  132. else if (flag == 2) {
  133. scanf("%lld%lld%lld", &x, &y, &k);
  134. change_min(1, x, y, k);
  135. }
  136. else if (flag == 1) {
  137. scanf("%lld%lld%lld", &x, &y, &k);
  138. add_segment(1, x, y, k);
  139. }
  140. showTree(); //测试用函数
  141. }
  142. return 0;
  143. }