难点:
- 乘法懒标记和加法懒标记的处理顺序,pushDown函数
- 区间乘法对加法标记的影响,mulSegment函数
- 建树的优化,built函数
- 查询函数的优化,query函数
- 数据庞大,如何贯彻执行取模
- 精度问题:干脆都用long long吧.jpg
代码
#include<iostream>#include<cstdio>using namespace std;typedef long long ll;const int maxn = 1e5 + 7;ll a[maxn];int n, m, mod;struct node { ll l, r; ll sum, mul, add;}st[4*maxn];/*构建树*/void built(int i, ll l, ll r) { st[i].l = l; st[i].r = r; st[i].add = 0; st[i].mul = 1; if (l == r) { st[i].sum = a[l] % mod; return; } ll mid = (l + r) / 2; built(2 * i, l, mid); built(2 * i + 1, mid + 1, r); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯}/*下推标记(核心代码)*/void pushDown(int i) { //如果是叶节点 if (st[i].l == st[i].r) return; ll mid = (st[i].l + st[i].r) / 2; //维护左右子节点区间值 st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add) % mod) % mod; st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add) % mod) % mod; //维护左右子节点懒标记 st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod; st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod; st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[i].add) % mod; st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[i].add) % mod; //父节点初始化 st[i].mul = 1; st[i].add = 0;}/*区间加*/void addSegment(int i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (ll)(st[i].sum + (st[i].r - st[i].l + 1) * k) % mod; st[i].add = (st[i].add + k) % mod; return; } pushDown(i); ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) addSegment(2 * i, x, y, k); if (y > mid) addSegment(2 * i + 1, x, y, k); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯}/*区间乘*/void mulSegment(int i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (st[i].sum * k) % mod; st[i].add = (st[i].add * k) % mod; //(重要步骤) st[i].mul = (st[i].mul * k) % mod; return; } pushDown(i); ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) mulSegment(2 * i, x, y, k); if (y > mid) mulSegment(2 * i + 1, x, y, k); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯}/*查询*/ll query(int i, ll x, ll y) { if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod; pushDown(i); //重要点 ll ans = 0; ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; return ans;}int main() { cin >> n >> m >> mod; for (int i = 1; i <= n; i++) { //scanf_s("%lld", &a[i]); scanf("%lld", &a[i]); } built(1, 1, n); int flag; ll x, y, k; for (int i = 1; i <= m; i++) { //scanf_s("%d", &flag); scanf("%d", &flag); if (flag == 1) { //scanf_s("%lld%lld%lld", &x, &y, &k); scanf("%lld%lld%lld", &x, &y, &k); mulSegment(1, x, y, k); } else if (flag == 2) { //scanf_s("%lld%lld%lld", &x, &y, &k); scanf("%lld%lld%lld", &x, &y, &k); addSegment(1, x, y, k); } else { //scanf_s("%lld%lld", &x, &y); scanf("%lld%lld", &x, &y); //printf_s("%lld\n", query(1, x, y)); printf("%lld\n", query(1, x, y)); } } return 0;}
BUG
第一次
/*下推标记(错误代码)*/void pushDown(int i) { //如果是叶节点 if (st[i].l == st[i].r) return; ll mid = (st[i].l + st[i].r) / 2; //维护左右子节点区间值 st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add)%mod) % mod; st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add)%mod) % mod; //维护左右子节点懒标记 st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod; st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod; st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[2 * i].add) % mod; st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[2 * i + 1].add) % mod; //父节点初始化 st[i].mul = 1; st[i].add = 0;}/*下推标记(正确代码)*/void pushDown(int i) { //如果是叶节点 if (st[i].l == st[i].r) return; ll mid = (st[i].l + st[i].r) / 2; //维护左右子节点区间值 st[2 * i].sum = (ll)(st[2 * i].sum * st[i].mul + ((st[2 * i].r - st[2 * i].l + 1) * st[i].add)%mod) % mod; st[2 * i + 1].sum = (ll)(st[2 * i + 1].sum * st[i].mul + ((st[2 * i + 1].r - st[2 * i + 1].l + 1) * st[i].add)%mod) % mod; //维护左右子节点懒标记 st[2 * i].mul = (ll)(st[2 * i].mul * st[i].mul) % mod; st[2 * i + 1].mul = (ll)(st[2 * i + 1].mul * st[i].mul) % mod; st[2 * i].add = (ll)(st[2 * i].add * st[i].mul + st[i].add) % mod; // st[2 * i + 1].add = (ll)(st[2 * i + 1].add * st[i].mul + st[i].add) % mod; // //父节点初始化 st[i].mul = 1; st[i].add = 0;}/*区间乘(错误代码)*/void mulSegment(int i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (st[i].sum * k) % mod; st[i].add = (st[i].add * k) % mod; //(重要步骤) st[i].mul = (st[i].mul * k) % mod; } pushDown(i); ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) mulSegment(2 * i, x, y, k); if (y > mid) mulSegment(2 * i + 1, x, y, k); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯}/*区间乘(正确代码)*/void mulSegment(int i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (st[i].sum * k) % mod; st[i].add = (st[i].add * k) % mod; //(重要步骤) st[i].mul = (st[i].mul * k) % mod; return; // } pushDown(i); ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) mulSegment(2 * i, x, y, k); if (y > mid) mulSegment(2 * i + 1, x, y, k); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //回溯}/*查询(错误代码,两个错误)*/ll query(int i, ll x, ll y) { if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod; ll ans = 0; ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) ans = (ans + st[2 * i].sum) % mod; if (y > mid) ans = (ans + st[2 * i + 1].sum) % mod; return ans;}/*查询(正确代码)*/ll query(int i, ll x, ll y) { if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod; pushDown(i); //重要点 ll ans = 0; ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; // if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; // return ans;}
第二次
/*区间乘(错误代码)*/void mul(ll i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (st[i].sum * k) % mod; st[i].mul = (st[i].mul * k) % mod; st[i].add = (st[i].add * k) % mod; //对于加动作之后的乘动作,add记录下来 return; } pushDown(i); //WRONG POINT ll mid = (st[i].l + st[i].r) / 2; if (x <= st[i].l) mul(2 * i, x, y, k); if (y > st[i].r) mul(2 * i + 1, x, y, k); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod;}/*区间乘(正确代码)*/void mul(ll i, ll x, ll y, ll k) { if (x <= st[i].l && st[i].r <= y) { st[i].sum = (st[i].sum * k) % mod; st[i].mul = (st[i].mul * k) % mod; st[i].add = (st[i].add * k) % mod; //对于加动作之后的乘动作,add记录下来 return; } pushDown(i); //WRONG POINT ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) mul(2 * i, x, y, k); // if (y > mid) mul(2 * i + 1, x, y, k); // st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod;}/*查询(错误代码)*/ll query(int i, ll x, ll y) { if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod; ll ans = 0; ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; return ans;}/*查询(正确代码)*/ll query(int i, ll x, ll y) { if (x <= st[i].l && st[i].r <= y) return st[i].sum % mod; pushDown(i); // ll ans = 0; ll mid = (st[i].l + st[i].r) / 2; if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; return ans;}/*建树没有贯彻取模*/void built(ll i, ll l, ll r) { st[i].l = l; st[i].r = r; st[i].add = 0; st[i].mul = 1; if (l == r) { st[i].sum = a[l] % mod; return; } ll mid = (l + r) / 2; built(2 * i, l, mid); built(2 * i + 1, mid + 1, r); st[i].sum = (st[2 * i].sum + st[2 * i + 1].sum) % mod; //错误:没有贯彻取模}
第三次 一遍AC
#include<iostream>#include<cstdio>using namespace std;typedef long long ll;const int maxn = 1e5 + 10;ll n, m, mod;ll a[maxn];struct NODE { ll l, r, sum; ll add, mu;}t[4*maxn];/*built*/void built(ll i, ll l, ll r) { t[i].l = l; t[i].r = r; t[i].mu = 1; t[i].add = 0; if (l == r) { t[i].sum = a[l] % mod; return; } ll mid = (l + r) / 2; built(2 * i, l, mid); built(2 * i + 1, mid + 1, r); t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;}/*pushDown*/void pushDown(ll i) { if (t[i].l == t[i].r) return; t[2 * i].sum = (t[2 * i].sum * t[i].mu + (t[2 * i].r - t[2 * i].l + 1) * t[i].add) % mod; t[2 * i + 1].sum = (t[2 * i + 1].sum * t[i].mu + (t[2 * i + 1].r - t[2 * i + 1].l + 1) * t[i].add) % mod; t[2 * i].mu = (t[2 * i].mu * t[i].mu) % mod; t[2 * i + 1].mu = (t[2 * i + 1].mu * t[i].mu) % mod; t[2 * i].add = (t[2 * i].add * t[i].mu + t[i].add) % mod; t[2 * i + 1].add = (t[2 * i + 1].add * t[i].mu + t[i].add) % mod; t[i].mu = 1; t[i].add = 0;}/*add*/void add(ll i, ll x, ll y, ll k) { if (x <= t[i].l && t[i].r <= y) { t[i].sum = (t[i].sum + (t[i].r - t[i].l + 1) * k) % mod; t[i].add = (t[i].add + k) % mod; return; } pushDown(i); //POINT ll mid = (t[i].l + t[i].r) / 2; if (x <= mid) add(2 * i, x, y, k); if (y > mid) add(2 * i + 1, x, y, k); t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;}/*multiply*/void mul(ll i, ll x, ll y, ll k) { if (x <= t[i].l && t[i].r <= y) { t[i].sum = (t[i].sum * k) % mod; t[i].mu = (t[i].mu * k) % mod; t[i].add = (t[i].add * k) % mod; //POINT return; } pushDown(i); //POINT ll mid = (t[i].l + t[i].r) / 2; if (x <= mid) mul(2 * i, x, y, k); if (y > mid) mul(2 * i + 1, x, y, k); t[i].sum = (t[2 * i].sum + t[2 * i + 1].sum) % mod;}/*find*/ll query(ll i, ll x, ll y) { if (x <= t[i].l && t[i].r <= y) { return t[i].sum % mod; } pushDown(i); //POINT ll mid = (t[i].l + t[i].r) / 2; ll ans = 0; if (x <= mid) ans = (ans + query(2 * i, x, y)) % mod; if (y > mid) ans = (ans + query(2 * i + 1, x, y)) % mod; return ans;}int main() { scanf("%lld%lld%lld", &n, &m, &mod); for (int i = 1; i <= n; i++) { scanf("%lld", &a[i]); } built(1, 1, n); int flag; ll x, y, k; for (int i = 1; i <= m; i++) { scanf("%d", &flag); if (flag == 1) { scanf("%lld%lld%lld", &x, &y, &k); mul(1, x, y, k); } else if (flag == 2) { scanf("%lld%lld%lld", &x, &y, &k); add(1, x, y, k); } else { scanf("%lld%lld", &x, &y); printf("%lld\n", query(1, x, y)); } } return 0;}