线段树
对于有一类的问题,我们主要关心的是线段(区间),比如说查询一个区间[i, j]内的最大值,最小值等等。假设你有一个网站,你想查询某年(或某年以后)的用户访问量,消费最多的用户等等,这些都是在某个区间内进行查询,一般线段树的区间是固定的,不包含删除和添加的操作,只有查询和更新的操作
线段树的表示

现在如果假设有n个元素,用数组存储的话,需要多少空间呢
public class SegmentTree<E> {private E[] tree;private E[] data;public SegmentTree(E[] arr) {data = (E[]) new Object[arr.length];for (int i = 0; i < arr.length; i++) {data[i] = arr[i];}tree = (E[]) new Object[4 * data.length];}public int getSize() {return data.length;}public E get(int index) {if (index < 0 || index >= data.length) {throw new IllegalArgumentException("参数错误");}return data[index];}private int leftChild(int index) {return 2 * index + 1;}private int rightChild(int index) {return 2 * index + 2;}}
实现
创建线段树
下面就要根据数组来创建一棵线段树,我们的方法先创建下面的子线段树,然后由这些子线段树合并成大的线段树,以此类推
在合并左右子树的过程中,我们不能写死合并的过程,具体怎么合并应该由业务决定,由用户去决定如何合并,所以合并的过程我们写一个接口,具体的实现由用户去实现
public interface Merger<E> {public E merge(E a, E b);}
然后我们在构造方法中添加创建线段树的过程(为了创建线段树,增加了一个辅助方法)
private Merger<E> merger;//merger由用户传入 用户决定如何合并public SegmentTree(E[] arr, Merger<E> merger) {this.merger = merger;data = (E[]) new Object[arr.length];for (int i = 0; i < arr.length; i++) {data[i] = arr[i];}tree = (E[]) new Object[4 * data.length];//构造线段树 创建根节点为0,范围为[0,data.length - 1]的线段树buildSegmentTree(0, 0, data.length - 1);}//在treeIndex创建一棵[l,r]的线段树private void buildSegmentTree(int treeIndex, int l, int r) {if (l == r) {tree[treeIndex] = data[l];return;}//l != r 那么就要创建子树的线段树int leftTreeIndex = leftChild(treeIndex);int rightTreeIndex = rightChild(treeIndex);int mid = l + (r - l) / 2; //(l +r) / 2中l + r可能会大于int表示的范围从而溢出buildSegmentTree(leftTreeIndex, l, mid);buildSegmentTree(rightTreeIndex, mid + 1, r);//融合的方法由用户传入tree[treeIndex] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]);}
为了方便我们打印出线段树,我们实现一个toString()方法
@Overridepublic String toString() {StringBuilder res = new StringBuilder();res.append("[");for (int i = 0; i < tree.length; i++) {if (tree[i] != null) {res.append(tree[i]);} else {res.append("null");}if (i != tree.length - 1) {res.append(", ");}}res.append("]");return res.toString();}
查询

实现代码
public E query(int queryL, int queryR) {if (queryL < 0 || queryL >= data.length|| queryR < 0 || queryR >= data.length|| queryL > queryR) {throw new IllegalArgumentException("参数错误");}return query(0, 0, data.length - 1, queryL, queryR);}private E query(int treeIndex, int l, int r, int queryL, int queryR) {if (l == queryL && r == queryR) {return tree[treeIndex];}int leftChildIndex = leftChild(treeIndex);int rightChildIndex = rightChild(treeIndex);int mid = l + (r - l) / 2;if (queryL >= mid + 1) {return query(rightChildIndex, mid+1, r, queryL, queryR);} else if (queryR <= mid) {return query(leftChildIndex, l, mid, queryL, queryR);}E leftResult = query(leftChildIndex, l, mid, queryL, mid);E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, queryR);return merger.merge(leftResult, rightResult);}
更新
public void set(int index, E e) {if (index < 0 || index >= data.length) {throw new IllegalArgumentException("参数错误");}set(0, 0, data.length - 1, index, e);}private void set(int treeIndex, int l, int r, int index, E e) {if (l == r) {tree[treeIndex] = e;return;}int leftChildIndex = leftChild(treeIndex);int rightChildIndex = rightChild(treeIndex);int mid = l + (r - l) / 2;if (index >= mid + 1) {set(rightChildIndex, mid+1, r, index, e);} else {set(leftChildIndex, l, mid, index, e);}tree[treeIndex] = merger.merge(tree[leftChildIndex], tree[rightChildIndex]);}
完整代码
public class SegmentTree<E>{private E[] tree;private E[] data;private Merger<E> merger;public SegmentTree(E[] arr, Merger<E> merger) {this.merger = merger;data = (E[]) new Object[arr.length];for (int i = 0; i < arr.length; i++) {data[i] = arr[i];}tree = (E[]) new Object[4 * data.length];buildSegmentTree(0, 0, data.length - 1);}//在treeIndex创建一棵[l,r]的线段树private void buildSegmentTree(int treeIndex, int l, int r) {if (l == r) {tree[treeIndex] = data[l];return;}int leftTreeIndex = leftChild(treeIndex);int rightTreeIndex = rightChild(treeIndex);int mid = l + (r - l) / 2; //(l +r) / 2中l + r可能会大于int表示的范围从而溢出buildSegmentTree(leftTreeIndex, l, mid);buildSegmentTree(rightTreeIndex, mid + 1, r);tree[treeIndex] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]);}public E query(int queryL, int queryR) {if (queryL < 0 || queryL >= data.length|| queryR < 0 || queryR >= data.length|| queryL > queryR) {throw new IllegalArgumentException("参数错误");}return query(0, 0, data.length - 1, queryL, queryR);}private E query(int treeIndex, int l, int r, int queryL, int queryR) {if (l == queryL && r == queryR) {return tree[treeIndex];}int leftChildIndex = leftChild(treeIndex);int rightChildIndex = rightChild(treeIndex);int mid = l + (r - l) / 2;if (queryL >= mid + 1) {return query(rightChildIndex, mid+1, r, queryL, queryR);} else if (queryR <= mid) {return query(leftChildIndex, l, mid, queryL, queryR);}E leftResult = query(leftChildIndex, l, mid, queryL, mid);E rightResult = query(rightChildIndex, mid + 1, r, mid + 1, queryR);return merger.merge(leftResult, rightResult);}public void set(int index, E e) {if (index < 0 || index >= data.length) {throw new IllegalArgumentException("参数错误");}set(0, 0, data.length - 1, index, e);}private void set(int treeIndex, int l, int r, int index, E e) {if (l == r) {tree[treeIndex] = e;return;}int leftChildIndex = leftChild(treeIndex);int rightChildIndex = rightChild(treeIndex);int mid = l + (r - l) / 2;if (index >= mid + 1) {set(rightChildIndex, mid+1, r, index, e);} else {set(leftChildIndex, l, mid, index, e);}tree[treeIndex] = merger.merge(tree[leftChildIndex], tree[rightChildIndex]);}public int getSize() {return data.length;}public E get(int index) {if (index < 0 || index >= data.length) {throw new IllegalArgumentException("参数错误");}return data[index];}private int leftChild(int index) {return 2 * index + 1;}private int rightChild(int index) {return 2 * index + 2;}@Overridepublic String toString() {StringBuilder res = new StringBuilder();res.append("[");for (int i = 0; i < tree.length; i++) {if (tree[i] != null) {res.append(tree[i]);} else {res.append("null");}if (i != tree.length - 1) {res.append(", ");}}res.append("]");return res.toString();}}
