线段树

对于有一类的问题,我们主要关心的是线段(区间),比如说查询一个区间[i, j]内的最大值,最小值等等。假设你有一个网站,你想查询某年(或某年以后)的用户访问量,消费最多的用户等等,这些都是在某个区间内进行查询,一般线段树的区间是固定的,不包含删除和添加的操作,只有查询和更新的操作
线段树 - 图1

线段树的表示

线段树 - 图2
现在如果假设有n个元素,用数组存储的话,需要多少空间呢
线段树 - 图3

  1. public class SegmentTree<E> {
  2. private E[] tree;
  3. private E[] data;
  4. public SegmentTree(E[] arr) {
  5. data = (E[]) new Object[arr.length];
  6. for (int i = 0; i < arr.length; i++) {
  7. data[i] = arr[i];
  8. }
  9. tree = (E[]) new Object[4 * data.length];
  10. }
  11. public int getSize() {
  12. return data.length;
  13. }
  14. public E get(int index) {
  15. if (index < 0 || index >= data.length) {
  16. throw new IllegalArgumentException("参数错误");
  17. }
  18. return data[index];
  19. }
  20. private int leftChild(int index) {
  21. return 2 * index + 1;
  22. }
  23. private int rightChild(int index) {
  24. return 2 * index + 2;
  25. }
  26. }

实现

创建线段树

下面就要根据数组来创建一棵线段树,我们的方法先创建下面的子线段树,然后由这些子线段树合并成大的线段树,以此类推
线段树 - 图4
在合并左右子树的过程中,我们不能写死合并的过程,具体怎么合并应该由业务决定,由用户去决定如何合并,所以合并的过程我们写一个接口,具体的实现由用户去实现

  1. public interface Merger<E> {
  2. public E merge(E a, E b);
  3. }

然后我们在构造方法中添加创建线段树的过程(为了创建线段树,增加了一个辅助方法)

  1. private Merger<E> merger;
  2. //merger由用户传入 用户决定如何合并
  3. public SegmentTree(E[] arr, Merger<E> merger) {
  4. this.merger = merger;
  5. data = (E[]) new Object[arr.length];
  6. for (int i = 0; i < arr.length; i++) {
  7. data[i] = arr[i];
  8. }
  9. tree = (E[]) new Object[4 * data.length];
  10. //构造线段树 创建根节点为0,范围为[0,data.length - 1]的线段树
  11. buildSegmentTree(0, 0, data.length - 1);
  12. }
  13. //在treeIndex创建一棵[l,r]的线段树
  14. private void buildSegmentTree(int treeIndex, int l, int r) {
  15. if (l == r) {
  16. tree[treeIndex] = data[l];
  17. return;
  18. }
  19. //l != r 那么就要创建子树的线段树
  20. int leftTreeIndex = leftChild(treeIndex);
  21. int rightTreeIndex = rightChild(treeIndex);
  22. int mid = l + (r - l) / 2; //(l +r) / 2中l + r可能会大于int表示的范围从而溢出
  23. buildSegmentTree(leftTreeIndex, l, mid);
  24. buildSegmentTree(rightTreeIndex, mid + 1, r);
  25. //融合的方法由用户传入
  26. tree[treeIndex] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]);
  27. }

为了方便我们打印出线段树,我们实现一个toString()方法

  1. @Override
  2. public String toString() {
  3. StringBuilder res = new StringBuilder();
  4. res.append("[");
  5. for (int i = 0; i < tree.length; i++) {
  6. if (tree[i] != null) {
  7. res.append(tree[i]);
  8. } else {
  9. res.append("null");
  10. }
  11. if (i != tree.length - 1) {
  12. res.append(", ");
  13. }
  14. }
  15. res.append("]");
  16. return res.toString();
  17. }

查询

线段树 - 图5
实现代码

  1. public E query(int queryL, int queryR) {
  2. if (queryL < 0 || queryL >= data.length
  3. || queryR < 0 || queryR >= data.length
  4. || queryL > queryR) {
  5. throw new IllegalArgumentException("参数错误");
  6. }
  7. return query(0, 0, data.length - 1, queryL, queryR);
  8. }
  9. private E query(int treeIndex, int l, int r, int queryL, int queryR) {
  10. if (l == queryL && r == queryR) {
  11. return tree[treeIndex];
  12. }
  13. int leftChildIndex = leftChild(treeIndex);
  14. int rightChildIndex = rightChild(treeIndex);
  15. int mid = l + (r - l) / 2;
  16. if (queryL >= mid + 1) {
  17. return query(rightChildIndex, mid+1, r, queryL, queryR);
  18. } else if (queryR <= mid) {
  19. return query(leftChildIndex, l, mid, queryL, queryR);
  20. }
  21. E leftResult = query(leftChildIndex, l, mid, queryL, mid);
  22. E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, queryR);
  23. return merger.merge(leftResult, rightResult);
  24. }

更新

  1. public void set(int index, E e) {
  2. if (index < 0 || index >= data.length) {
  3. throw new IllegalArgumentException("参数错误");
  4. }
  5. set(0, 0, data.length - 1, index, e);
  6. }
  7. private void set(int treeIndex, int l, int r, int index, E e) {
  8. if (l == r) {
  9. tree[treeIndex] = e;
  10. return;
  11. }
  12. int leftChildIndex = leftChild(treeIndex);
  13. int rightChildIndex = rightChild(treeIndex);
  14. int mid = l + (r - l) / 2;
  15. if (index >= mid + 1) {
  16. set(rightChildIndex, mid+1, r, index, e);
  17. } else {
  18. set(leftChildIndex, l, mid, index, e);
  19. }
  20. tree[treeIndex] = merger.merge(tree[leftChildIndex], tree[rightChildIndex]);
  21. }

完整代码

  1. public class SegmentTree<E>{
  2. private E[] tree;
  3. private E[] data;
  4. private Merger<E> merger;
  5. public SegmentTree(E[] arr, Merger<E> merger) {
  6. this.merger = merger;
  7. data = (E[]) new Object[arr.length];
  8. for (int i = 0; i < arr.length; i++) {
  9. data[i] = arr[i];
  10. }
  11. tree = (E[]) new Object[4 * data.length];
  12. buildSegmentTree(0, 0, data.length - 1);
  13. }
  14. //在treeIndex创建一棵[l,r]的线段树
  15. private void buildSegmentTree(int treeIndex, int l, int r) {
  16. if (l == r) {
  17. tree[treeIndex] = data[l];
  18. return;
  19. }
  20. int leftTreeIndex = leftChild(treeIndex);
  21. int rightTreeIndex = rightChild(treeIndex);
  22. int mid = l + (r - l) / 2; //(l +r) / 2中l + r可能会大于int表示的范围从而溢出
  23. buildSegmentTree(leftTreeIndex, l, mid);
  24. buildSegmentTree(rightTreeIndex, mid + 1, r);
  25. tree[treeIndex] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]);
  26. }
  27. public E query(int queryL, int queryR) {
  28. if (queryL < 0 || queryL >= data.length
  29. || queryR < 0 || queryR >= data.length
  30. || queryL > queryR) {
  31. throw new IllegalArgumentException("参数错误");
  32. }
  33. return query(0, 0, data.length - 1, queryL, queryR);
  34. }
  35. private E query(int treeIndex, int l, int r, int queryL, int queryR) {
  36. if (l == queryL && r == queryR) {
  37. return tree[treeIndex];
  38. }
  39. int leftChildIndex = leftChild(treeIndex);
  40. int rightChildIndex = rightChild(treeIndex);
  41. int mid = l + (r - l) / 2;
  42. if (queryL >= mid + 1) {
  43. return query(rightChildIndex, mid+1, r, queryL, queryR);
  44. } else if (queryR <= mid) {
  45. return query(leftChildIndex, l, mid, queryL, queryR);
  46. }
  47. E leftResult = query(leftChildIndex, l, mid, queryL, mid);
  48. E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, queryR);
  49. return merger.merge(leftResult, rightResult);
  50. }
  51. public void set(int index, E e) {
  52. if (index < 0 || index >= data.length) {
  53. throw new IllegalArgumentException("参数错误");
  54. }
  55. set(0, 0, data.length - 1, index, e);
  56. }
  57. private void set(int treeIndex, int l, int r, int index, E e) {
  58. if (l == r) {
  59. tree[treeIndex] = e;
  60. return;
  61. }
  62. int leftChildIndex = leftChild(treeIndex);
  63. int rightChildIndex = rightChild(treeIndex);
  64. int mid = l + (r - l) / 2;
  65. if (index >= mid + 1) {
  66. set(rightChildIndex, mid+1, r, index, e);
  67. } else {
  68. set(leftChildIndex, l, mid, index, e);
  69. }
  70. tree[treeIndex] = merger.merge(tree[leftChildIndex], tree[rightChildIndex]);
  71. }
  72. public int getSize() {
  73. return data.length;
  74. }
  75. public E get(int index) {
  76. if (index < 0 || index >= data.length) {
  77. throw new IllegalArgumentException("参数错误");
  78. }
  79. return data[index];
  80. }
  81. private int leftChild(int index) {
  82. return 2 * index + 1;
  83. }
  84. private int rightChild(int index) {
  85. return 2 * index + 2;
  86. }
  87. @Override
  88. public String toString() {
  89. StringBuilder res = new StringBuilder();
  90. res.append("[");
  91. for (int i = 0; i < tree.length; i++) {
  92. if (tree[i] != null) {
  93. res.append(tree[i]);
  94. } else {
  95. res.append("null");
  96. }
  97. if (i != tree.length - 1) {
  98. res.append(", ");
  99. }
  100. }
  101. res.append("]");
  102. return res.toString();
  103. }
  104. }