难点:

  • 乘法懒标记和加法懒标记的处理顺序,pushDown函数
  • 区间乘法对加法标记的影响,mulSegment函数
  • 建树的优化,built函数
  • 查询函数的优化,query函数
  • 数据庞大,如何贯彻执行取模
  • 精度问题:干脆都用long long吧.jpg

代码

  1. #include<iostream>
  2. #include<cstdio>
  3. using namespace std;
  4. typedef long long ll;
  5. const int maxn = 1e5 + 7;
  6. ll a[maxn];
  7. int n, m, mod;
  8. struct node {
  9. ll l, r;
  10. ll sum, mul, add;
  11. }st[4*maxn];
  12. /*构建树*/
  13. void built(int i, ll l, ll r) {
  14. st[i].l = l; st[i].r = r;
  15. st[i].add = 0; st[i].mul = 1;
  16. if (l == r) {
  17. st[i].sum = a[l] % mod;
  18. return;
  19. }
  20. ll mid = (l + r) / 2;
  21. built(2 * i, l, mid);
  22. built(2 * i + 1, mid + 1, r);
  23. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯
  24. }
  25. /*下推标记(核心代码)*/
  26. void pushDown(int i) {
  27. //如果是叶节点
  28. if (st[i].l == st[i].r) return;
  29. ll mid = (st[i].l + st[i].r) / 2;
  30. //维护左右子节点区间值
  31. st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add) % mod) % mod;
  32. st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add) % mod) % mod;
  33. //维护左右子节点懒标记
  34. st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod;
  35. st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod;
  36. st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[i].add) % mod;
  37. st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[i].add) % mod;
  38. //父节点初始化
  39. st[i].mul = 1;
  40. st[i].add = 0;
  41. }
  42. /*区间加*/
  43. void addSegment(int i, ll x, ll y, ll k) {
  44. if (x <= st[i].l && st[i].r <= y) {
  45. st[i].sum = (ll)(st[i].sum + (st[i].r - st[i].l + 1) * k) % mod;
  46. st[i].add = (st[i].add + k) % mod;
  47. return;
  48. }
  49. pushDown(i);
  50. ll mid = (st[i].l + st[i].r) / 2;
  51. if (x <= mid) addSegment(2 * i, x, y, k);
  52. if (y > mid) addSegment(2 * i + 1, x, y, k);
  53. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯
  54. }
  55. /*区间乘*/
  56. void mulSegment(int i, ll x, ll y, ll k) {
  57. if (x <= st[i].l && st[i].r <= y) {
  58. st[i].sum = (st[i].sum * k) % mod;
  59. st[i].add = (st[i].add * k) % mod; //(重要步骤)
  60. st[i].mul = (st[i].mul * k) % mod;
  61. return;
  62. }
  63. pushDown(i);
  64. ll mid = (st[i].l + st[i].r) / 2;
  65. if (x <= mid) mulSegment(2 * i, x, y, k);
  66. if (y > mid) mulSegment(2 * i + 1, x, y, k);
  67. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯
  68. }
  69. /*查询*/
  70. ll query(int i, ll x, ll y) {
  71. if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod;
  72. pushDown(i); //重要点
  73. ll ans = 0;
  74. ll mid = (st[i].l + st[i].r) / 2;
  75. if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod;
  76. if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod;
  77. return ans;
  78. }
  79. int main() {
  80. cin >> n >> m >> mod;
  81. for (int i = 1; i <= n; i++) {
  82. //scanf_s("%lld", &a[i]);
  83. scanf("%lld", &a[i]);
  84. }
  85. built(1, 1, n);
  86. int flag;
  87. ll x, y, k;
  88. for (int i = 1; i <= m; i++) {
  89. //scanf_s("%d", &flag);
  90. scanf("%d", &flag);
  91. if (flag == 1) {
  92. //scanf_s("%lld%lld%lld", &x, &y, &k);
  93. scanf("%lld%lld%lld", &x, &y, &k);
  94. mulSegment(1, x, y, k);
  95. }
  96. else if (flag == 2) {
  97. //scanf_s("%lld%lld%lld", &x, &y, &k);
  98. scanf("%lld%lld%lld", &x, &y, &k);
  99. addSegment(1, x, y, k);
  100. }
  101. else {
  102. //scanf_s("%lld%lld", &x, &y);
  103. scanf("%lld%lld", &x, &y);
  104. //printf_s("%lld\n", query(1, x, y));
  105. printf("%lld\n", query(1, x, y));
  106. }
  107. }
  108. return 0;
  109. }

BUG

第一次

  1. /*下推标记(错误代码)*/
  2. void pushDown(int i) {
  3. //如果是叶节点
  4. if (st[i].l == st[i].r) return;
  5. ll mid = (st[i].l + st[i].r) / 2;
  6. //维护左右子节点区间值
  7. st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add)%mod) % mod;
  8. st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add)%mod) % mod;
  9. //维护左右子节点懒标记
  10. st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod;
  11. st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod;
  12. st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[2 * i].add) % mod;
  13. st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[2 * i + 1].add) % mod;
  14. //父节点初始化
  15. st[i].mul = 1;
  16. st[i].add = 0;
  17. }
  18. /*下推标记(正确代码)*/
  19. void pushDown(int i) {
  20. //如果是叶节点
  21. if (st[i].l == st[i].r) return;
  22. ll mid = (st[i].l + st[i].r) / 2;
  23. //维护左右子节点区间值
  24. st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add)%mod) % mod;
  25. st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add)%mod) % mod;
  26. //维护左右子节点懒标记
  27. st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod;
  28. st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod;
  29. st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[i].add) % mod; //
  30. st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[i].add) % mod; //
  31. //父节点初始化
  32. st[i].mul = 1;
  33. st[i].add = 0;
  34. }
  35. /*区间乘(错误代码)*/
  36. void mulSegment(int i, ll x, ll y, ll k) {
  37. if (x <= st[i].l && st[i].r <= y) {
  38. st[i].sum = (st[i].sum * k) % mod;
  39. st[i].add = (st[i].add * k) % mod; //(重要步骤)
  40. st[i].mul = (st[i].mul * k) % mod;
  41. }
  42. pushDown(i);
  43. ll mid = (st[i].l + st[i].r) / 2;
  44. if (x <= mid) mulSegment(2 * i, x, y, k);
  45. if (y > mid) mulSegment(2 * i + 1, x, y, k);
  46. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯
  47. }
  48. /*区间乘(正确代码)*/
  49. void mulSegment(int i, ll x, ll y, ll k) {
  50. if (x <= st[i].l && st[i].r <= y) {
  51. st[i].sum = (st[i].sum * k) % mod;
  52. st[i].add = (st[i].add * k) % mod; //(重要步骤)
  53. st[i].mul = (st[i].mul * k) % mod;
  54. return; //
  55. }
  56. pushDown(i);
  57. ll mid = (st[i].l + st[i].r) / 2;
  58. if (x <= mid) mulSegment(2 * i, x, y, k);
  59. if (y > mid) mulSegment(2 * i + 1, x, y, k);
  60. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯
  61. }
  62. /*查询(错误代码,两个错误)*/
  63. ll query(int i, ll x, ll y) {
  64. if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod;
  65. ll ans = 0;
  66. ll mid = (st[i].l + st[i].r) / 2;
  67. if (x <= mid) ans = (ans + st[2 * i].sum) % mod;
  68. if (y > mid) ans = (ans + st[2 * i + 1].sum) % mod;
  69. return ans;
  70. }
  71. /*查询(正确代码)*/
  72. ll query(int i, ll x, ll y) {
  73. if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod;
  74. pushDown(i); //重要点
  75. ll ans = 0;
  76. ll mid = (st[i].l + st[i].r) / 2;
  77. if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; //
  78. if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; //
  79. return ans;
  80. }

第二次

  1. /*区间乘(错误代码)*/
  2. void mul(ll i, ll x, ll y, ll k) {
  3. if (x <= st[i].l && st[i].r <= y) {
  4. st[i].sum = (st[i].sum * k) % mod;
  5. st[i].mul = (st[i].mul * k) % mod;
  6. st[i].add = (st[i].add * k) % mod; //对于加动作之后的乘动作,add记录下来
  7. return;
  8. }
  9. pushDown(i); //WRONG POINT
  10. ll mid = (st[i].l + st[i].r) / 2;
  11. if (x <= st[i].l) mul(2 * i, x, y, k);
  12. if (y > st[i].r) mul(2 * i + 1, x, y, k);
  13. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod;
  14. }
  15. /*区间乘(正确代码)*/
  16. void mul(ll i, ll x, ll y, ll k) {
  17. if (x <= st[i].l && st[i].r <= y) {
  18. st[i].sum = (st[i].sum * k) % mod;
  19. st[i].mul = (st[i].mul * k) % mod;
  20. st[i].add = (st[i].add * k) % mod; //对于加动作之后的乘动作,add记录下来
  21. return;
  22. }
  23. pushDown(i); //WRONG POINT
  24. ll mid = (st[i].l + st[i].r) / 2;
  25. if (x <= mid) mul(2 * i, x, y, k); //
  26. if (y > mid) mul(2 * i + 1, x, y, k); //
  27. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod;
  28. }
  29. /*查询(错误代码)*/
  30. ll query(int i, ll x, ll y) {
  31. if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod;
  32. ll ans = 0;
  33. ll mid = (st[i].l + st[i].r) / 2;
  34. if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod;
  35. if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod;
  36. return ans;
  37. }
  38. /*查询(正确代码)*/
  39. ll query(int i, ll x, ll y) {
  40. if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod;
  41. pushDown(i); //
  42. ll ans = 0;
  43. ll mid = (st[i].l + st[i].r) / 2;
  44. if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod;
  45. if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod;
  46. return ans;
  47. }
  48. /*建树没有贯彻取模*/
  49. void built(ll i, ll l, ll r) {
  50. st[i].l = l; st[i].r = r;
  51. st[i].add = 0; st[i].mul = 1;
  52. if (l == r) {
  53. st[i].sum = a[l] % mod;
  54. return;
  55. }
  56. ll mid = (l + r) / 2;
  57. built(2 * i, l, mid);
  58. built(2 * i + 1, mid + 1, r);
  59. st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //错误:没有贯彻取模
  60. }

第三次 一遍AC

  1. #include<iostream>
  2. #include<cstdio>
  3. using namespace std;
  4. typedef long long ll;
  5. const int maxn = 1e5 + 10;
  6. ll n, m, mod;
  7. ll a[maxn];
  8. struct NODE {
  9. ll l, r, sum;
  10. ll add, mu;
  11. }t[4*maxn];
  12. /*built*/
  13. void built(ll i, ll l, ll r) {
  14. t[i].l = l; t[i].r = r;
  15. t[i].mu = 1; t[i].add = 0;
  16. if (l == r) {
  17. t[i].sum = a[l] % mod;
  18. return;
  19. }
  20. ll mid = (l + r) / 2;
  21. built(2 * i, l, mid);
  22. built(2 * i + 1, mid + 1, r);
  23. t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;
  24. }
  25. /*pushDown*/
  26. void pushDown(ll i) {
  27. if (t[i].l == t[i].r) return;
  28. t[2 * i].sum = (t[2 * i].sum * t[i].mu + (t[2 * i].r - t[2 * i].l + 1) * t[i].add) % mod;
  29. t[2 * i + 1].sum = (t[2 * i + 1].sum * t[i].mu + (t[2 * i + 1].r - t[2 * i + 1].l + 1) * t[i].add) % mod;
  30. t[2 * i].mu = (t[2 * i].mu * t[i].mu) % mod;
  31. t[2 * i + 1].mu = (t[2 * i + 1].mu * t[i].mu) % mod;
  32. t[2 * i].add = (t[2 * i].add * t[i].mu + t[i].add) % mod;
  33. t[2 * i + 1].add = (t[2 * i + 1].add * t[i].mu + t[i].add) % mod;
  34. t[i].mu = 1;
  35. t[i].add = 0;
  36. }
  37. /*add*/
  38. void add(ll i, ll x, ll y, ll k) {
  39. if (x <= t[i].l && t[i].r <= y) {
  40. t[i].sum = (t[i].sum + (t[i].r - t[i].l + 1) * k) % mod;
  41. t[i].add = (t[i].add + k) % mod;
  42. return;
  43. }
  44. pushDown(i); //POINT
  45. ll mid = (t[i].l + t[i].r) / 2;
  46. if (x <= mid) add(2 * i, x, y, k);
  47. if (y > mid) add(2 * i + 1, x, y, k);
  48. t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;
  49. }
  50. /*multiply*/
  51. void mul(ll i, ll x, ll y, ll k) {
  52. if (x <= t[i].l && t[i].r <= y) {
  53. t[i].sum = (t[i].sum * k) % mod;
  54. t[i].mu = (t[i].mu * k) % mod;
  55. t[i].add = (t[i].add * k) % mod; //POINT
  56. return;
  57. }
  58. pushDown(i); //POINT
  59. ll mid = (t[i].l + t[i].r) / 2;
  60. if (x <= mid) mul(2 * i, x, y, k);
  61. if (y > mid) mul(2 * i + 1, x, y, k);
  62. t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;
  63. }
  64. /*find*/
  65. ll query(ll i, ll x, ll y) {
  66. if (x <= t[i].l && t[i].r <= y) {
  67. return t[i].sum % mod;
  68. }
  69. pushDown(i); //POINT
  70. ll mid = (t[i].l + t[i].r) / 2;
  71. ll ans = 0;
  72. if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod;
  73. if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod;
  74. return ans;
  75. }
  76. int main() {
  77. scanf("%lld%lld%lld", &n, &m, &mod);
  78. for (int i = 1; i <= n; i++) {
  79. scanf("%lld", &a[i]);
  80. }
  81. built(1, 1, n);
  82. int flag;
  83. ll x, y, k;
  84. for (int i = 1; i <= m; i++) {
  85. scanf("%d", &flag);
  86. if (flag == 1) {
  87. scanf("%lld%lld%lld", &x, &y, &k);
  88. mul(1, x, y, k);
  89. }
  90. else if (flag == 2) {
  91. scanf("%lld%lld%lld", &x, &y, &k);
  92. add(1, x, y, k);
  93. }
  94. else {
  95. scanf("%lld%lld", &x, &y);
  96. printf("%lld\n", query(1, x, y));
  97. }
  98. }
  99. return 0;
  100. }