背景知识
1. 二叉搜索树BST
2. 堆Heap
操作
- 插入
- 删除
- 找前驱/后继
- 找最大/最小
- 求某个值的排名
- 求排名是k的数是哪个
- 找比某个数小的最大值
- 找比某个数大的最小值
结构定义
Node {
int l, r;
// key指二叉搜索树中的值,val指堆中的值(随机值),两个同时满足要求
int key, val;
}
这样就能唯一确定一个二叉搜索树。
插入
- 直接插入至叶节点,赋随机值val
- 旋转操作(左旋zag/右旋zig),类似于堆的pushup
删除
- 找到目标节点
- 通过不断左旋或右旋降低节点高度直至叶节点
- 删除目标节点
模板
// 省略IO
public class Main {
static IntReader in;
static FastWriter out;
static String INPUT = "";
static class Node {
int l, r;
int key, val;
int cnt, size;
}
static final int INF = (int)(1e8), N = 100010;
static Node[] tr = new Node[N];
static int n, idx;
static Random random = new Random();
static void solve() {
n = ni();
int root = build();
for (int i = 1; i <= n; i++) {
int op = ni(), x = ni();
if (op == 1)
root = insert(root, x);
else if (op == 2)
root = delete(root, x);
else if (op == 3)
out.println(getRank(root, x) - 1);
else if (op == 4)
out.println(getKey(root, x + 1));
else if (op == 5)
out.println(getPrev(root, x));
else if (op == 6)
out.println(getNext(root, x));
}
}
static int getNext(int u, int x) {
if (u == 0) return INF;
if (tr[u].key <= x)
return getNext(tr[u].r, x);
else
return Math.min(tr[u].key, getNext(tr[u].l, x));
}
static int getPrev(int u, int x) {
if (u == 0) return -INF;
if (tr[u].key >= x)
return getPrev(tr[u].l, x);
else
return Math.max(tr[u].key, getPrev(tr[u].r, x));
}
static int getKey(int u, int rank) {
if (u == 0) return 0;
if (tr[tr[u].l].size >= rank)
return getKey(tr[u].l, rank);
else if (tr[tr[u].l].size + tr[u].cnt >= rank)
return tr[u].key;
else
return getKey(tr[u].r, rank - tr[u].cnt - tr[tr[u].l].size);
}
static int getRank(int u, int key) {
if (u == 0) return 0;
if (tr[u].key == key) {
return tr[tr[u].l].size + 1;
} else if (tr[u].key > key) {
return getRank(tr[u].l, key);
} else {
return tr[tr[u].l].size + tr[u].cnt + getRank(tr[u].r, key);
}
}
static int delete(int u, int key) {
if (u == 0) {
return 0;
}
if (tr[u].key == key) {
if (tr[u].cnt > 1)
tr[u].cnt--;
else if (tr[u].l != 0 || tr[u].r != 0) {
if (tr[u].r == 0 || tr[u].l != 0 && tr[tr[u].l].val > tr[tr[u].r].val) {
u = zig(u);
tr[u].r = delete(tr[u].r, key);
} else {
u = zag(u);
tr[u].l = delete(tr[u].l, key);
}
} else return 0;
} else if (tr[u].key > key) {
tr[u].l = delete(tr[u].l, key);
} else {
tr[u].r = delete(tr[u].r, key);
}
pushup(u);
return u;
}
static int insert(int u, int key) {
if (u == 0) {
int p = createNode(key);
return p;
}
if (tr[u].key == key) {
tr[u].cnt++;
} else if (tr[u].key > key) {
tr[u].l = insert(tr[u].l, key);
if (tr[u].val < tr[tr[u].l].val)
u = zig(u);
} else {
tr[u].r = insert(tr[u].r, key);
if (tr[u].val < tr[tr[u].r].val)
u = zag(u);
}
pushup(u);
return u;
}
static void pushup(int u) {
tr[u].size = tr[tr[u].l].size + tr[tr[u].r].size + tr[u].cnt;
}
static int zag(int u) {
int right = tr[u].r;
tr[u].r = tr[right].l;
tr[right].l = u;
pushup(u);
pushup(right);
return right;
}
static int zig(int u) {
int left = tr[u].l;
tr[u].l = tr[left].r;
tr[left].r = u;
pushup(u);
pushup(left);
return left;
}
static int build() {
tr[0] = new Node();
int root = createNode(-INF), right = createNode(INF);
tr[root].r = right;
pushup(root);
if (tr[root].val < tr[right].val)
root = zag(root);
return root;
}
static int createNode(int x) {
++idx;
tr[idx] = new Node();
tr[idx].key = x;
tr[idx].val = random.nextInt(2 * INF) + 1;
tr[idx].size = tr[idx].cnt = 1;
return idx;
}
public static void main(String[] args) throws Exception {
in = INPUT.isEmpty() ? new IntReader(System.in) : new IntReader(new ByteArrayInputStream(INPUT.getBytes()));
out = new FastWriter(System.out);
solve();
out.flush();
}
}