树状数组入门
参考链接:
- https://www.cnblogs.com/xenny/p/9739600.html
- https://leetcode-cn.com/problems/count-of-smaller-numbers-after-self/solution/ji-suan-you-ce-xiao-yu-dang-qian-yuan-su-de-ge-s-7/
- https://leetcode-cn.com/problems/count-of-smaller-numbers-after-self/solution/gui-bing-pai-xu-suo-yin-shu-zu-python-dai-ma-java-/
树状数组」这个数据结构用于高效地解决「前缀和查询」与「单点更新」问题;







例题
1. 计算右侧小于当前元素的个数
描述 
思路
考虑到「树状数组」的底层是数组(线性结构),为了避免开辟多余的「树状数组」空间,需要进行「离散化」;
「离散化」的作用是:针对数值的大小做一个排名的「映射」,把原始数据映射到 [1, len] 这个区间,这样「树状数组」底层的数组空间会更紧凑,更易于维护。
从右向左读取排名;
先查询严格小于当前排名的「前缀和」,这里「前缀和」指的是,严格小于当前排名的元素的个数,这一步对应「前缀和查询」;
然后给「当前排名」加 1,这一步对应「单点更新」。
代码
Java代码:
class Solution {private int[] c;private int[] a;public List<Integer> countSmaller(int[] nums) {List<Integer> resultList = new ArrayList<Integer>();discretization(nums);init(nums.length + 5);for (int i = nums.length - 1; i >= 0; --i) {int id = getId(nums[i]);resultList.add(query(id - 1));update(id);}Collections.reverse(resultList);return resultList;}private void init(int length) {c = new int[length];Arrays.fill(c, 0);}private int lowBit(int x) {return x & (-x);}private void update(int pos) {while (pos < c.length) {c[pos] += 1;pos += lowBit(pos);}}private int query(int pos) {int ret = 0;while (pos > 0) {ret += c[pos];pos -= lowBit(pos);}return ret;}private void discretization(int[] nums) {Set<Integer> set = new HashSet<Integer>();for (int num : nums) {set.add(num);}int size = set.size();a = new int[size];int index = 0;for (int num : set) {a[index++] = num;}Arrays.sort(a);}private int getId(int x) {return Arrays.binarySearch(a, x) + 1;}}
Python代码:
from bisect import bisect_leftclass Solution:def countSmaller(self, nums: List[int]) -> List[int]:m = len(nums)counts = [0] * mtreeArr = FenwickTree(m + 1)nums = treeArr.discretization(nums)for i in range(m-1, -1, -1):treeArr.update(nums[i], 1)counts[i] = treeArr.ask(nums[i]-1)return countsclass FenwickTree:def __init__(self, n):self.tree = [0] * (n + 1)self.size = ndef lowbit(self, x):return x & (-x)def update(self, i, k):# 从下到上,可以等于sizewhile i <= self.size:self.tree[i] += ki += self.lowbit(i)def ask(self, i):res = 0while i > 0:res += self.tree[i]i -= self.lowbit(i)return resdef get_mapping_list(self, nums):return list(sorted(set(nums)))def discretization(self, nums):mapping = self.get_mapping_list(nums)return [bisect_left(mapping, num) + 1 for num in nums]
离散化可以采用treeSet和堆来实现
Java代码:
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
public class Solution {
public List<Integer> countSmaller(int[] nums) {
List<Integer> res = new ArrayList<>();
int len = nums.length;
if (len == 0) {
return res;
}
// 使用二分搜索树方便排序
Set<Integer> set = new TreeSet();
for (int i = 0; i < len; i++) {
set.add(nums[i]);
}
// 排名表
Map<Integer, Integer> map = new HashMap<>();
int rank = 1;
for (Integer num : set) {
map.put(num, rank);
rank++;
}
FenwickTree fenwickTree = new FenwickTree(set.size() + 1);
// 从后向前填表
for (int i = len - 1; i >= 0; i--) {
// 1、查询排名
rank = map.get(nums[i]);
// 2、在树状数组排名的那个位置 + 1
fenwickTree.update(rank, 1);
// 3、查询一下小于等于“当前排名 - 1”的元素有多少
res.add(fenwickTree.query(rank - 1));
}
Collections.reverse(res);
return res;
}
private class FenwickTree {
private int[] tree;
private int len;
public FenwickTree(int n) {
this.len = n;
tree = new int[n + 1];
}
// 单点更新:将 index 这个位置 + 1
public void update(int i, int delta) {
// 从下到上,最多到 size,可以等于 size
while (i <= this.len) {
tree[i] += delta;
i += lowbit(i);
}
}
// 区间查询:查询小于等于 index 的元素个数
// 查询的语义是"前缀和"
public int query(int i) {
// 从右到左查询
int sum = 0;
while (i > 0) {
sum += tree[i];
i -= lowbit(i);
}
return sum;
}
public int lowbit(int x) {
return x & (-x);
}
}
public static void main(String[] args) {
int[] nums = new int[]{5, 2, 6, 1};
Solution solution = new Solution();
List<Integer> countSmaller = solution.countSmaller(nums);
System.out.println(countSmaller);
}
}
Python代码:
from typing import List
class Solution:
def countSmaller(self, nums: List[int]) -> List[int]:
class FenwickTree:
def __init__(self, n):
self.size = n
self.tree = [0 for _ in range(n + 1)]
def __lowbit(self, index):
return index & (-index)
# 单点更新:将 index 这个位置 + 1
def update(self, index, delta):
# 从下到上,最多到 size,可以等于 size
while index <= self.size:
self.tree[index] += delta
index += self.__lowbit(index)
# 区间查询:查询小于等于 index 的元素个数
# 查询的语义是"前缀和"
def query(self, index):
res = 0
# 从上到下,最少到 1,可以等于 1
while index > 0:
res += self.tree[index]
index -= self.__lowbit(index)
return res
# 特判
size = len(nums)
if size == 0:
return []
if size == 1:
return [0]
# 去重,方便离散化
s = list(set(nums))
s_len = len(s)
# 离散化,借助堆
import heapq
heapq.heapify(s)
rank_map = dict()
rank = 1
for _ in range(s_len):
num = heapq.heappop(s)
rank_map[num] = rank
rank += 1
fenwick_tree = FenwickTree(s_len)
# 从后向前填表
res = [None for _ in range(size)]
# 从后向前填表
for index in range(size - 1, -1, -1):
# 1、查询排名
rank = rank_map[nums[index]]
# 2、在树状数组排名的那个位置 + 1
fenwick_tree.update(rank, 1)
# 3、查询一下小于等于“当前排名 - 1”的元素有多少
res[index] = fenwick_tree.query(rank - 1)
return res
if __name__ == '__main__':
nums = [5, 2, 6, 1]
solution = Solution()
result = solution.countSmaller(nums)
print(result)
2. AcWing 242. 一个简单的整数问题
给定长度为N的数列A,然后输入M行操作指令。
第一类指令形如“C l r d”,表示把数列中第l~r个数都加d。
第二类指令形如“Q X”,表示询问数列中第x个数的值。
对于每个询问,输出一个整数表示答案。
输入格式
第一行包含两个整数N和M。
第二行包含N个整数A[i]。
接下来M行表示M条指令,每条指令的格式如题目描述所示。
输出格式
对于每个询问,输出一个整数表示答案。
每个答案占一行。
数据范围
1≤N,M≤1051≤N,M≤105,
|d|≤10000|d|≤10000,
|A[i]|≤1000000000|A[i]|≤1000000000
输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
Q 4
Q 1
Q 2
C 1 6 3
Q 2
输出样例:
4
1
2
5
思路
- 区间更新
- 单点查询
- 树状数组
本题的指令有“区间增加”和“单点查询”,而树状数组仅支持“单点查询”,需要作出一些转化来解决问题。
新建一个数组b,期初为全零。对于每条指令“C l r d”,我们把它转化成以下两条指令:
1.把b[l]加上d.
2.再把b[r+1]减去d。
执行上面两条指令后,我们再来考虑一下b数组的前缀和(b[1~x]的和)的情况:
1.对于1<=x
于是,我们可以用树状数组来维护数组b的前缀(对b只有单点增加操作)。
因为各次操作之间具有可累加性,所以在树状数组上查询前缀和b[1~x],就得到了到目前为止所有“C”指令在a[x]上增加的数值总和。再加上a[x]的初始值,就得到了“Qx”的答案。
代码
"""
@Author: Li Zenghui
@Date: 2020-08-16 09:46
"""
class FW:
def __init__(self, n):
self.a = [0]
self.b = [0] * (n + 1)
self.size = n
def lowbit(self, x):
return x & -x
def update(self, i, k):
while i <= self.size:
self.b[i] += k
i += self.lowbit(i)
def ask(self, i):
ans = self.a[i]
while i > 0:
ans += self.b[i]
i -= self.lowbit(i)
return ans
if __name__ == '__main__':
N, M = map(int, input().split())
A = list(map(int, input().split()))
n = len(A)
fw = FW(n)
fw.a.extend(A)
for _ in range(M):
op = list(input().split())
if op[0] == "Q":
print(fw.ask(int(op[1])))
if op[0] == "C":
fw.update(int(op[1]), int(op[3]))
fw.update(int(op[2]) + 1, -int(op[3]))
