1. package com.zhangyong.DataStructures.Tree.RedBlackTree;
    2. import com.zhangyong.DataStructures.Tree.AVL.FileOperation;
    3. import java.util.ArrayList;
    4. public class RBTree<K extends Comparable<K>, V> {
    5. private static final boolean RED = true;
    6. private static final boolean BLACK = false;
    7. private class Node {
    8. public K key;
    9. public V value;
    10. public Node left, right;
    11. public boolean color;
    12. public Node(K key, V value) {
    13. this.key = key;
    14. this.value = value;
    15. left = null;
    16. right = null;
    17. color = RED;
    18. }
    19. }
    20. private Node root;
    21. private int size;
    22. public RBTree() {
    23. root = null;
    24. size = 0;
    25. }
    26. public int getSize() {
    27. return size;
    28. }
    29. public boolean isEmpty() {
    30. return size == 0;
    31. }
    32. // 判断节点node的颜色
    33. private boolean isRed(Node node) {
    34. if (node == null)
    35. return BLACK;
    36. return node.color;
    37. }
    38. /**
    39. * a c
    40. * / \ / \
    41. * b c ===> a e
    42. * / \ / \ \
    43. * d e b d f
    44. * \
    45. * f
    46. * 左旋转
    47. */
    48. private Node leftRotate(Node node) {
    49. Node x = node.right;
    50. node.right = x.left;
    51. x.left = node;
    52. x.color = node.color;
    53. node.color = RED;
    54. return x;
    55. }
    56. /**
    57. * a b
    58. * / \ / \
    59. * b c ===> d a
    60. * / \ / / \
    61. * d e f e c
    62. * /
    63. * f
    64. */
    65. private Node rightRotate(Node node) {
    66. Node x = node.left;
    67. node.left = x.right;
    68. x.right = node;
    69. x.color = node.color;
    70. node.color = RED;
    71. return x;
    72. }
    73. /**
    74. * 颜色翻转
    75. *
    76. * @param node
    77. */
    78. private void flipColor(Node node) {
    79. node.color = RED;
    80. node.left.color = BLACK;
    81. node.right.color = BLACK;
    82. }
    83. // 向二分搜索树中添加新的元素(key, value)
    84. public void add(K key, V value) {
    85. root = add(root, key, value);
    86. root.color = BLACK; //保持最终的根节点为黑色;
    87. }
    88. // 向以node为根的红黑树中插入元素(key, value),递归算法
    89. // 返回插入新节点后红黑树的根
    90. private Node add(Node node, K key, V value) {
    91. if (node == null) {
    92. size++;
    93. return new Node(key, value);
    94. }
    95. if (key.compareTo(node.key) < 0)
    96. node.left = add(node.left, key, value);
    97. else if (key.compareTo(node.key) > 0)
    98. node.right = add(node.right, key, value);
    99. else // key.compareTo(node.key) == 0
    100. node.value = value;
    101. if (isRed(node.right) && !isRed(node.left)) {
    102. node = leftRotate(node);
    103. }
    104. if (isRed(node.left) && isRed(node.left.left)) {
    105. node = rightRotate(node);
    106. }
    107. if (isRed(node.left) && isRed(node.right)) {
    108. flipColor(node);
    109. }
    110. return node;
    111. }
    112. // 返回以node为根节点的二分搜索树中,key所在的节点
    113. private Node getNode(Node node, K key) {
    114. if (node == null)
    115. return null;
    116. if (key.equals(node.key))
    117. return node;
    118. else if (key.compareTo(node.key) < 0)
    119. return getNode(node.left, key);
    120. else // if(key.compareTo(node.key) > 0)
    121. return getNode(node.right, key);
    122. }
    123. public boolean contains(K key) {
    124. return getNode(root, key) != null;
    125. }
    126. public V get(K key) {
    127. Node node = getNode(root, key);
    128. return node == null ? null : node.value;
    129. }
    130. public void set(K key, V newValue) {
    131. Node node = getNode(root, key);
    132. if (node == null)
    133. throw new IllegalArgumentException(key + " doesn't exist!");
    134. node.value = newValue;
    135. }
    136. // 返回以node为根的二分搜索树的最小值所在的节点
    137. private Node minimum(Node node) {
    138. if (node.left == null)
    139. return node;
    140. return minimum(node.left);
    141. }
    142. // 删除掉以node为根的二分搜索树中的最小节点
    143. // 返回删除节点后新的二分搜索树的根
    144. private Node removeMin(Node node) {
    145. if (node.left == null) {
    146. Node rightNode = node.right;
    147. node.right = null;
    148. size--;
    149. return rightNode;
    150. }
    151. node.left = removeMin(node.left);
    152. return node;
    153. }
    154. // 从二分搜索树中删除键为key的节点
    155. public V remove(K key) {
    156. Node node = getNode(root, key);
    157. if (node != null) {
    158. root = remove(root, key);
    159. return node.value;
    160. }
    161. return null;
    162. }
    163. private Node remove(Node node, K key) {
    164. if (node == null)
    165. return null;
    166. if (key.compareTo(node.key) < 0) {
    167. node.left = remove(node.left, key);
    168. return node;
    169. } else if (key.compareTo(node.key) > 0) {
    170. node.right = remove(node.right, key);
    171. return node;
    172. } else { // key.compareTo(node.key) == 0
    173. // 待删除节点左子树为空的情况
    174. if (node.left == null) {
    175. Node rightNode = node.right;
    176. node.right = null;
    177. size--;
    178. return rightNode;
    179. }
    180. // 待删除节点右子树为空的情况
    181. if (node.right == null) {
    182. Node leftNode = node.left;
    183. node.left = null;
    184. size--;
    185. return leftNode;
    186. }
    187. // 待删除节点左右子树均不为空的情况
    188. // 找到比待删除节点大的最小节点, 即待删除节点右子树的最小节点
    189. // 用这个节点顶替待删除节点的位置
    190. Node successor = minimum(node.right);
    191. successor.right = removeMin(node.right);
    192. successor.left = node.left;
    193. node.left = node.right = null;
    194. return successor;
    195. }
    196. }
    197. public static void main(String[] args) {
    198. System.out.println("Pride and Prejudice");
    199. ArrayList<String> words = new ArrayList<>();
    200. if (FileOperation.readFile("pride-and-prejudice.txt", words)) {
    201. System.out.println("Total words: " + words.size());
    202. RBTree<String, Integer> map = new RBTree<>();
    203. for (String word : words) {
    204. if (map.contains(word))
    205. map.set(word, map.get(word) + 1);
    206. else
    207. map.add(word, 1);
    208. }
    209. System.out.println("Total different words: " + map.getSize());
    210. System.out.println("Frequency of PRIDE: " + map.get("pride"));
    211. System.out.println("Frequency of PREJUDICE: " + map.get("prejudice"));
    212. }
    213. System.out.println();
    214. }
    215. }