并查集概念
并查集背景:
- 有若干个样本a、b、c、d…类型假设是V
- 在并查集中一开始认为每个样本都在单独的集合里
- 用户可以在任何时候调用如下两个方法:
- boolean isSameSet(V x, V y) : 查询样本x和样本y是否属于一个集合
- void union(V x, V y) : 把x和y各自所在集合的所有样本合并成一个集合
- isSameSet和union方法的代价越低越好
并查集实现:
- 每个节点都有一条往上指的指针
- 节点a往上找到的头节点,叫做a所在集合的代表节点
- 查询x和y是否属于同一个集合,就是看看找到的代表节点是不是一个
- 把x和y各自所在集合的所有点合并成一个集合,只需要小集合的代表点挂在大集合的代表点的下方即可
并查集实现优化:
- 节点往上找代表点的过程,把沿途的链变成扁平的
- 小集合挂在大集合的下面
- 如果方法调用很频繁,那么单次调用的代价为O(1),两个方法都如此
并查集应用:
- 解决两大块区域的合并问题
- 常用在图等领域中
并查集代码实现
public class Code01_UnionFind {
private static class Node<V> {
V value;
public Node(V value) {
this.value = value;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Node<?> node = (Node<?>) o;
return Objects.equals(value, node.value);
}
@Override
public int hashCode() {
return Objects.hash(value);
}
}
public static class UnionFind<V> {
/**
* 将传入的节点类型包装为Node类型
*/
private HashMap<V, Node<V>> nodes;
/**
* 记录每个节点的父节点
*/
private HashMap<Node<V>, Node<V>> parents;
/**
* 记录每个根节点的大小
*/
private HashMap<Node<V>, Integer> sizes;
public UnionFind() {
nodes = new HashMap<>();
parents = new HashMap<>();
sizes = new HashMap<>();
}
public UnionFind(List<V> values) {
this();
for (V value : values) {
Node<V> node = new Node<>(value);
nodes.put(value, node);
parents.put(node, node);
sizes.put(node, 1);
}
}
/**
* 向并查集中添加一个元素
*
* @param value 元素
*/
public void add(V value) {
Node<V> node = nodes.get(value);
if (node == null) {
nodes.put(value, node);
parents.put(node, node);
sizes.put(node, 1);
}
}
/**
* 判断两个值所在的集合是否为同一个集合
*
* @param one 第一个值
* @param antherOne 第二个值
* @return 是同一个集合返回true, 否则返回false
*/
public boolean isSameSet(V one, V antherOne) {
// 查找one的节点
Node<V> oneNode = nodes.get(one);
// 查找antherOne的节点
Node<V> antherNode = nodes.get(antherOne);
// 判断两个数是否都在集合中,不在直接返回false
if (oneNode == null || antherNode == null) {
return false;
}
// 查找两个节点根节点是否为同一个,是同一个则在一个集合中
return findRoot(oneNode) == findRoot(antherNode);
}
/**
* 合并两个值所在的集合
*
* @param one 第一个值
* @param antherOne 第二个值
*/
public void union(V one, V antherOne) {
Node<V> oneNode = nodes.get(one);
Node<V> antherNode = nodes.get(antherOne);
// 如果传入的值所对应的节点有不存在
if (oneNode == null || antherNode == null) {
if (antherNode != null) {
oneNode = new Node<>(one);
nodes.put(one, oneNode);
Node<V> root = findRoot(antherNode);
parents.put(oneNode, root);
sizes.put(root, sizes.get(root) + 1);
} else if (oneNode != null) {
antherNode = new Node<>(antherOne);
nodes.put(antherOne, antherNode);
Node<V> root = findRoot(oneNode);
parents.put(antherNode, root);
sizes.put(root, sizes.get(root) + 1);
} else {
oneNode = new Node<>(one);
antherNode = new Node<>(antherOne);
nodes.put(one, oneNode);
nodes.put(antherOne, antherNode);
parents.put(oneNode, antherNode);
sizes.put(antherNode, 2);
}
} else {
Node<V> oneRoot = findRoot(oneNode);
Node<V> antherRoot = findRoot(antherNode);
if (oneRoot == antherRoot) {
return;
}
Integer oneSize = sizes.get(oneRoot);
Integer antherSize = sizes.get(antherRoot);
Node<V> bigRoot = oneSize > antherSize ? oneRoot : antherRoot;
Node<V> smallRoot = bigRoot == oneRoot ? antherRoot : oneRoot;
parents.put(smallRoot, bigRoot);
sizes.remove(smallRoot);
sizes.put(bigRoot, oneSize + antherSize);
}
}
/**
* 返回value 所在集合的大小
*
* @param vales 一个元素
* @return 元素所在结合大小
*/
public int size(V vales) {
Node<V> node = nodes.get(vales);
if (node == null) {
return 0;
}
return sizes.get(findRoot(node));
}
/**
* 给一个节点,向上查询,返回根部节点
*
* @param node 一个节点
* @return 根部节点
*/
private Node<V> findRoot(Node<V> node) {
// 一个辅助栈,用来临时存放找到的node往上找到的祖先节点
Stack<Node<V>> stack = new Stack<>();
Node<V> cur = node;
// 如果祖先节点不是自己(不是根节点,那么继续往上找
while (parents.get(cur) != cur) {
// 往上找的过程,遇到的祖先节点放入辅助栈中
stack.push(cur);
cur = parents.get(cur);
}
// 将辅助栈中的节点直接连接根节点
while (!stack.empty()) {
parents.put(stack.pop(), cur);
}
return cur;
}
}
}
并查集应用
题目1 省份数量
lettcode 547:https://leetcode-cn.com/problems/number-of-provinces/
有 n 个城市,其中一些彼此相连,另一些没有相连。如果城市 a 与城市 b 直接相连,且城市 b 与城市 c 直接相连,那么城市 a 与城市 c 间接相连。
省份 是一组直接或间接相连的城市,组内不含其他没有相连的城市。
给你一个 n x n 的矩阵 isConnected ,其中 isConnected[i][j] = 1 表示第 i 个城市和第 j 个城市直接相连,而 isConnected[i][j] = 0 表示二者不直接相连。
返回矩阵中 省份 的数量
public class Code02_FindCrcle {
// 给你一个 n x n 的矩阵 isConnected ,其中 isConnected[i][j] = 1 表示第 i 个城市和第 j 个城市直接相连,而 isConnected[i][j] = 0 表示二者不直接相连。
//
//返回矩阵中 省份 的数量。
//
//来源:力扣(LeetCode)
//链接:https://leetcode-cn.com/problems/number-of-provinces
public int findCircleNum(int[][] isConnected) {
if (isConnected == null || isConnected.length == 0) {
return 0;
}
int n = isConnected.length;
UnionSet unionSet = new UnionSet(n);
for (int i = 0; i < n; i++) {
for (int j = i + 1; j < n; j++) {
if (isConnected[i][j] == 1) {
unionSet.union(i, j);
}
}
}
return unionSet.sets();
}
private static class UnionSet {
/**
* 当前节点的表示主线节点
*/
private int[] parents;
/**
* 表示祖先节点当前大小
*/
private int[] size;
/**
* 辅助数组
*/
private int[] help;
/**
* 当前并查集集合集合个数
*/
private int setSize;
public UnionSet(int n) {
parents = new int[n];
size = new int[n];
help = new int[n];
for (int i = 0; i < n; i++) {
parents[i] = i;
size[i] = 1;
}
setSize = n;
}
/**
* 找到该节点的根节点
*
* @param i 下标为i的及诶到哪
* @return 根节点的下标,如果给定的下标越界,返回-1,表示不存在
*/
private int findRoot(int i) {
// 如果给定的下标越界,返回-1,表示不存在
if (i < 0 || i >= parents.length) {
return -1;
}
int index = 0;
while (parents[i] != i) {
help[index++] = i;
i = parents[i];
}
for (index--; index >= 0; index--) {
parents[help[index]] = i;
}
return i;
}
/**
* 将下标为i,j的两个节点所在的集合合并
*
* @param i 下标i所在节点
* @param j 下标j所在节点
*/
public void union(int i, int j) {
// 如果有一个下标越界,直接返回
if (i < 0 || i >= parents.length || j < 0 || j >= parents.length) {
return;
}
int rootI = findRoot(i);
int rootJ = findRoot(j);
if (rootI != rootJ) {
int sizeI = size[rootI];
int sizeJ = size[rootJ];
if (sizeI > sizeJ) {
parents[rootJ] = rootI;
size[rootI] = sizeI + sizeJ;
} else {
parents[rootI] = rootJ;
size[rootJ] = sizeI + sizeJ;
}
setSize--;
}
}
/**
* 返回并查集集合个数
*
* @return 集合个数
*/
public int sets() {
return setSize;
}
}
}
题目2:岛问题
给定一个二维数组matrix,里面的值不是1就是0,
上、下、左、右相邻的1认为是一片岛,
返回matrix中岛的数量
//给定一个二维数组matrix,里面的值不是1就是0,上、下、左、右相邻的1认为是一片岛,
//返回matrix中岛的数量
public static int numberOfIsland1(char[][] grid) {
if (grid == null || grid.length == 0) {
return 0;
}
int r = grid.length;
int c = grid[0].length;
int ans = 0;
for (int i = 0; i < r; i++) {
for (int j = 0; j < c; j++) {
if (grid[i][j] == '1') {
infect(grid, i, j);
ans++;
}
}
}
return ans;
}
private static void infect(char[][] arr, int i, int j) {
if (i < 0 || i == arr.length || j < 0 || j == arr[0].length || arr[i][j] != '1') {
return;
}
arr[i][j] = 0;
infect(arr, i - 1, j);
infect(arr, i + 1, j);
infect(arr, i, j - 1);
infect(arr, i, j + 1);
}
public static int numIslands3(char[][] board) {
int row = board.length;
int col = board[0].length;
UnionSet2 uf = new UnionSet2(board);
for (int j = 1; j < col; j++) {
if (board[0][j - 1] == '1' && board[0][j] == '1') {
uf.union(0, j - 1, 0, j);
}
}
for (int i = 1; i < row; i++) {
if (board[i - 1][0] == '1' && board[i][0] == '1') {
uf.union(i - 1, 0, i, 0);
}
}
for (int i = 1; i < row; i++) {
for (int j = 1; j < col; j++) {
if (board[i][j] == '1') {
if (board[i][j - 1] == '1') {
uf.union(i, j - 1, i, j);
}
if (board[i - 1][j] == '1') {
uf.union(i - 1, j, i, j);
}
}
}
}
return uf.sets();
}
private static class UnionSet2 {
private final int[] parent;
private final int[] size;
private final int[] help;
private int sets;
private final int col;
public UnionSet2(char[][] board) {
col = board[0].length;
sets = 0;
int row = board.length;
int len = row * col;
parent = new int[len];
size = new int[len];
help = new int[len];
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
if (board[r][c] == '1') {
int i = index(r, c);
parent[i] = i;
size[i] = 1;
sets++;
}
}
}
}
private int index(int r, int c) {
return r * col + c;
}
private int find(int i) { //i=2
int hi = 0;
while (i != parent[i]) { // parent[2]=1 parent[1]=3
help[hi++] = i; // help[0] = 2 help[1]=3
i = parent[i]; // i = 1 i = 3
}
for (hi--; hi >= 0; hi--) {
parent[help[hi]] = i;
}
return i;
}
public void union(int r1, int c1, int r2, int c2) {
int i1 = index(r1, c1);
int i2 = index(r2, c2);
int f1 = find(i1);
int f2 = find(i2);
if (f1 != f2) {
if (size[f1] >= size[f2]) {
size[f1] += size[f2];
parent[f2] = f1;
} else {
size[f2] += size[f1];
parent[f1] = f2;
}
sets--;
}
}
public int sets() {
return sets;
}
}
public static int numIslands2(char[][] grid) {
int r = grid.length;
int c = grid[0].length;
Dot[][] dots = new Dot[r][c];
List<Dot> list = new ArrayList<>();
for (int i = 0; i < r; i++) {
for (int j = 0; j < c; j++) {
if (grid[i][j] == '1') {
dots[i][j] = new Dot();
list.add(dots[i][j]);
}
}
}
UnionSet1<Dot> unionSet = new UnionSet1<>(list);
for (int i = 1; i < c; i++) {
if (grid[0][i - 1] == '1' && grid[0][i] == '1') {
unionSet.union(dots[0][i - 1], dots[0][i]);
}
}
for (int i = 1; i < r; i++) {
if (grid[i][0] == '1' && grid[i - 1][0] == '1') {
unionSet.union(dots[i][0], dots[i - 1][0]);
}
}
for (int i = 1; i < r; i++) {
for (int j = 1; j < c; j++) {
if (grid[i][j] == '1') {
if (grid[i - 1][j] == '1') {
unionSet.union(dots[i - 1][j], dots[i][j]);
}
if (grid[i][j - 1] == '1') {
unionSet.union(dots[i][j - 1], dots[i][j]);
}
}
}
}
return unionSet.sets();
}
private static class Dot {
}
private static class Node<V> {
V val;
public Node(V val) {
this.val = val;
}
}
private static class UnionSet1<V> {
private final HashMap<Node<V>, Node<V>> parents;
private final HashMap<V, Node<V>> nodes;
private final HashMap<Node<V>, Integer> sizeMap;
public UnionSet1(List<V> list) {
parents = new HashMap<>();
nodes = new HashMap<>();
sizeMap = new HashMap<>();
for (V v : list) {
Node<V> node = new Node<>(v);
parents.put(node, node);
nodes.put(v, node);
sizeMap.put(node, 1);
}
}
private Node<V> findRoot(Node<V> v) {
Stack<Node<V>> stack = new Stack<>();
Node<V> cur = v;
while (parents.get(cur) != cur) {
stack.push(cur);
cur = parents.get(cur);
}
while (!stack.empty()) {
parents.put(stack.pop(), cur);
}
return cur;
}
public void union(V one, V anotherOne) {
Node<V> oneRoot = findRoot(nodes.get(one));
Node<V> anotherOneRoot = findRoot(nodes.get(anotherOne));
if (oneRoot != anotherOneRoot) {
int oneSize = sizeMap.get(oneRoot);
int anotherSize = sizeMap.get(anotherOneRoot);
Node<V> big = oneSize > anotherSize ? oneRoot : anotherOneRoot;
Node<V> small = big == oneRoot ? anotherOneRoot : oneRoot;
parents.put(small, big);
sizeMap.put(big, oneSize + anotherSize);
sizeMap.remove(small);
}
}
public int sets() {
return sizeMap.size();
}
}
题目3:岛问题扩展
由m行和n列组成的二维网格图最初充满了水。我们可以执行一个addLand操作,将位置(row, col)的水变成陆地。给定要操作的位置列表,计算每个addLand操作后的岛屿数量。岛屿被水包围,通过水平或垂直连接相邻的陆地而形成。你可以假设网格的四边都被水包围着。
示例:
输入: m = 3, n = 3,
positions = [[0,0], [0,1], [1,2], [2,1]]
输出: [1,1,2,3]
解析:
起初,二维网格 grid 被全部注入「水」。(0 代表「水」,1 代表「陆地」)
0 0 0
0 0 0
0 0 0
操作 #1:addLand(0, 0) 将 grid[0][0] 的水变为陆地。
1 0 0
0 0 0 Number of islands = 1
0 0 0
操作 #2:addLand(0, 1) 将 grid[0][1] 的水变为陆地。
1 1 0
0 0 0 岛屿的数量为 1
0 0 0
操作 #3:addLand(1, 2) 将 grid[1][2] 的水变为陆地。
1 1 0
0 0 1 岛屿的数量为 2
0 0 0
操作 #4:addLand(2, 1) 将 grid[2][1] 的水变为陆地。
1 1 0
0 0 1 岛屿的数量为 3
0 1 0
拓展:
你是否能在 O(k log mn) 的时间复杂度程度内完成每次的计算?
(k 表示 positions 的长度)
public static List<Integer> numIslands2(int m, int n, int[][] positions) {
if (m <= 0 || n <= 0 || positions == null || positions.length == 0) {
return null;
}
List<Integer> ans = new ArrayList<>(positions.length);
UnionSet unionSet = new UnionSet(m, n);
for (int[] position : positions) {
if (position[0] < 0 || position[0] >= m || position[1] < 0 || position[1] >= n) {
ans.add(0);
} else {
ans.add(unionSet.connect(position));
}
}
return ans;
}
private static class UnionSet {
/**
* 当前节点的父节点
*/
private final int[] parents;
/**
* 当前节点的节点数
*/
private final int[] size;
/**
* 辅助数组
*/
private final int[] help;
/**
* 集的数量
*/
private int sets;
private final int column;
private final int row;
public UnionSet(int m, int n) {
int len = m * n;
parents = new int[len];
size = new int[len];
help = new int[len];
sets = 0;
column = n;
row = m;
}
private int index(int r, int c) {
return r * column + c;
}
private int find(int i) {
int hi = 0;
while (i != parents[i]) {
help[hi++] = i;
i = parents[i];
}
for (hi--; hi >= 0; hi--) {
parents[help[hi]] = i;
}
return i;
}
public void union(int r1, int c1, int r2, int c2) {
if (r1 < 0 || r1 == row || r2 < 0 || r2 == row || c1 < 0 || c1 == column || c2 < 0 || c2 == column) {
return;
}
int i1 = index(r1, c1);
int i2 = index(r2, c2);
if (size[i1] == 0 || size[i2] == 0) {
return;
}
int f1 = find(i1);
int f2 = find(i2);
if (f1 != f2) {
if (size[f1] > size[f2]) {
parents[f2] = f1;
size[f1] += size[f2];
} else {
parents[f1] = f2;
size[f2] += size[f1];
}
sets--;
}
}
public int connect(int[] location) {
int r = location[0];
int c = location[1];
int i = index(r, c);
if (size[i] == 0) {
parents[i] = i;
size[i] = 1;
sets++;
union(r, c, r, c - 1);
union(r, c, r, c + 1);
union(r, c, r - 1, c);
union(r, c, r + 1, c);
}
return sets;
}
}
// 当m、n 很大时,position比较少时,并查集初始化占用太多资源 ,进行优化
public static List<Integer> numIslands23(int m, int n, int[][] positions) {
if (m <= 0 || n <= 0 || positions == null || positions.length == 0) {
return null;
}
UnionSet2 unionSet2 = new UnionSet2();
List<Integer> list = new ArrayList<>();
for (int[] position : positions) {
if (position[0] < 0 || position[1] < 0 || position[0] >= m || position[1] >= n) {
list.add(0);
} else {
list.add(unionSet2.connect(position[0], position[1]));
}
}
return list;
}
private static class UnionSet2 {
private final HashMap<String, String> parents;
private final HashMap<String, Integer> size;
private final ArrayList<String> help;
private int sets;
public UnionSet2() {
parents = new HashMap<>();
size = new HashMap<>();
help = new ArrayList<>();
sets = 0;
}
private String find(String s) {
while (!s.equals(parents.get(s))) {
help.add(s);
s = parents.get(s);
}
for (String s1 : help) {
parents.put(s1, s);
}
help.clear();
return s;
}
public void union(String s1, String s2) {
if (!parents.containsKey(s1) || !parents.containsKey(s2)) {
return;
}
String f1 = find(s1);
String f2 = find(s2);
if (!f1.equals(f2)) {
if (size.get(f1) > size.get(f1)) {
parents.put(f2, f1);
size.put(f1, size.get(f1) + size.get(f2));
size.remove(f2);
} else {
parents.put(f1, f2);
size.put(f2, size.get(f1) + size.get(f2));
size.remove(f1);
}
sets--;
}
}
public int connect(int r, int c) {
String key = String.valueOf(r).concat("-").concat(String.valueOf(c));
if (!parents.containsKey(key)) {
parents.put(key, key);
size.put(key, 1);
sets++;
String up = String.valueOf(r - 1).concat("-").concat(String.valueOf(c));
String down = String.valueOf(r + 1).concat("-").concat(String.valueOf(c));
String left = String.valueOf(r).concat("-").concat(String.valueOf(c - 1));
String right = String.valueOf(r).concat("-").concat(String.valueOf(c + 1));
union(key, up);
union(key, down);
union(key, left);
union(key, right);
}
return sets;
}
}