并查集概念

并查集背景:

  1. 有若干个样本a、b、c、d…类型假设是V
  2. 在并查集中一开始认为每个样本都在单独的集合里
  3. 用户可以在任何时候调用如下两个方法:
    1. boolean isSameSet(V x, V y) : 查询样本x和样本y是否属于一个集合
    2. void union(V x, V y) : 把x和y各自所在集合的所有样本合并成一个集合
  4. isSameSet和union方法的代价越低越好

并查集实现:

  1. 每个节点都有一条往上指的指针
  2. 节点a往上找到的头节点,叫做a所在集合的代表节点
  3. 查询x和y是否属于同一个集合,就是看看找到的代表节点是不是一个
  4. 把x和y各自所在集合的所有点合并成一个集合,只需要小集合的代表点挂在大集合的代表点的下方即可

并查集实现优化:

  1. 节点往上找代表点的过程,把沿途的链变成扁平的
  2. 小集合挂在大集合的下面
  3. 如果方法调用很频繁,那么单次调用的代价为O(1),两个方法都如此

并查集应用:

  1. 解决两大块区域的合并问题
  2. 常用在图等领域中

第15节.pptx

并查集代码实现

  1. public class Code01_UnionFind {
  2. private static class Node<V> {
  3. V value;
  4. public Node(V value) {
  5. this.value = value;
  6. }
  7. @Override
  8. public boolean equals(Object o) {
  9. if (this == o) {
  10. return true;
  11. }
  12. if (o == null || getClass() != o.getClass()) {
  13. return false;
  14. }
  15. Node<?> node = (Node<?>) o;
  16. return Objects.equals(value, node.value);
  17. }
  18. @Override
  19. public int hashCode() {
  20. return Objects.hash(value);
  21. }
  22. }
  23. public static class UnionFind<V> {
  24. /**
  25. * 将传入的节点类型包装为Node类型
  26. */
  27. private HashMap<V, Node<V>> nodes;
  28. /**
  29. * 记录每个节点的父节点
  30. */
  31. private HashMap<Node<V>, Node<V>> parents;
  32. /**
  33. * 记录每个根节点的大小
  34. */
  35. private HashMap<Node<V>, Integer> sizes;
  36. public UnionFind() {
  37. nodes = new HashMap<>();
  38. parents = new HashMap<>();
  39. sizes = new HashMap<>();
  40. }
  41. public UnionFind(List<V> values) {
  42. this();
  43. for (V value : values) {
  44. Node<V> node = new Node<>(value);
  45. nodes.put(value, node);
  46. parents.put(node, node);
  47. sizes.put(node, 1);
  48. }
  49. }
  50. /**
  51. * 向并查集中添加一个元素
  52. *
  53. * @param value 元素
  54. */
  55. public void add(V value) {
  56. Node<V> node = nodes.get(value);
  57. if (node == null) {
  58. nodes.put(value, node);
  59. parents.put(node, node);
  60. sizes.put(node, 1);
  61. }
  62. }
  63. /**
  64. * 判断两个值所在的集合是否为同一个集合
  65. *
  66. * @param one 第一个值
  67. * @param antherOne 第二个值
  68. * @return 是同一个集合返回true, 否则返回false
  69. */
  70. public boolean isSameSet(V one, V antherOne) {
  71. // 查找one的节点
  72. Node<V> oneNode = nodes.get(one);
  73. // 查找antherOne的节点
  74. Node<V> antherNode = nodes.get(antherOne);
  75. // 判断两个数是否都在集合中,不在直接返回false
  76. if (oneNode == null || antherNode == null) {
  77. return false;
  78. }
  79. // 查找两个节点根节点是否为同一个,是同一个则在一个集合中
  80. return findRoot(oneNode) == findRoot(antherNode);
  81. }
  82. /**
  83. * 合并两个值所在的集合
  84. *
  85. * @param one 第一个值
  86. * @param antherOne 第二个值
  87. */
  88. public void union(V one, V antherOne) {
  89. Node<V> oneNode = nodes.get(one);
  90. Node<V> antherNode = nodes.get(antherOne);
  91. // 如果传入的值所对应的节点有不存在
  92. if (oneNode == null || antherNode == null) {
  93. if (antherNode != null) {
  94. oneNode = new Node<>(one);
  95. nodes.put(one, oneNode);
  96. Node<V> root = findRoot(antherNode);
  97. parents.put(oneNode, root);
  98. sizes.put(root, sizes.get(root) + 1);
  99. } else if (oneNode != null) {
  100. antherNode = new Node<>(antherOne);
  101. nodes.put(antherOne, antherNode);
  102. Node<V> root = findRoot(oneNode);
  103. parents.put(antherNode, root);
  104. sizes.put(root, sizes.get(root) + 1);
  105. } else {
  106. oneNode = new Node<>(one);
  107. antherNode = new Node<>(antherOne);
  108. nodes.put(one, oneNode);
  109. nodes.put(antherOne, antherNode);
  110. parents.put(oneNode, antherNode);
  111. sizes.put(antherNode, 2);
  112. }
  113. } else {
  114. Node<V> oneRoot = findRoot(oneNode);
  115. Node<V> antherRoot = findRoot(antherNode);
  116. if (oneRoot == antherRoot) {
  117. return;
  118. }
  119. Integer oneSize = sizes.get(oneRoot);
  120. Integer antherSize = sizes.get(antherRoot);
  121. Node<V> bigRoot = oneSize > antherSize ? oneRoot : antherRoot;
  122. Node<V> smallRoot = bigRoot == oneRoot ? antherRoot : oneRoot;
  123. parents.put(smallRoot, bigRoot);
  124. sizes.remove(smallRoot);
  125. sizes.put(bigRoot, oneSize + antherSize);
  126. }
  127. }
  128. /**
  129. * 返回value 所在集合的大小
  130. *
  131. * @param vales 一个元素
  132. * @return 元素所在结合大小
  133. */
  134. public int size(V vales) {
  135. Node<V> node = nodes.get(vales);
  136. if (node == null) {
  137. return 0;
  138. }
  139. return sizes.get(findRoot(node));
  140. }
  141. /**
  142. * 给一个节点,向上查询,返回根部节点
  143. *
  144. * @param node 一个节点
  145. * @return 根部节点
  146. */
  147. private Node<V> findRoot(Node<V> node) {
  148. // 一个辅助栈,用来临时存放找到的node往上找到的祖先节点
  149. Stack<Node<V>> stack = new Stack<>();
  150. Node<V> cur = node;
  151. // 如果祖先节点不是自己(不是根节点,那么继续往上找
  152. while (parents.get(cur) != cur) {
  153. // 往上找的过程,遇到的祖先节点放入辅助栈中
  154. stack.push(cur);
  155. cur = parents.get(cur);
  156. }
  157. // 将辅助栈中的节点直接连接根节点
  158. while (!stack.empty()) {
  159. parents.put(stack.pop(), cur);
  160. }
  161. return cur;
  162. }
  163. }
  164. }

并查集应用

题目1 省份数量

lettcode 547:https://leetcode-cn.com/problems/number-of-provinces/
有 n 个城市,其中一些彼此相连,另一些没有相连。如果城市 a 与城市 b 直接相连,且城市 b 与城市 c 直接相连,那么城市 a 与城市 c 间接相连。
省份 是一组直接或间接相连的城市,组内不含其他没有相连的城市。
给你一个 n x n 的矩阵 isConnected ,其中 isConnected[i][j] = 1 表示第 i 个城市和第 j 个城市直接相连,而 isConnected[i][j] = 0 表示二者不直接相连。
返回矩阵中 省份 的数量

public class Code02_FindCrcle {


    // 给你一个 n x n 的矩阵 isConnected ,其中 isConnected[i][j] = 1 表示第 i 个城市和第 j 个城市直接相连,而 isConnected[i][j] = 0 表示二者不直接相连。
    //
    //返回矩阵中 省份 的数量。
    //
    //来源:力扣(LeetCode)
    //链接:https://leetcode-cn.com/problems/number-of-provinces

    public int findCircleNum(int[][] isConnected) {
        if (isConnected == null || isConnected.length == 0) {
            return 0;
        }
        int n = isConnected.length;
        UnionSet unionSet = new UnionSet(n);
        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                if (isConnected[i][j] == 1) {
                    unionSet.union(i, j);
                }
            }
        }
        return unionSet.sets();
    }


    private static class UnionSet {
        /**
         * 当前节点的表示主线节点
         */
        private int[] parents;
        /**
         * 表示祖先节点当前大小
         */
        private int[] size;
        /**
         * 辅助数组
         */
        private int[] help;
        /**
         * 当前并查集集合集合个数
         */
        private int setSize;

        public UnionSet(int n) {
            parents = new int[n];
            size = new int[n];
            help = new int[n];
            for (int i = 0; i < n; i++) {
                parents[i] = i;
                size[i] = 1;
            }
            setSize = n;
        }

        /**
         * 找到该节点的根节点
         *
         * @param i 下标为i的及诶到哪
         * @return 根节点的下标,如果给定的下标越界,返回-1,表示不存在
         */
        private int findRoot(int i) {
            // 如果给定的下标越界,返回-1,表示不存在
            if (i < 0 || i >= parents.length) {
                return -1;
            }
            int index = 0;
            while (parents[i] != i) {
                help[index++] = i;
                i = parents[i];
            }
            for (index--; index >= 0; index--) {
                parents[help[index]] = i;
            }
            return i;
        }

        /**
         * 将下标为i,j的两个节点所在的集合合并
         *
         * @param i 下标i所在节点
         * @param j 下标j所在节点
         */
        public void union(int i, int j) {
            // 如果有一个下标越界,直接返回
            if (i < 0 || i >= parents.length || j < 0 || j >= parents.length) {
                return;
            }
            int rootI = findRoot(i);
            int rootJ = findRoot(j);
            if (rootI != rootJ) {
                int sizeI = size[rootI];
                int sizeJ = size[rootJ];
                if (sizeI > sizeJ) {
                    parents[rootJ] = rootI;
                    size[rootI] = sizeI + sizeJ;
                } else {
                    parents[rootI] = rootJ;
                    size[rootJ] = sizeI + sizeJ;
                }
                setSize--;
            }
        }

        /**
         * 返回并查集集合个数
         *
         * @return 集合个数
         */
        public int sets() {
            return setSize;
        }
    }
}

题目2:岛问题

给定一个二维数组matrix,里面的值不是1就是0,
上、下、左、右相邻的1认为是一片岛,
返回matrix中岛的数量

//给定一个二维数组matrix,里面的值不是1就是0,上、下、左、右相邻的1认为是一片岛,
//返回matrix中岛的数量

public static int numberOfIsland1(char[][] grid) {
    if (grid == null || grid.length == 0) {
        return 0;
    }
    int r = grid.length;
    int c = grid[0].length;
    int ans = 0;
    for (int i = 0; i < r; i++) {
        for (int j = 0; j < c; j++) {
            if (grid[i][j] == '1') {
                infect(grid, i, j);
                ans++;
            }
        }
    }
    return ans;
}

private static void infect(char[][] arr, int i, int j) {
    if (i < 0 || i == arr.length || j < 0 || j == arr[0].length || arr[i][j] != '1') {
        return;
    }
    arr[i][j] = 0;
    infect(arr, i - 1, j);
    infect(arr, i + 1, j);
    infect(arr, i, j - 1);
    infect(arr, i, j + 1);
}
public static int numIslands3(char[][] board) {
    int row = board.length;
    int col = board[0].length;
    UnionSet2 uf = new UnionSet2(board);
    for (int j = 1; j < col; j++) {
        if (board[0][j - 1] == '1' && board[0][j] == '1') {
            uf.union(0, j - 1, 0, j);
        }
    }
    for (int i = 1; i < row; i++) {
        if (board[i - 1][0] == '1' && board[i][0] == '1') {
            uf.union(i - 1, 0, i, 0);
        }
    }
    for (int i = 1; i < row; i++) {
        for (int j = 1; j < col; j++) {
            if (board[i][j] == '1') {
                if (board[i][j - 1] == '1') {
                    uf.union(i, j - 1, i, j);
                }
                if (board[i - 1][j] == '1') {
                    uf.union(i - 1, j, i, j);
                }
            }
        }
    }
    return uf.sets();
}


private static class UnionSet2 {
    private final int[] parent;
    private final int[] size;
    private final int[] help;
    private int sets;
    private final int col;

    public UnionSet2(char[][] board) {
        col = board[0].length;
        sets = 0;
        int row = board.length;
        int len = row * col;
        parent = new int[len];
        size = new int[len];
        help = new int[len];

        for (int r = 0; r < row; r++) {
            for (int c = 0; c < col; c++) {
                if (board[r][c] == '1') {
                    int i = index(r, c);
                    parent[i] = i;
                    size[i] = 1;
                    sets++;
                }
            }
        }
    }


    private int index(int r, int c) {
        return r * col + c;
    }

    private int find(int i) { //i=2
        int hi = 0;
        while (i != parent[i]) { // parent[2]=1  parent[1]=3
            help[hi++] = i; // help[0] = 2 help[1]=3
            i = parent[i]; // i = 1 i = 3
        }
        for (hi--; hi >= 0; hi--) {
            parent[help[hi]] = i;
        }
        return i;
    }


    public void union(int r1, int c1, int r2, int c2) {
        int i1 = index(r1, c1);
        int i2 = index(r2, c2);
        int f1 = find(i1);
        int f2 = find(i2);
        if (f1 != f2) {
            if (size[f1] >= size[f2]) {
                size[f1] += size[f2];
                parent[f2] = f1;
            } else {
                size[f2] += size[f1];
                parent[f1] = f2;
            }
            sets--;
        }
    }

    public int sets() {
        return sets;
    }
}
public static int numIslands2(char[][] grid) {
    int r = grid.length;
    int c = grid[0].length;
    Dot[][] dots = new Dot[r][c];
    List<Dot> list = new ArrayList<>();
    for (int i = 0; i < r; i++) {
        for (int j = 0; j < c; j++) {
            if (grid[i][j] == '1') {
                dots[i][j] = new Dot();
                list.add(dots[i][j]);
            }
        }
    }

    UnionSet1<Dot> unionSet = new UnionSet1<>(list);
    for (int i = 1; i < c; i++) {
        if (grid[0][i - 1] == '1' && grid[0][i] == '1') {
            unionSet.union(dots[0][i - 1], dots[0][i]);
        }
    }

    for (int i = 1; i < r; i++) {
        if (grid[i][0] == '1' && grid[i - 1][0] == '1') {
            unionSet.union(dots[i][0], dots[i - 1][0]);
        }
    }

    for (int i = 1; i < r; i++) {
        for (int j = 1; j < c; j++) {
            if (grid[i][j] == '1') {
                if (grid[i - 1][j] == '1') {
                    unionSet.union(dots[i - 1][j], dots[i][j]);
                }
                if (grid[i][j - 1] == '1') {
                    unionSet.union(dots[i][j - 1], dots[i][j]);
                }
            }
        }
    }
    return unionSet.sets();
}

private static class Dot {
}

private static class Node<V> {
    V val;

    public Node(V val) {
        this.val = val;
    }
}

private static class UnionSet1<V> {
    private final HashMap<Node<V>, Node<V>> parents;
    private final HashMap<V, Node<V>> nodes;

    private final HashMap<Node<V>, Integer> sizeMap;

    public UnionSet1(List<V> list) {
        parents = new HashMap<>();
        nodes = new HashMap<>();
        sizeMap = new HashMap<>();
        for (V v : list) {
            Node<V> node = new Node<>(v);
            parents.put(node, node);
            nodes.put(v, node);
            sizeMap.put(node, 1);
        }
    }

    private Node<V> findRoot(Node<V> v) {
        Stack<Node<V>> stack = new Stack<>();
        Node<V> cur = v;
        while (parents.get(cur) != cur) {
            stack.push(cur);
            cur = parents.get(cur);
        }
        while (!stack.empty()) {
            parents.put(stack.pop(), cur);
        }
        return cur;
    }

    public void union(V one, V anotherOne) {
        Node<V> oneRoot = findRoot(nodes.get(one));
        Node<V> anotherOneRoot = findRoot(nodes.get(anotherOne));
        if (oneRoot != anotherOneRoot) {
            int oneSize = sizeMap.get(oneRoot);
            int anotherSize = sizeMap.get(anotherOneRoot);
            Node<V> big = oneSize > anotherSize ? oneRoot : anotherOneRoot;
            Node<V> small = big == oneRoot ? anotherOneRoot : oneRoot;
            parents.put(small, big);
            sizeMap.put(big, oneSize + anotherSize);
            sizeMap.remove(small);
        }
    }

    public int sets() {
        return sizeMap.size();
    }
}

题目3:岛问题扩展

由m行和n列组成的二维网格图最初充满了水。我们可以执行一个addLand操作,将位置(row, col)的水变成陆地。给定要操作的位置列表,计算每个addLand操作后的岛屿数量。岛屿被水包围,通过水平或垂直连接相邻的陆地而形成。你可以假设网格的四边都被水包围着。

示例:
输入: m = 3, n = 3, 
    positions = [[0,0], [0,1], [1,2], [2,1]]
输出: [1,1,2,3]
解析:
起初,二维网格 grid 被全部注入「水」。(0 代表「水」,1 代表「陆地」)
0 0 0
0 0 0
0 0 0
操作 #1:addLand(0, 0) 将 grid[0][0] 的水变为陆地。
1 0 0
0 0 0   Number of islands = 1
0 0 0
操作 #2:addLand(0, 1) 将 grid[0][1] 的水变为陆地。
1 1 0
0 0 0   岛屿的数量为 1
0 0 0
操作 #3:addLand(1, 2) 将 grid[1][2] 的水变为陆地。
1 1 0
0 0 1   岛屿的数量为 2
0 0 0
操作 #4:addLand(2, 1) 将 grid[2][1] 的水变为陆地。
1 1 0
0 0 1   岛屿的数量为 3
0 1 0
拓展:
你是否能在 O(k log mn) 的时间复杂度程度内完成每次的计算?
(k 表示 positions 的长度)
public static List<Integer> numIslands2(int m, int n, int[][] positions) {

    if (m <= 0 || n <= 0 || positions == null || positions.length == 0) {
        return null;
    }
    List<Integer> ans = new ArrayList<>(positions.length);
    UnionSet unionSet = new UnionSet(m, n);
    for (int[] position : positions) {
        if (position[0] < 0 || position[0] >= m || position[1] < 0 || position[1] >= n) {
            ans.add(0);
        } else {
            ans.add(unionSet.connect(position));
        }
    }
    return ans;
}

private static class UnionSet {
    /**
     * 当前节点的父节点
     */
    private final int[] parents;
    /**
     * 当前节点的节点数
     */
    private final int[] size;
    /**
     * 辅助数组
     */
    private final int[] help;
    /**
     * 集的数量
     */
    private int sets;

    private final int column;
    private final int row;

    public UnionSet(int m, int n) {
        int len = m * n;
        parents = new int[len];
        size = new int[len];
        help = new int[len];
        sets = 0;
        column = n;
        row = m;
    }


    private int index(int r, int c) {
        return r * column + c;
    }

    private int find(int i) {
        int hi = 0;
        while (i != parents[i]) {
            help[hi++] = i;
            i = parents[i];
        }
        for (hi--; hi >= 0; hi--) {
            parents[help[hi]] = i;
        }
        return i;
    }

    public void union(int r1, int c1, int r2, int c2) {
        if (r1 < 0 || r1 == row || r2 < 0 || r2 == row || c1 < 0 || c1 == column || c2 < 0 || c2 == column) {
            return;
        }
        int i1 = index(r1, c1);
        int i2 = index(r2, c2);
        if (size[i1] == 0 || size[i2] == 0) {
            return;
        }
        int f1 = find(i1);
        int f2 = find(i2);

        if (f1 != f2) {
            if (size[f1] > size[f2]) {
                parents[f2] = f1;
                size[f1] += size[f2];
            } else {
                parents[f1] = f2;
                size[f2] += size[f1];
            }
            sets--;
        }
    }

    public int connect(int[] location) {
        int r = location[0];
        int c = location[1];
        int i = index(r, c);
        if (size[i] == 0) {
            parents[i] = i;
            size[i] = 1;
            sets++;
            union(r, c, r, c - 1);
            union(r, c, r, c + 1);
            union(r, c, r - 1, c);
            union(r, c, r + 1, c);
        }
        return sets;
    }
}
// 当m、n 很大时,position比较少时,并查集初始化占用太多资源 ,进行优化

public static List<Integer> numIslands23(int m, int n, int[][] positions) {
    if (m <= 0 || n <= 0 || positions == null || positions.length == 0) {
        return null;
    }
    UnionSet2 unionSet2 = new UnionSet2();
    List<Integer> list = new ArrayList<>();
    for (int[] position : positions) {
        if (position[0] < 0 || position[1] < 0 || position[0] >= m || position[1] >= n) {
            list.add(0);
        } else {
            list.add(unionSet2.connect(position[0], position[1]));
        }
    }
    return list;

}

private static class UnionSet2 {

    private final HashMap<String, String> parents;
    private final HashMap<String, Integer> size;
    private final ArrayList<String> help;
    private int sets;

    public UnionSet2() {
        parents = new HashMap<>();
        size = new HashMap<>();
        help = new ArrayList<>();
        sets = 0;
    }

    private String find(String s) {
        while (!s.equals(parents.get(s))) {
            help.add(s);
            s = parents.get(s);
        }

        for (String s1 : help) {
            parents.put(s1, s);
        }
        help.clear();
        return s;
    }

    public void union(String s1, String s2) {

        if (!parents.containsKey(s1) || !parents.containsKey(s2)) {
            return;
        }

        String f1 = find(s1);
        String f2 = find(s2);
        if (!f1.equals(f2)) {
            if (size.get(f1) > size.get(f1)) {
                parents.put(f2, f1);
                size.put(f1, size.get(f1) + size.get(f2));
                size.remove(f2);
            } else {
                parents.put(f1, f2);
                size.put(f2, size.get(f1) + size.get(f2));
                size.remove(f1);
            }
            sets--;
        }
    }

    public int connect(int r, int c) {
        String key = String.valueOf(r).concat("-").concat(String.valueOf(c));
        if (!parents.containsKey(key)) {
            parents.put(key, key);
            size.put(key, 1);
            sets++;

            String up = String.valueOf(r - 1).concat("-").concat(String.valueOf(c));
            String down = String.valueOf(r + 1).concat("-").concat(String.valueOf(c));
            String left = String.valueOf(r).concat("-").concat(String.valueOf(c - 1));
            String right = String.valueOf(r).concat("-").concat(String.valueOf(c + 1));
            union(key, up);
            union(key, down);
            union(key, left);
            union(key, right);
        }
        return sets;
    }
}