AVL

概念及实现

我们在研究二分搜索树时发现,如果我们将数据顺序添加进树中时,它有会退化成一棵链表,即所有的元素都添加到一个孩子上,这样树结构的优势就体现不出来,为了不使左右孩子的高度相差太大,我们需要对树进行调整,使树达到平衡,成为一棵平衡二叉树,AVL就是一种经典的平衡二叉树
AVL树 - 图1
在AVL中,我们定义的平衡二叉树为,对于任意一个节点,左子树和右子树的高度相差不能超过1。
AVL树 - 图2
我们为每一个节点标注好高度值,计算方法为取左右子树高度较高的高度,然后+1
AVL树 - 图3
然后我们还有记录节点左右子树的高度差,我们称之为平衡因子(规定用左子树的高度-右子树的高度)
AVL树 - 图4
由于我们只是在添加元素和删除元素时对树进行调整,其余的代码同二分搜索树是相同的,所以就不贴出所有的代码,只给出不同的代码,首先我们需要在Node类中添加一个height变量来记录高度

  1. private class Node {
  2. public E e;
  3. public Node left;
  4. public Node right;
  5. //高度
  6. public int height;
  7. public Node(E e) {
  8. this.e = e;
  9. left = null;
  10. right = null;
  11. //高度初始为1
  12. height = 1;
  13. }
  14. @Override
  15. public String toString() {
  16. return e.toString();
  17. }
  18. }

新增加一个获得某节点高度的函数和平衡因子的函数

  1. private int getHeight(Node node) {
  2. if (node == null) {
  3. return 0;
  4. }
  5. return node.height;
  6. }
  7. private int getBalanceFactor(Node node) {
  8. if (node == null) {
  9. return 0;
  10. }
  11. return getHeight(node.left) - getHeight(node.right);
  12. }

有了这些因素,我们一般需要在添加元素时进行维护,重新计算高度和平衡因子,从而进行调整

  1. private Node add(Node node, E e) {
  2. if (node == null) {
  3. size++;
  4. return new Node(e);
  5. }
  6. if (e.compareTo(node.e) < 0) {
  7. node.left = add(node.left, e);
  8. } else if (e.compareTo(node.e) > 0) {
  9. node.right = add(node.right,e);
  10. }
  11. //更新高度
  12. node.height = Math.max(getHeight(node.left),getHeight(node.right)) + 1;
  13. //计算平衡因子
  14. int balanceFactor = getBalanceFactor(node);
  15. if (Math.abs(balanceFactor) > 1) {
  16. //进行调整
  17. }
  18. return node;
  19. }

我们后面的内容主要是如何调整,后面所以只给出如何调整的代码,在学如何调整之前,我们来写两个辅助函数来判断这棵树是不是二分搜索树和AVL树,因为如果我们的代码有问题的话,有可能破坏二分搜索树的性质,这样有利于我们检查,那怎么检查一棵树是不是二分搜索树,我们根据二分搜索树的性质,它的中序遍历的结果是从小到大的特性,我们重写中序遍历为

  1. public boolean isBST() {
  2. ArrayList<E> arrayList = new ArrayList<>();
  3. inOrder(root, arrayList);
  4. for (int i = 1; i < arrayList.size(); i++) {
  5. if (arrayList.get(i-1).compareTo(arrayList.get(i)) > 0)
  6. return false;
  7. }
  8. }
  9. return true;
  10. }
  11. private void inOrder(Node node, ArrayList<E> arrayList) {
  12. if (node == null) {
  13. return;
  14. }
  15. inOrder(node.left, arrayList);
  16. arrayList.add(node.e);
  17. inOrder(node.right, arrayList);
  18. }

现在我们判断这棵树是不是平衡二叉树

  1. public boolean isBalanced() {
  2. return isBalanced(root);
  3. }
  4. //判断某个节点是不是平衡
  5. private boolean isBalanced(Node node) {
  6. if (node == null) {
  7. return true;
  8. }
  9. int balanceFactor = getBalanceFactor(node);
  10. if (Math.abs(balanceFactor) > 1) {
  11. return false;
  12. }
  13. return isBalanced(node.left) && isBalanced(node.right);
  14. }

下面对不平衡的四种情形进行讨论,并给出调整方法
AVL树 - 图5

  1. // 对节点y进行向右旋转操作,返回旋转后新的根节点x
  2. // y x
  3. // / \ / \
  4. // x T4 向右旋转 (y) z y
  5. // / \ - - - - - - - -> / \ / \
  6. // z T3 T1 T2 T3 T4
  7. // / \
  8. // T1 T2
  9. private Node rightRotate(Node y) {
  10. Node x = y.left;
  11. Node T3 = x.right;
  12. x.right = y;
  13. y.left = T3;
  14. //更新x和y的高度值 先更新y的,因为y是x的右孩子,x的更新取决于y
  15. y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
  16. x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
  17. return x;
  18. }
  19. // 对节点y进行向左旋转操作,返回旋转后新的根节点x
  20. // y x
  21. // / \ / \
  22. // T4 x 向左旋转 (y) y z
  23. // / \ - - - - - - - -> / \ / \
  24. // T3 z T4 T3 T1 T2
  25. // / \
  26. // T1 T2
  27. private Node leftRotate(Node y) {
  28. Node x = y.right;
  29. Node T3 = x.left;
  30. x.left = y;
  31. y.right = T3;
  32. y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
  33. x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
  34. return x;
  35. }
  36. public void add(E e) {
  37. root = add(root, e);
  38. }
  39. private Node add(Node node, E e) {
  40. if (node == null) {
  41. size++;
  42. return new Node(e);
  43. }
  44. if (e.compareTo(node.e) < 0) {
  45. node.left = add(node.left, e);
  46. } else if (e.compareTo(node.e) > 0) {
  47. node.right = add(node.right,e);
  48. }
  49. //更新高度
  50. node.height = Math.max(getHeight(node.left),getHeight(node.right)) + 1;
  51. //计算平衡因子
  52. int balanceFactor = getBalanceFactor(node);
  53. //调整
  54. if (balanceFactor > 1 && getBalanceFactor(node.left) >= 0) {
  55. return rightRotate(node);
  56. }
  57. if (balanceFactor < -1 && getBalanceFactor(node.right) <= 0) {
  58. return leftRotate(node);
  59. }
  60. if (balanceFactor > 1 && getBalanceFactor(node.left) < 0) {
  61. node.left = leftRotate(node.left);
  62. return rightRotate(node);
  63. }
  64. if (balanceFactor < -1 && getBalanceFactor(node.right) > 0) {
  65. node.right = rightRotate(node.right);
  66. return leftRotate(node);
  67. }
  68. return node;
  69. }
  70. public void remove(E e) {
  71. root = remove(root, e);
  72. }
  73. private Node remove(Node node, E e) {
  74. if (node == null) {
  75. return null;
  76. }
  77. Node retNode;
  78. if (e.equals(node.e)) {
  79. if (node.right == null) {
  80. Node leftNode = node.left;
  81. node.left = null;
  82. size--;
  83. retNode = leftNode;
  84. } else if (node.left == null) {
  85. Node rightNode = node.right;
  86. node.right = null;
  87. size--;
  88. retNode = rightNode;
  89. } else {
  90. Node successor = minimum(node.right);
  91. ////由于removeMin没有维持balance,所以我们复用remove
  92. successor.right = remove(node.right,successor.e);
  93. successor.left = node.left;
  94. node.left = node.right = null;
  95. retNode = successor;
  96. }
  97. } else if (e.compareTo(node.e) < 0) {
  98. node.left = remove(node.left, e);
  99. retNode = node;
  100. } else {
  101. node.right = remove(node.right, e);
  102. retNode = node;
  103. }
  104. //否则retNode.height会有空指针异常
  105. if (retNode == null) {
  106. return null;
  107. }
  108. //更新高度
  109. retNode.height = Math.max(getHeight(retNode.left),getHeight(retNode.right)) + 1;
  110. //计算平衡因子
  111. int balanceFactor = getBalanceFactor(retNode);
  112. if (balanceFactor > 1 && getBalanceFactor(retNode.left) >= 0) {
  113. return rightRotate(retNode);
  114. }
  115. if (balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0) {
  116. return leftRotate(retNode);
  117. }
  118. if (balanceFactor > 1 && getBalanceFactor(retNode.left) < 0) {
  119. retNode.left = leftRotate(retNode.left);
  120. return rightRotate(retNode);
  121. }
  122. if (balanceFactor < -1 && getBalanceFactor(retNode.right) > 0) {
  123. retNode.right = rightRotate(retNode.right);
  124. return leftRotate(retNode);
  125. }
  126. return retNode;
  127. }

完整代码

  1. import java.util.ArrayList;
  2. public class AVLTree<E extends Comparable<E>> {
  3. private class Node {
  4. public E e;
  5. public Node left;
  6. public Node right;
  7. public int height;
  8. public Node(E e) {
  9. this.e = e;
  10. left = null;
  11. right = null;
  12. height = 1;
  13. }
  14. @Override
  15. public String toString() {
  16. return e.toString();
  17. }
  18. }
  19. //根节点
  20. private Node root;
  21. //树中元素的个数
  22. private int size;
  23. public AVLTree() {
  24. root = null;
  25. size = 0;
  26. }
  27. public int size() {
  28. return size;
  29. }
  30. public boolean isEmpty() {
  31. return size == 0;
  32. }
  33. private int getHeight(Node node) {
  34. if (node == null) {
  35. return 0;
  36. }
  37. return node.height;
  38. }
  39. private int getBalanceFactor(Node node) {
  40. if (node == null) {
  41. return 0;
  42. }
  43. return getHeight(node.left) - getHeight(node.right);
  44. }
  45. public boolean isBST() {
  46. ArrayList<E> arrayList = new ArrayList<>();
  47. inOrder(root, arrayList);
  48. for (int i = 1; i < arrayList.size(); i++) {
  49. if (arrayList.get(i-1).compareTo(arrayList.get(i)) > 0) {
  50. return false;
  51. }
  52. }
  53. return true;
  54. }
  55. private void inOrder(Node node, ArrayList<E> arrayList) {
  56. if (node == null) {
  57. return;
  58. }
  59. inOrder(node.left, arrayList);
  60. arrayList.add(node.e);
  61. inOrder(node.right, arrayList);
  62. }
  63. public boolean isBalanced() {
  64. return isBalanced(root);
  65. }
  66. //判断某个节点是不是平衡
  67. private boolean isBalanced(Node node) {
  68. if (node == null) {
  69. return true;
  70. }
  71. int balanceFactor = getBalanceFactor(node);
  72. if (Math.abs(balanceFactor) > 1) {
  73. return false;
  74. }
  75. return isBalanced(node.left) && isBalanced(node.right);
  76. }
  77. private Node rightRotate(Node y) {
  78. Node x = y.left;
  79. Node T3 = x.right;
  80. x.right = y;
  81. y.left = T3;
  82. //更新x和y的高度值 先更新y的,因为y是x的右孩子,x的更新取决于y
  83. y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
  84. x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
  85. return x;
  86. }
  87. private Node leftRotate(Node y) {
  88. Node x = y.right;
  89. Node T3 = x.left;
  90. x.left = y;
  91. y.right = T3;
  92. y.height = Math.max(getHeight(y.left), getHeight(y.right)) + 1;
  93. x.height = Math.max(getHeight(x.left), getHeight(x.right)) + 1;
  94. return x;
  95. }
  96. public void add(E e) {
  97. root = add(root, e);
  98. }
  99. private Node add(Node node, E e) {
  100. if (node == null) {
  101. size++;
  102. return new Node(e);
  103. }
  104. if (e.compareTo(node.e) < 0) {
  105. node.left = add(node.left, e);
  106. } else if (e.compareTo(node.e) > 0) {
  107. node.right = add(node.right,e);
  108. }
  109. //更新高度
  110. node.height = Math.max(getHeight(node.left),getHeight(node.right)) + 1;
  111. //计算平衡因子
  112. int balanceFactor = getBalanceFactor(node);
  113. if (balanceFactor > 1 && getBalanceFactor(node.left) >= 0) {
  114. return rightRotate(node);
  115. }
  116. if (balanceFactor < -1 && getBalanceFactor(node.right) <= 0) {
  117. return leftRotate(node);
  118. }
  119. if (balanceFactor > 1 && getBalanceFactor(node.left) < 0) {
  120. node.left = leftRotate(node.left);
  121. return rightRotate(node);
  122. }
  123. if (balanceFactor < -1 && getBalanceFactor(node.right) > 0) {
  124. node.right = rightRotate(node.right);
  125. return leftRotate(node);
  126. }
  127. return node;
  128. }
  129. public boolean contains(E e) {
  130. return contains(root, e);
  131. }
  132. private boolean contains(Node node, E e) {
  133. if (node == null) {
  134. return false;
  135. }
  136. if (e.equals(node.e)) {
  137. return true;
  138. } else if (e.compareTo(node.e) < 0) {
  139. return contains(node.left, e);
  140. } else {
  141. return contains(node.right,e);
  142. }
  143. }
  144. public E minimum() {
  145. if (size == 0) {
  146. throw new IllegalArgumentException("树为空");
  147. }
  148. return minimum(root).e;
  149. }
  150. private Node minimum(Node node) {
  151. if (node.left == null) {
  152. return node;
  153. }
  154. return minimum(node.left);
  155. }
  156. public void remove(E e) {
  157. root = remove(root, e);
  158. }
  159. private Node remove(Node node, E e) {
  160. if (node == null) {
  161. return null;
  162. }
  163. Node retNode;
  164. if (e.equals(node.e)) {
  165. if (node.right == null) {
  166. Node leftNode = node.left;
  167. node.left = null;
  168. size--;
  169. retNode = leftNode;
  170. } else if (node.left == null) {
  171. Node rightNode = node.right;
  172. node.right = null;
  173. size--;
  174. retNode = rightNode;
  175. } else {
  176. Node successor = minimum(node.right);
  177. successor.right = remove(node.right,successor.e);//由于removeMin没有维持balance,所以我们用remove
  178. successor.left = node.left;
  179. node.left = node.right = null;
  180. //size--; 在removeMin中已经维护size了
  181. retNode = successor;
  182. }
  183. } else if (e.compareTo(node.e) < 0) {
  184. node.left = remove(node.left, e);
  185. retNode = node;
  186. } else {
  187. node.right = remove(node.right, e);
  188. retNode = node;
  189. }
  190. //否则retNode.height会有空指针异常
  191. if (retNode == null) {
  192. return null;
  193. }
  194. //更新高度
  195. retNode.height = Math.max(getHeight(retNode.left),getHeight(retNode.right)) + 1;
  196. //计算平衡因子
  197. int balanceFactor = getBalanceFactor(retNode);
  198. if (balanceFactor > 1 && getBalanceFactor(retNode.left) >= 0) {
  199. return rightRotate(retNode);
  200. }
  201. if (balanceFactor < -1 && getBalanceFactor(retNode.right) <= 0) {
  202. return leftRotate(retNode);
  203. }
  204. if (balanceFactor > 1 && getBalanceFactor(retNode.left) < 0) {
  205. retNode.left = leftRotate(retNode.left);
  206. return rightRotate(retNode);
  207. }
  208. if (balanceFactor < -1 && getBalanceFactor(retNode.right) > 0) {
  209. retNode.right = rightRotate(retNode.right);
  210. return leftRotate(retNode);
  211. }
  212. return retNode;
  213. }
  214. }