树状数组入门

参考链接:

树状数组」这个数据结构用于高效地解决「前缀和查询」与「单点更新」问题
image.png
image.png
image.png
image.png
image.png
image.png
image.png
image.png

例题

1. 计算右侧小于当前元素的个数

描述
image.png
思路

考虑到「树状数组」的底层是数组(线性结构),为了避免开辟多余的「树状数组」空间,需要进行「离散化」;
「离散化」的作用是:针对数值的大小做一个排名的「映射」,把原始数据映射到 [1, len] 这个区间,这样「树状数组」底层的数组空间会更紧凑,更易于维护。

从右向左读取排名;
先查询严格小于当前排名的「前缀和」,这里「前缀和」指的是,严格小于当前排名的元素的个数,这一步对应「前缀和查询」;
然后给「当前排名」加 1,这一步对应「单点更新」。
代码
Java代码:

  1. class Solution {
  2. private int[] c;
  3. private int[] a;
  4. public List<Integer> countSmaller(int[] nums) {
  5. List<Integer> resultList = new ArrayList<Integer>();
  6. discretization(nums);
  7. init(nums.length + 5);
  8. for (int i = nums.length - 1; i >= 0; --i) {
  9. int id = getId(nums[i]);
  10. resultList.add(query(id - 1));
  11. update(id);
  12. }
  13. Collections.reverse(resultList);
  14. return resultList;
  15. }
  16. private void init(int length) {
  17. c = new int[length];
  18. Arrays.fill(c, 0);
  19. }
  20. private int lowBit(int x) {
  21. return x & (-x);
  22. }
  23. private void update(int pos) {
  24. while (pos < c.length) {
  25. c[pos] += 1;
  26. pos += lowBit(pos);
  27. }
  28. }
  29. private int query(int pos) {
  30. int ret = 0;
  31. while (pos > 0) {
  32. ret += c[pos];
  33. pos -= lowBit(pos);
  34. }
  35. return ret;
  36. }
  37. private void discretization(int[] nums) {
  38. Set<Integer> set = new HashSet<Integer>();
  39. for (int num : nums) {
  40. set.add(num);
  41. }
  42. int size = set.size();
  43. a = new int[size];
  44. int index = 0;
  45. for (int num : set) {
  46. a[index++] = num;
  47. }
  48. Arrays.sort(a);
  49. }
  50. private int getId(int x) {
  51. return Arrays.binarySearch(a, x) + 1;
  52. }
  53. }

Python代码:

  1. from bisect import bisect_left
  2. class Solution:
  3. def countSmaller(self, nums: List[int]) -> List[int]:
  4. m = len(nums)
  5. counts = [0] * m
  6. treeArr = FenwickTree(m + 1)
  7. nums = treeArr.discretization(nums)
  8. for i in range(m-1, -1, -1):
  9. treeArr.update(nums[i], 1)
  10. counts[i] = treeArr.ask(nums[i]-1)
  11. return counts
  12. class FenwickTree:
  13. def __init__(self, n):
  14. self.tree = [0] * (n + 1)
  15. self.size = n
  16. def lowbit(self, x):
  17. return x & (-x)
  18. def update(self, i, k):
  19. # 从下到上,可以等于size
  20. while i <= self.size:
  21. self.tree[i] += k
  22. i += self.lowbit(i)
  23. def ask(self, i):
  24. res = 0
  25. while i > 0:
  26. res += self.tree[i]
  27. i -= self.lowbit(i)
  28. return res
  29. def get_mapping_list(self, nums):
  30. return list(sorted(set(nums)))
  31. def discretization(self, nums):
  32. mapping = self.get_mapping_list(nums)
  33. return [bisect_left(mapping, num) + 1 for num in nums]

离散化可以采用treeSet和堆来实现
Java代码:

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class Solution {

    public List<Integer> countSmaller(int[] nums) {
        List<Integer> res = new ArrayList<>();
        int len = nums.length;
        if (len == 0) {
            return res;
        }

        // 使用二分搜索树方便排序
        Set<Integer> set = new TreeSet();
        for (int i = 0; i < len; i++) {
            set.add(nums[i]);
        }

        // 排名表
        Map<Integer, Integer> map = new HashMap<>();
        int rank = 1;
        for (Integer num : set) {
            map.put(num, rank);
            rank++;
        }

        FenwickTree fenwickTree = new FenwickTree(set.size() + 1);
        // 从后向前填表
        for (int i = len - 1; i >= 0; i--) {
            // 1、查询排名
            rank = map.get(nums[i]);
            // 2、在树状数组排名的那个位置 + 1
            fenwickTree.update(rank, 1);
            // 3、查询一下小于等于“当前排名 - 1”的元素有多少
            res.add(fenwickTree.query(rank - 1));
        }
        Collections.reverse(res);
        return res;
    }


    private class FenwickTree {
        private int[] tree;
        private int len;

        public FenwickTree(int n) {
            this.len = n;
            tree = new int[n + 1];
        }

        // 单点更新:将 index 这个位置 + 1
        public void update(int i, int delta) {
            // 从下到上,最多到 size,可以等于 size
            while (i <= this.len) {
                tree[i] += delta;
                i += lowbit(i);
            }
        }


        // 区间查询:查询小于等于 index 的元素个数
        // 查询的语义是"前缀和"
        public int query(int i) {
            // 从右到左查询
            int sum = 0;
            while (i > 0) {
                sum += tree[i];
                i -= lowbit(i);
            }
            return sum;
        }

        public int lowbit(int x) {
            return x & (-x);
        }
    }


    public static void main(String[] args) {
        int[] nums = new int[]{5, 2, 6, 1};
        Solution solution = new Solution();
        List<Integer> countSmaller = solution.countSmaller(nums);
        System.out.println(countSmaller);
    }
}

Python代码:

from typing import List


class Solution:
    def countSmaller(self, nums: List[int]) -> List[int]:
        class FenwickTree:
            def __init__(self, n):
                self.size = n
                self.tree = [0 for _ in range(n + 1)]

            def __lowbit(self, index):
                return index & (-index)

            # 单点更新:将 index 这个位置 + 1
            def update(self, index, delta):
                # 从下到上,最多到 size,可以等于 size
                while index <= self.size:
                    self.tree[index] += delta
                    index += self.__lowbit(index)

            # 区间查询:查询小于等于 index 的元素个数
            # 查询的语义是"前缀和"
            def query(self, index):
                res = 0
                # 从上到下,最少到 1,可以等于 1
                while index > 0:
                    res += self.tree[index]
                    index -= self.__lowbit(index)
                return res

        # 特判
        size = len(nums)
        if size == 0:
            return []
        if size == 1:
            return [0]

        # 去重,方便离散化
        s = list(set(nums))

        s_len = len(s)

        # 离散化,借助堆
        import heapq
        heapq.heapify(s)

        rank_map = dict()
        rank = 1
        for _ in range(s_len):
            num = heapq.heappop(s)
            rank_map[num] = rank
            rank += 1

        fenwick_tree = FenwickTree(s_len)

        # 从后向前填表
        res = [None for _ in range(size)]
        # 从后向前填表
        for index in range(size - 1, -1, -1):
            # 1、查询排名
            rank = rank_map[nums[index]]
            # 2、在树状数组排名的那个位置 + 1
            fenwick_tree.update(rank, 1)
            # 3、查询一下小于等于“当前排名 - 1”的元素有多少
            res[index] = fenwick_tree.query(rank - 1)
        return res


if __name__ == '__main__':
    nums = [5, 2, 6, 1]
    solution = Solution()
    result = solution.countSmaller(nums)
    print(result)

2. AcWing 242. 一个简单的整数问题

给定长度为N的数列A,然后输入M行操作指令。
第一类指令形如“C l r d”,表示把数列中第l~r个数都加d。
第二类指令形如“Q X”,表示询问数列中第x个数的值。
对于每个询问,输出一个整数表示答案。
输入格式
第一行包含两个整数N和M。
第二行包含N个整数A[i]。
接下来M行表示M条指令,每条指令的格式如题目描述所示。
输出格式
对于每个询问,输出一个整数表示答案。
每个答案占一行。
数据范围
1≤N,M≤1051≤N,M≤105,
|d|≤10000|d|≤10000,
|A[i]|≤1000000000|A[i]|≤1000000000
输入样例:

10 5
1 2 3 4 5 6 7 8 9 10
Q 4
Q 1
Q 2
C 1 6 3
Q 2

输出样例:

4
1
2
5

思路

  • 区间更新
  • 单点查询
  • 树状数组

本题的指令有“区间增加”和“单点查询”,而树状数组仅支持“单点查询”,需要作出一些转化来解决问题。
新建一个数组b,期初为全零。对于每条指令“C l r d”,我们把它转化成以下两条指令:
1.把b[l]加上d.
2.再把b[r+1]减去d。
执行上面两条指令后,我们再来考虑一下b数组的前缀和(b[1~x]的和)的情况:
1.对于1<=x2.对于l<=x3.对于r我们发现,b数组的前缀和b[1~x]就反映了指令“C l r d”对a[x]产生的影响。
于是,我们可以用树状数组来维护数组b的前缀(对b只有单点增加操作)。
因为各次操作之间具有可累加性,所以在树状数组上查询前缀和b[1~x],就得到了到目前为止所有“C”指令在a[x]上增加的数值总和。再加上a[x]的初始值,就得到了“Qx”的答案。
代码

"""

@Author: Li Zenghui
@Date: 2020-08-16 09:46
"""

class FW:
    def __init__(self, n):
        self.a = [0]
        self.b = [0] * (n + 1)
        self.size = n

    def lowbit(self, x):
        return x & -x

    def update(self, i, k):
        while i <= self.size:
            self.b[i] += k
            i += self.lowbit(i)

    def ask(self, i):
        ans = self.a[i]
        while i > 0:
            ans += self.b[i]
            i -= self.lowbit(i)
        return ans


if __name__ == '__main__':
    N, M = map(int, input().split())
    A = list(map(int, input().split()))
    n = len(A)
    fw = FW(n)
    fw.a.extend(A)
    for _ in range(M):
        op = list(input().split())
        if op[0] == "Q":
            print(fw.ask(int(op[1])))
        if op[0] == "C":
            fw.update(int(op[1]), int(op[3]))
            fw.update(int(op[2]) + 1, -int(op[3]))

3.