视频讲解
感谢大神liweiwei的视频讲解。
点击查看【bilibili】
并查集核心知识
- 「并查集」是一种建立在「数组」上的树形结构,并且这棵树的特点是孩子结点指向父亲结点;
- 「并查集」主要用于解决「动态连通性」问题,重点关注的是连接问题,并不关注路径问题;
- 「并查集」是树,所以优化的策略依然是和树的高度较劲,优化思路有「按秩合并」与「路径压缩」。
两种优化方式
隔代压缩



按秩合并
「并查集」与「路径问题」
并查集主要用于解决连通问题,即抽象概念中结点和结点是否连接。
路径问题,不仅仅要考虑连通问题,我们还要往往还需要求出最短路径,这不是并查集做的事情。因此并查集问题能做的事情比路径问题少,它更专注于
- 判断连接状态(查);
- 改变连接状态(并)。
具体说来,并查集的代码需要实现以下的 3 个功能:
1、 find(p):查找元素 p 所对应的集合,
说明:这个函数有些时候仅作为私有函数被下面两个函数调用。
2、 union(p, q):合并元素 p 和元素 q 所在的集合。
3、 isConnected(p, q):查询元素 p 和元素 q 是不是在同一个集合中。
因此,我们要实现的并查集其实就是要实现下面的这个接口:
public interface IUnionFind {// 并查集的版本名称,由开发者指定String versionName();// p (0 到 N-1)所在的分量的标识符int find(int p);// 如果 p 和 q 存在于同一分量中则返回 trueboolean isConnected(int p, int q);// 在 p 与 q 之间添加一条连接void union(int p, int q);}
并查集实现
基础路径压缩版
Java代码:
private class UnionFind {private int[] parent;public UnionFind(int n) {parent = new int[n];for (int i = 0; i < n; i++) {parent[i] = i;}}public int find(int x) {while (x != parent[x]) {parent[x] = parent[parent[x]];x = parent[x];}return x;}public void union(int x, int y) {int rootx = find(x);int rooty = find(y);parent[rootx] = rooty;}public boolean isConnected(int x, int y) {return find(x) == find(y);}}
Python代码:
class UnionFind:
def __init__(self, n):
self.parent = [i for i in range(n)]
def find(self, x):
while x != self.parent[x]:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
def union(self, x, y):
rootx = self.find(x)
rooty = self.find(y)
self.parent[rootx] = rooty
def isConnected(self, x, y):
return self.find(x) == self.find(y)
注意初始化部分的代码
基于 rank 的优化版
Java数组版:
public class UnionFind4 implements IUnionFind {
private int[] parent;
private int count;
// 以下标为 i 的元素为根结点的树的深度(最深的那个深度)
private int[] rank;
public UnionFind4(int n) {
this.count = n;
parent = new int[n];
rank = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
// 初始化时,所有的元素只包含它自己,只有一个元素,所以 rank[i] = 1
rank[i] = 1;
}
}
// 返回下标为 p 的元素的根结点
@Override
public int find(int p) {
while (p != parent[p]) {
p = parent[p];
}
return p;
}
@Override
public boolean isConnected(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
return pRoot == qRoot;
}
@Override
public void union(int p, int q) {
int pRoot = find(p);
int qRoot = find(q);
if (pRoot == qRoot) {
return;
}
// 这一步是与第 3 版不同的地方
if (rank[pRoot] > rank[qRoot]) {
parent[qRoot] = pRoot;
} else if (rank[pRoot] < rank[qRoot]) {
parent[pRoot] = qRoot;
} else {
parent[qRoot] = pRoot;
rank[pRoot]++;
}
// 每次 union 以后,连通分量减 1
count--;
}
}
Python字典版:
class UF:
def __init__(self, M):
self.f = {}
self.s = {}
self.count = len(M)
def find(self, x):
self.f.setdefault(x, x)
while x != self.f[x]:
self.f[x] = self.f[self.f[x]]
x = self.f[x]
return x
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
else:
if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
self.f[x_root] = y_root
self.s[y_root] += self.s[x_root]
else:
self.f[y_root] = x_root
self.s[x_root] = y_root
self.count -= 1
def connected(self, x, y):
return self.find(x) == self.find(y)
例题
1. 等式方程的可满足性
描述 
代码
Java代码:
class Solution {
public boolean equationsPossible(String[] equations) {
UnionFind unionFind = new UnionFind(26);
for (String eq : equations) {
char[] charArray = eq.toCharArray();
if (charArray[1] == '=') {
int index1 = charArray[0] - 'a';
int index2 = charArray[3] - 'a';
unionFind.union(index1, index2);
}
}
for (String eq : equations) {
char[] charArray = eq.toCharArray();
if (charArray[1] == '!') {
int index1 = charArray[0] - 'a';
int index2 = charArray[3] - 'a';
if (unionFind.isConnected(index1, index2)) {
return false;
}
}
}
return true;
}
private class UnionFind {
private int[] parent;
public UnionFind(int n) {
parent = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
}
}
public int find(int x) {
while (x != parent[x]) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
public void union(int x, int y) {
int rootx = find(x);
int rooty = find(y);
parent[rootx] = rooty;
}
public boolean isConnected(int x, int y) {
return find(x) == find(y);
}
}
}
Python代码:
class Solution:
def equationsPossible(self, equations: List[str]) -> bool:
unionFind = UnionFind(26)
for eq in equations:
if eq[1] == "=":
index1 = ord(eq[0]) - ord('a')
index2 = ord(eq[3]) - ord('a')
unionFind.union(index1, index2)
for eq in equations:
if eq[1] == "!":
index1 = ord(eq[0]) - ord('a')
index2 = ord(eq[3]) - ord('a')
print(index1, index2)
if unionFind.isConnected(index1, index2):
return False
print(unionFind.parent)
return True
class UnionFind:
def __init__(self, n):
self.parent = [i for i in range(n)]
def find(self, x):
while x != self.parent[x]:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
def union(self, x, y):
rootx = self.find(x)
rooty = self.find(y)
self.parent[rootx] = rooty
def isConnected(self, x, y):
return self.find(x) == self.find(y)
2. 朋友圈个数
并查集
Python代码:
class Solution:
def findCircleNum(self, M: List[List[int]]) -> int:
uf = UF(M)
for i in range(len(M)):
for j in range(len(M[0])):
if M[i][j] == 1:
uf.union(i, j)
return uf.count
class UF:
def __init__(self, M):
self.f = {}
self.s = {}
self.count = len(M)
def find(self, x):
self.f.setdefault(x, x)
while x != self.f[x]:
self.f[x] = self.f[self.f[x]]
x = self.f[x]
return x
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
else:
if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
self.f[x_root] = y_root
self.s[y_root] += self.s[x_root]
else:
self.f[y_root] = x_root
self.s[x_root] = y_root
self.count -= 1
def connected(self, x, y):
return self.find(x) == self.find(y)
广度优先搜索
Java代码:
public int findCircleNum(int[][] M) {
int m = M.length;
int[] visited = new int[m];
Queue<Integer> q = new LinkedList<>();
int res = 0;
for (int i = 0; i < m; i++) {
if (visited[i] == 0) {
q.add(i);
while (!q.isEmpty()) {
int node = q.remove();
visited[node] = 1;
for (int j = 0; j < m; j++) {
if (M[node][j] == 1 && visited[j] == 0) {
q.add(j);
}
}
}
res++;
}
}
return res;
}
Python代码:
class Solution:
def findCircleNum(self, M: List[List[int]]) -> int:
q = []
visited = [0] * len(M)
res = 0
for i in range(len(M)):
for j in range(len(M[0])):
if M[i][j] == 1 and visited[i] == 0:
q.append(i)
res += 1
while q:
size = len(q)
for k in range(size):
cur = q.pop(0)
for nbr in range(len(M)):
if M[cur][nbr] == 1 and visited[nbr] == 0:
q.append(nbr)
visited[nbr] = 1
return res
深度优先搜索
Java代码:
class Solution {
public int findCircleNum(int[][] M) {
int[] visited = new int[M.length];
int res = 0;
for (int i = 0; i < M.length; i++) {
for (int j = 0; j < M.length; j++) {
if (M[i][j] == 1 && visited[i] == 0) {
dfs(M, visited, i);
res++;
}
}
}
return res;
}
public void dfs(int[][] M, int[] visited, int i) {
visited[i] = 1;
for (int j = 0; j < M.length; j++) {
if (M[i][j] == 1 && visited[j] == 0) {
dfs(M, visited, j);
}
}
}
}
python代码:
class Solution:
def findCircleNum(self, M: List[List[int]]) -> int:
def dfs(M, visited, i):
visited[i] = 1
for j in range(len(M)):
if M[i][j] == 1 and visited[j] == 0:
dfs(M, visited, j)
visited = [0] * len(M)
res = 0
for i in range(len(M)):
for j in range(len(M[0])):
if M[i][j] == 1 and visited[i] == 0:
dfs(M, visited, i)
res += 1
return res;
3. 冗余连接
描述 
Python代码:
class Solution:
def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
uf = UF(1000)
res = []
for edge in edges:
node1 = edge[0]
node2 = edge[1]
if (uf.union(node1, node2)):
res = [node1, node2]
return res
class UF:
def __init__(self, N):
self.f = {}
self.s = {}
self.count = N
def find(self, x):
self.f.setdefault(x, x);
while x != self.f[x]:
self.f[x] = self.f[self.f[x]]
x = self.f[x]
return x
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return True
if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
self.f[x_root] = y_root
self.s[y_root] += self.s[x_root]
else:
self.f[y_root] = x_root
self.s[x_root] += self.s[y_root]
self.count -= 1
def connected(self, x, y):
return self.find(x) == self.find(y)
Java代码:
class Solution {
public int[] findRedundantConnection(int[][] edges) {
UnionFind uf = new UnionFind(10000);
int[] res = new int[2];
for(int[] edge : edges) {
int node1 = edge[0];
int node2 = edge[1];
if (uf.isConnected(node1, node2)) {
res[0] = node1;
res[1] = node2;
}
uf.union(node1, node2);
}
return res;
}
private class UnionFind {
private int[] parent;
public UnionFind(int n) {
parent = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
}
}
public int find(int x) {
while (x != parent[x]) {
parent[x] = parent[parent[x]];
x = parent[x];
}
return x;
}
public boolean union(int x, int y) {
int rootx = find(x);
int rooty = find(y);
if(rootx == rooty) {
return true;
}
parent[rootx] = rooty;
return false;
}
public boolean isConnected(int x, int y) {
return find(x) == find(y);
}
}
}
4. 冗余连接II

思路
分为两种情况:
都是度为1, 则找出构成环的最后一条边
有度为2的两条边(A->B, C->B),则删除的边一定是在其中
先不将C->B加入并查集中,若不能构成环,则C->B是需要删除的点边,反之,则A->B是删除的边(去掉C->B还能构成环,则C->B一定不是要删除的边)
Python代码:
class Solution:
def findRedundantDirectedConnection(self, edges: List[List[int]]) -> List[int]:
n = len(edges)
uf = UnionFind(n + 1)
candinates = []
last = []
parent = {}
for start, end in edges:
if end in parent:
candinates.append([parent[end], end])
candinates.append([start, end])
else:
parent[end] = start
if uf.union(start, end):
last = [start, end]
if not candinates:
return last
return candinates[0] if last else candinates[1]
class UnionFind:
def __init__(self, n):
self.f = list(range(n))
def find(self, x):
while x != self.f[x]:
self.f[x] = self.f[self.f[x]]
x = self.f[x]
return x
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return True
else:
self.f[x_root] = y_root
return False
5. 除法求余(带权并查集)
描述 
思路
上一题,我们把元素放进并查集的时候,做了一个转换,因为等式都是小写字母,因此可以与 0 - 25 做一个映射,这里我们还需要绑定一个权值信息,进而我们将一个元素和一个结点类绑定在一起,为此我们设计一个结点类,并且并查集内部的数组我们使用哈希表代替(事实上不使用哈希表也是可以的,这里为了展示并查集实现的灵活性,使用哈希表代替,请大家自行完并查集内部使用两个数组,一个表示元素,另一个表示结点之间关系)。

代码
Python代码:
class Solution:
def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:
uf = UnionFind()
dic = {}
for i in range(len(equations)):
n1 = equations[i][0]
n2 = equations[i][1]
val = values[i]
dic[n1] = 1
dic[n2] = 1
uf.union(n1, n2, val)
res = [-1.0] * len(queries)
for i in range(len(queries)):
q = queries[i]
if q[0] not in dic or q[1] not in dic:
res[i] = -1.0
else:
res[i] = uf.connected(q[0], q[1])
return res
class UnionFind:
def __init__(self):
self.f = {}
self.weight = {}
def find(self, x):
self.f.setdefault(x, x)
self.weight.setdefault(x, 1.0)
if x != self.f[x]:
origin = self.f[x]
self.f[x] = self.find(self.f[x])
self.weight[x] *= self.weight[origin]
return self.f[x]
def union(self, x, y, val):
root_x = self.find(x)
root_y = self.find(y)
self.f[root_x] = root_y
self.weight[root_x] = self.weight[y] * val / self.weight[x]
def connected(self, x, y):
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return self.weight[x] / self.weight[y]
else:
return -1.0
6. AcWing 237. 程序自动分析
数据范围
1≤n≤10000001≤n≤1000000
1≤i,j≤10000000001≤i,j≤1000000000
输入样例:
2
2
1 2 1
1 2 0
2
1 2 1
2 1 1
输出样例:
NO
YES
代码
Python代码:
class UnionFind:
def __init__(self, n):
self.f = {}
def find(self, x):
self.f.setdefault(x, x)
while x != self.f[x]:
self.f[x] = self.f[self.f[x]]
x = self.f[x]
return x
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
self.f[x_root] = y_root
def connected(self, x, y):
return self.find(x) == self.find(y)
if __name__ == '__main__':
t = int(input())
for i in range(t):
n = int(input())
flag = True
uf = UnionFind(n)
not_equals = []
for i in range(n):
x, y, e = map(int, input().split())
if e == 1:
uf.union(x, y)
else:
not_equals.append([x, y])
for x, y in not_equals:
if uf.connected(x, y):
flag = False
break
if flag:
print("YES")
else:
print("NO")
7. AcWing 238. 银河英雄传说
有一个划分为N列的星际战场,各列依次编号为1,2,…,N。
有N艘战舰,也依次编号为1,2,…,N,其中第i号战舰处于第i列。
有T条指令,每条指令格式为以下两种之一:
1、M i j,表示让第i号战舰所在列的全部战舰保持原有顺序,接在第j号战舰所在列的尾部。
2、C i j,表示询问第i号战舰与第j号战舰当前是否处于同一列中,如果在同一列中,它们之间间隔了多少艘战舰。
现在需要你编写一个程序,处理一系列的指令。
输入格式
第一行包含整数T,表示共有T条指令。
接下来T行,每行一个指令,指令有两种形式:M i j或C i j。
其中M和C为大写字母表示指令类型,i和j为整数,表示指令涉及的战舰编号。
输出格式
你的程序应当依次对输入的每一条指令进行分析和处理:
如果是M i j形式,则表示舰队排列发生了变化,你的程序要注意到这一点,但是不要输出任何信息;
如果是C i j形式,你的程序要输出一行,仅包含一个整数,表示在同一列上,第i号战舰与第j号战舰之间布置的战舰数目,如果第i号战舰与第j号战舰当前不在同一列上,则输出-1。
数据范围
N≤30000,T≤500000
Java代码:
import java.util.*;
public class Main {
public static void main(String[] args) {
UnionFind uf = new UnionFind(30010);
Scanner in = new Scanner(System.in);
int T = Integer.valueOf(in.nextLine());
for (int i = 0; i < T; i++) {
String[] ss = in.nextLine().split(" ");
int x = Integer.valueOf(ss[1]);
int y = Integer.valueOf(ss[2]);
if (ss[0].equals("M")) {
uf.union(x, y);
} else {
int xRoot = uf.find(x);
int yRoot = uf.find(y);
if (xRoot == yRoot) {
System.out.println(Math.abs(uf.d[x] - uf.d[y]) - 1);
} else {
System.out.println(-1);
}
}
}
}
}
class UnionFind {
public int[] fa;
public int[] d;
public int[] size;
public UnionFind(int n) {
fa = new int[n];
size = new int[n];
d = new int[n];
for (int i = 0; i < n; i++) {
fa[i] = i;
size[i] = 1;
}
}
public int find(int x) {
if (x != fa[x]) {
int root = find(fa[x]);
d[x] += d[fa[x]];
fa[x] = root;
}
return fa[x];
}
public void union(int x, int y) {
int x_root = find(x);
int y_root = find(y);
fa[x_root] = y_root;
d[x_root] = size[y_root];
size[y_root] += size[x_root];
}
}
Python代码-超时:
class UnionFind:
def __init__(self, n):
self.f = [i for i in range(n)]
self.size = [1] * n
self.d = [0] * n
def find(self, x):
if x != self.f[x]:
root = self.find(self.f[x])
self.d[x] += self.d[self.f[x]]
self.f[x] = root
return self.f[x]
def merge(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
self.f[x_root] = y_root
self.d[x_root] = self.size[y_root]
self.size[y_root] += self.size[x_root]
if __name__ == '__main__':
T = int(input())
uf = UnionFind(30010)
for _ in range(T):
command, x, y = input().split()
x, y = int(x), int(y)
if command == "M":
uf.merge(x, y)
else:
if uf.find(x) == uf.find(y):
print(abs(uf.d[x] - uf.d[y]) - 1)
else:
print(-1)
