核心:不管是插入还是查询,都会将待操作数splay到树根。
可以证明:无论在何种数据下,使用splay操作,都能保证平均意义下O(logn)
的时间复杂度。
数据结构
static class Node {
int[] s = new int[2]; // 子节点
int p; // 父节点
int v; // 节点维护的数据信息
int size, flag; // 以当前节点为根的子树的总节点数,懒标记
}
static final int N = 100010;
static Node[] tr = new Node[N];
static int root, idx;
static void splay(int x, int k); // 伸展操作
static void rotate(int x); // 旋转操作
static void insert(int v); // 插入值为v的节点
static void delete(int v); // 删除值为v的节点
static void delete(int vl, int vr); // 删除值为[vl, vr]的一段节点
static void pushup(int x); // 向上更新
static void pushdown(int x); // 向下更新
static int get_r(int k); // 获取中序遍历的第k个数(不一定是第k大)
static void output(int u); // 中序遍历输出整棵树
splay
操作
splay(x, k)
将x
指向的节点旋转至k
节点下方。
分两种情况:
第一种:直线型
z y x
/ / \ \
y => x z => z
/ \
x y
第二种:折线型
z z x
/ / / \
y => x => y z
\ /
x y
作用:将x节点转至k点下方,并使树更平衡
static void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y))
rotate(x);
else rotate(y);
rotate(x);
}
if (k == 0) root = x;
}
rotate
操作
static void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x ? 1 : 0;
tr[z].s[tr[z].s[1] == y ? 1 : 0] = x;
tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1];
tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y;
tr[y].p = x;
pushup(y);
pushup(x);
}
insert
操作
插入值为v的节点
static void insert(int v) {
int u = root, p = 0;
while (u != 0) {
p = u;
u = v > tr[u].v ? tr[u].s[1] : tr[u].s[0];
}
u = ++idx;
tr[u] = new Node(v, p);
if (p != 0) tr[p].s[tr[p].v > v ? 0 : 1] = u;
splay(u, 0);
}
delete
操作
pushup
操作
在rotate
结尾调用
static void pushup(int x) {
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
pushdown
操作
在output
和get_r
开始调用,即递归操作的开始调用
static void pushdown(int x) {
if (tr[x].flag == 1) {
swap(x);
tr[tr[x].s[0]].flag ^= 1;
tr[tr[x].s[1]].flag ^= 1;
tr[x].flag = 0;
}
}
static void swap(int x) {
int t = tr[x].s[0];
tr[x].s[0] = tr[x].s[1];
tr[x].s[1] = t;
}
get_r
操作
获取第k个节点
static int get_r(int k) {
int u = root;
while (u != 0) {
pushdown(u);
if (tr[tr[u].s[0]].size + 1 < k) {
k -= tr[tr[u].s[0]].size + 1;
u = tr[u].s[1];
} else if (tr[tr[u].s[0]].size + 1 == k) {
return u;
} else u = tr[u].s[0];
}
return -1;
}
output
操作
输出整棵树的中序遍历结果
static void output(int u) {
pushdown(u);
if (tr[u].s[0] != 0)
output(tr[u].s[0]);
if (tr[u].v >= 1 && tr[u].v <= n)
System.out.print(tr[u].v + " ");
if (tr[u].s[1] != 0)
output(tr[u].s[1]);
}
例题
Acwing 2437. Splay模板题
给定一个长度为 nn 的整数序列,初始时序列为 {1,2,…,n−1,n}。
序列中的位置从左到右依次标号为 1∼n。
我们用 [l,r] 来表示从位置 l 到位置 r 之间(包括两端点)的所有数字构成的子序列。
现在要对该序列进行 m 次操作,每次操作选定一个子序列 [l,r],并将该子序列中的所有数字进行翻转。
例如,对于现有序列 1 3 2 4 6 5 7,如果某次操作选定翻转子序列为 [3,6],那么经过这次操作后序列变为 1 3 5 6 4 2 7。
请你求出经过 m 次操作后的序列。
输入格式
第一行包含两个整数 n,m。
接下来 m 行,每行包含两个整数 l,r,用来描述一次操作。
输出格式
共一行,输出经过 m 次操作后的序列。
数据范围
1≤n,m≤105,
1≤l≤r≤n
输入样例:
6 3 2 4 1 5 3 5
输出样例:
5 2 1 4 3 6
import java.util.*;
public class Main {
static class Node {
int[] s = new int[2];
int v, p;
int size, flag;
Node(int v, int p) {
this.v = v;
this.p = p;
}
}
static final int N = 100010;
static Node[] tr = new Node[N];
static int n, m, idx, root;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
tr[0] = new Node(0, 0);
for (int i = 0; i <= n + 1; i++)
insert(i);
while (m-- > 0) {
int l = sc.nextInt(), r = sc.nextInt();
l = get_r(l);
r = get_r(r + 2);
splay(l, 0);
splay(r, l);
tr[tr[r].s[0]].flag ^= 1;
}
output(root);
}
static void insert(int v) {
int u = root, p = 0;
while (u != 0) {
p = u;
u = v > tr[u].v ? tr[u].s[1] : tr[u].s[0];
}
u = ++idx;
tr[u] = new Node(v, p);
if (p != 0) tr[p].s[tr[p].v > v ? 0 : 1] = u;
splay(u, 0);
}
static void splay(int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
if ((tr[y].s[1] == x) ^ (tr[z].s[1] == y))
rotate(x);
else rotate(y);
rotate(x);
}
if (k == 0) root = x;
}
static void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x ? 1 : 0;
tr[z].s[tr[z].s[1] == y ? 1 : 0] = x;
tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1];
tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y;
tr[y].p = x;
pushup(y);
pushup(x);
}
static void pushup(int x) {
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1;
}
static void pushdown(int x) {
if (tr[x].flag == 1) {
swap(x);
tr[tr[x].s[0]].flag ^= 1;
tr[tr[x].s[1]].flag ^= 1;
tr[x].flag = 0;
}
}
static void swap(int x) {
int t = tr[x].s[0];
tr[x].s[0] = tr[x].s[1];
tr[x].s[1] = t;
}
static int get_r(int k) {
int u = root;
while (u != 0) {
pushdown(u);
if (tr[tr[u].s[0]].size + 1 < k) {
k -= tr[tr[u].s[0]].size + 1;
u = tr[u].s[1];
} else if (tr[tr[u].s[0]].size + 1 == k) {
return u;
} else u = tr[u].s[0];
}
return -1;
}
static void output(int u) {
pushdown(u);
if (tr[u].s[0] != 0)
output(tr[u].s[0]);
if (tr[u].v >= 1 && tr[u].v <= n)
System.out.print(tr[u].v + " ");
if (tr[u].s[1] != 0)
output(tr[u].s[1]);
}
}