实现

  1. class SegmentTree:
  2. def __init__(self, data: list):
  3. # 线段树
  4. self.tree = [None for i in range(len(data) * 4)]
  5. self.data = [i for i in data]
  6. def get_size(self):
  7. return len(self.data)
  8. def get(self, index: int):
  9. assert 0 <= index < len(self.data), "index error"
  10. return self.data[index]
  11. @staticmethod
  12. def get_left_child(index):
  13. """ 左孩子节点下标"""
  14. return 2 * index + 1
  15. @staticmethod
  16. def get_right_child(index):
  17. """右孩子节点下标"""
  18. return 2 * index + 2
  19. def build_segment_tree(self):
  20. self.__build_segment_tree(0, 0, len(self.data) - 1)
  21. # 在tree_index的位置创建表示区间[l,r]的线段树
  22. def __build_segment_tree(self, tree_index, l: int, r: int):
  23. if l == r:
  24. self.tree[tree_index] = self.data[l]
  25. return
  26. # 左下标
  27. left_index = self.get_left_child(tree_index)
  28. # 右下标
  29. right_index = self.get_right_child(tree_index)
  30. mid = l + (r - l) // 2
  31. self.__build_segment_tree(left_index, l, mid)
  32. self.__build_segment_tree(right_index, mid + 1, r)
  33. self.tree[tree_index] = self.tree[left_index] + self.tree[right_index]
  34. def query(self, query_l, query_r):
  35. return self.__query(0, 0, len(self.data) - 1, query_l, query_r)
  36. # 在以tree_index 为根的线段树中[l..r]的范围内,搜索区间[query_l,query_r]的值
  37. def __query(self, tree_index, l, r, query_l, query_r):
  38. if l == query_l and r == query_r:
  39. return self.tree[tree_index]
  40. mid = l + (r - l) // 2
  41. left_tree_index = self.get_left_child(tree_index)
  42. right_tree_index = self.get_right_child(tree_index)
  43. if query_l >= (mid + 1):
  44. return self.__query(right_tree_index, mid + 1, r, query_l, query_r)
  45. elif query_r <= mid:
  46. return self.__query(left_tree_index, l, mid, query_l, query_r)
  47. left_res = self.__query(left_tree_index, l, mid, query_l, mid)
  48. right_res = self.__query(right_tree_index, mid + 1, r, mid + 1, query_r)
  49. return left_res + right_res
  50. # 修改数据
  51. def set(self, index, value):
  52. self.data[index] = value
  53. self.__set(0, 0, len(self.data)-1, index, value)
  54. def __set(self, tree_index, l, r, index, value):
  55. if l == r:
  56. self.tree[tree_index] = value
  57. return
  58. mid = l + (r - l) // 2
  59. left_tree_index = self.get_left_child(tree_index)
  60. right_tree_index = self.get_right_child(tree_index)
  61. if index >= mid + 1:
  62. self.__set(right_tree_index, mid + 1, r, index, value)
  63. else:
  64. self.__set(left_tree_index, l, mid, index, value)
  65. self.tree[tree_index] = self.tree[left_tree_index] + self.tree[right_tree_index]