视频讲解

感谢大神liweiwei的视频讲解。
点击查看【bilibili】

并查集核心知识

  • 「并查集」是一种建立在「数组」上的树形结构,并且这棵树的特点是孩子结点指向父亲结点
  • 「并查集」主要用于解决「动态连通性」问题,重点关注的是连接问题,并不关注路径问题;
  • 「并查集」是树,所以优化的策略依然是和树的高度较劲,优化思路有「按秩合并」与「路径压缩」。

    两种优化方式

    隔代压缩

    image.png
    image.png
    image.png

按秩合并

image.png

「并查集」与「路径问题」

并查集主要用于解决连通问题,即抽象概念中结点和结点是否连接。

路径问题,不仅仅要考虑连通问题,我们还要往往还需要求出最短路径,这不是并查集做的事情。因此并查集问题能做的事情比路径问题少,它更专注于

  • 判断连接状态(查);
  • 改变连接状态(并)。

具体说来,并查集的代码需要实现以下的 3 个功能:

1、 find(p):查找元素 p 所对应的集合,
说明:这个函数有些时候仅作为私有函数被下面两个函数调用。
2、 union(p, q):合并元素 p 和元素 q 所在的集合。
3、 isConnected(p, q):查询元素 p 和元素 q 是不是在同一个集合中。

因此,我们要实现的并查集其实就是要实现下面的这个接口:

  1. public interface IUnionFind {
  2. // 并查集的版本名称,由开发者指定
  3. String versionName();
  4. // p (0 到 N-1)所在的分量的标识符
  5. int find(int p);
  6. // 如果 p 和 q 存在于同一分量中则返回 true
  7. boolean isConnected(int p, int q);
  8. // 在 p 与 q 之间添加一条连接
  9. void union(int p, int q);
  10. }

并查集实现

基础路径压缩版

Java代码:

  1. private class UnionFind {
  2. private int[] parent;
  3. public UnionFind(int n) {
  4. parent = new int[n];
  5. for (int i = 0; i < n; i++) {
  6. parent[i] = i;
  7. }
  8. }
  9. public int find(int x) {
  10. while (x != parent[x]) {
  11. parent[x] = parent[parent[x]];
  12. x = parent[x];
  13. }
  14. return x;
  15. }
  16. public void union(int x, int y) {
  17. int rootx = find(x);
  18. int rooty = find(y);
  19. parent[rootx] = rooty;
  20. }
  21. public boolean isConnected(int x, int y) {
  22. return find(x) == find(y);
  23. }
  24. }

Python代码:

class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]

    def find(self, x):
        while x != self.parent[x]:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x, y):
        rootx = self.find(x)
        rooty = self.find(y)
        self.parent[rootx] = rooty

    def isConnected(self, x, y):
        return self.find(x) == self.find(y)

注意初始化部分的代码

基于 rank 的优化版

Java数组版:

public class UnionFind4 implements IUnionFind {

    private int[] parent;

    private int count;

    // 以下标为 i 的元素为根结点的树的深度(最深的那个深度)
    private int[] rank;

    public UnionFind4(int n) {
        this.count = n;
        parent = new int[n];
        rank = new int[n];
        for (int i = 0; i < n; i++) {
            parent[i] = i;
            // 初始化时,所有的元素只包含它自己,只有一个元素,所以 rank[i] = 1
            rank[i] = 1;
        }
    }
    // 返回下标为 p 的元素的根结点
    @Override
    public int find(int p) {
        while (p != parent[p]) {
            p = parent[p];
        }
        return p;
    }

    @Override
    public boolean isConnected(int p, int q) {
        int pRoot = find(p);
        int qRoot = find(q);
        return pRoot == qRoot;
    }


    @Override
    public void union(int p, int q) {
        int pRoot = find(p);
        int qRoot = find(q);
        if (pRoot == qRoot) {
            return;
        }
        // 这一步是与第 3 版不同的地方
        if (rank[pRoot] > rank[qRoot]) {
            parent[qRoot] = pRoot;
        } else if (rank[pRoot] < rank[qRoot]) {
            parent[pRoot] = qRoot;
        } else {
            parent[qRoot] = pRoot;
            rank[pRoot]++;
        }
        // 每次 union 以后,连通分量减 1
        count--;
    }
}

Python字典版:

class UF:
    def __init__(self, M):
        self.f = {}
        self.s = {}
        self.count = len(M)

    def find(self, x):
        self.f.setdefault(x, x)
        while x != self.f[x]:
            self.f[x] = self.f[self.f[x]]
            x = self.f[x]
        return x

    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        if x_root == y_root:
            return
        else:
            if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
                self.f[x_root] = y_root
                self.s[y_root] += self.s[x_root]
            else:
                self.f[y_root] = x_root
                self.s[x_root] = y_root
        self.count -= 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)

例题

1. 等式方程的可满足性

描述
image.png
代码
Java代码:

class Solution {
    public boolean equationsPossible(String[] equations) {
        UnionFind unionFind = new UnionFind(26);

        for (String eq : equations) {
            char[] charArray = eq.toCharArray();
            if (charArray[1] == '=') {
                int index1 = charArray[0] - 'a';
                int index2 = charArray[3] - 'a';
                unionFind.union(index1, index2);
            }
        }

         for (String eq : equations) {
            char[] charArray = eq.toCharArray();
            if (charArray[1] == '!') {
                int index1 = charArray[0] - 'a';
                int index2 = charArray[3] - 'a';
                if (unionFind.isConnected(index1, index2)) {
                    return false;
                }
            }
        }
        return true;
    }

    private class UnionFind {
        private int[] parent;

        public UnionFind(int n) {
            parent = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }

        public int find(int x) {
            while (x != parent[x]) {
                parent[x] = parent[parent[x]];
                x = parent[x];
            }
            return x;
        }

        public void union(int x, int y) {
            int rootx = find(x);
            int rooty = find(y);
            parent[rootx] = rooty;
        }

        public boolean isConnected(int x, int y) {
            return find(x) == find(y);
        }
    }
}

Python代码:

class Solution:
    def equationsPossible(self, equations: List[str]) -> bool:

        unionFind = UnionFind(26)
        for eq in equations:
            if eq[1] == "=":
                index1 = ord(eq[0]) - ord('a')
                index2 = ord(eq[3]) - ord('a')
                unionFind.union(index1, index2)

        for eq in equations:
            if eq[1] == "!":
                index1 = ord(eq[0]) - ord('a')
                index2 = ord(eq[3]) - ord('a')
                print(index1, index2)
                if unionFind.isConnected(index1, index2):
                    return False
        print(unionFind.parent)
        return True


class UnionFind:
    def __init__(self, n):
        self.parent = [i for i in range(n)]

    def find(self, x):
        while x != self.parent[x]:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x, y):
        rootx = self.find(x)
        rooty = self.find(y)
        self.parent[rootx] = rooty

    def isConnected(self, x, y):
        return self.find(x) == self.find(y)

2. 朋友圈个数

描述
image.png

并查集

Python代码:

class Solution:
    def findCircleNum(self, M: List[List[int]]) -> int:
        uf = UF(M)
        for i in range(len(M)):
            for j in range(len(M[0])):
                if M[i][j] == 1:
                    uf.union(i, j)
        return uf.count

class UF:
    def __init__(self, M):
        self.f = {}
        self.s = {}
        self.count = len(M)

    def find(self, x):
        self.f.setdefault(x, x)
        while x != self.f[x]:
            self.f[x] = self.f[self.f[x]]
            x = self.f[x]
        return x

    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        if x_root == y_root:
            return
        else:
            if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
                self.f[x_root] = y_root
                self.s[y_root] += self.s[x_root]
            else:
                self.f[y_root] = x_root
                self.s[x_root] = y_root
        self.count -= 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)

广度优先搜索

Java代码:

public int findCircleNum(int[][] M) {
        int m = M.length;
        int[] visited = new int[m];
        Queue<Integer> q = new LinkedList<>();
        int res = 0;
        for (int i = 0; i < m; i++) {
            if (visited[i] == 0) {
                q.add(i);
                while (!q.isEmpty()) {
                    int node = q.remove();
                    visited[node] = 1;
                    for (int j = 0; j < m; j++) {
                        if (M[node][j] == 1 && visited[j] == 0) {
                            q.add(j);
                        }
                    }
                }
            res++;
            }
        }
        return res;
    }

Python代码:

class Solution:
    def findCircleNum(self, M: List[List[int]]) -> int:
        q = []
        visited = [0] * len(M)
        res = 0
        for i in range(len(M)):
            for j in range(len(M[0])):
                if M[i][j] == 1 and visited[i] == 0:
                    q.append(i)
                    res += 1
                    while q:
                        size = len(q)
                        for k in range(size):
                            cur = q.pop(0)
                            for nbr in range(len(M)):
                                if M[cur][nbr] == 1 and visited[nbr] == 0:
                                    q.append(nbr)
                                    visited[nbr] = 1
        return res

深度优先搜索

Java代码:

class Solution {
    public int findCircleNum(int[][] M) {
        int[] visited = new int[M.length];
        int res = 0;
        for (int i = 0; i < M.length; i++) {
            for (int j = 0; j < M.length; j++) {
                if (M[i][j] == 1 && visited[i] == 0) {
                    dfs(M, visited, i);
                    res++;
                }
            }
        }
        return res;
    }

    public void dfs(int[][] M, int[] visited, int i) {
        visited[i] = 1;
        for (int j = 0; j < M.length; j++) {
            if (M[i][j] == 1 && visited[j] == 0) {
                dfs(M, visited, j);
            }
        }
    }
}

python代码:

class Solution:
    def findCircleNum(self, M: List[List[int]]) -> int:
        def dfs(M, visited, i):
            visited[i] = 1
            for j in range(len(M)):
                if M[i][j] == 1 and visited[j] == 0:
                    dfs(M, visited, j)

        visited = [0] * len(M)
        res = 0
        for i in range(len(M)):
            for j in range(len(M[0])):
                if M[i][j] == 1 and visited[i] == 0:
                    dfs(M, visited, i)
                    res += 1
        return res;

3. 冗余连接

描述
image.png
Python代码:

class Solution:
    def findRedundantConnection(self, edges: List[List[int]]) -> List[int]:
        uf = UF(1000)
        res = []
        for edge in edges:
            node1 = edge[0]
            node2 = edge[1]
            if (uf.union(node1, node2)):
                res = [node1, node2]
        return res


class UF:
    def __init__(self, N):
        self.f = {}
        self.s = {}
        self.count = N

    def find(self, x):
        self.f.setdefault(x, x);
        while x != self.f[x]:
            self.f[x] = self.f[self.f[x]]
            x = self.f[x]
        return x

    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)

        if x_root == y_root:
            return True
        if self.s.setdefault(x_root, 1) < self.s.setdefault(y_root, 1):
            self.f[x_root] = y_root
            self.s[y_root] += self.s[x_root]
        else:
            self.f[y_root] = x_root
            self.s[x_root] += self.s[y_root]
        self.count -= 1

    def connected(self, x, y):
        return self.find(x) == self.find(y)

Java代码:

class Solution {
    public int[] findRedundantConnection(int[][] edges) {
        UnionFind uf = new UnionFind(10000);
        int[] res = new int[2];
        for(int[] edge : edges) {
            int node1 = edge[0];
            int node2 = edge[1];
            if (uf.isConnected(node1, node2)) {
                res[0] = node1;
                res[1] = node2;
            }
            uf.union(node1, node2);
        }
        return res;
    }

    private class UnionFind {
        private int[] parent;

        public UnionFind(int n) {
            parent = new int[n];
            for (int i = 0; i < n; i++) {
                parent[i] = i;
            }
        }

        public int find(int x) {
            while (x != parent[x]) {
                parent[x] = parent[parent[x]];
                x = parent[x];
            }
            return x;
        }

        public boolean union(int x, int y) {
            int rootx = find(x);
            int rooty = find(y);
            if(rootx == rooty) {
                return true;
            }
            parent[rootx] = rooty;
            return false;
        }

        public boolean isConnected(int x, int y) {
            return find(x) == find(y);
        }
    }
}

4. 冗余连接II

image.png
思路
分为两种情况:

都是度为1, 则找出构成环的最后一条边
有度为2的两条边(A->B, C->B),则删除的边一定是在其中
先不将C->B加入并查集中,若不能构成环,则C->B是需要删除的点边,反之,则A->B是删除的边(去掉C->B还能构成环,则C->B一定不是要删除的边)
Python代码:

class Solution:
    def findRedundantDirectedConnection(self, edges: List[List[int]]) -> List[int]:
        n = len(edges)

        uf = UnionFind(n + 1)
        candinates = []
        last = []
        parent = {}
        for start, end in edges:
            if end in parent:
                candinates.append([parent[end], end])
                candinates.append([start, end])
            else:
                parent[end] = start
                if uf.union(start, end):
                    last = [start, end]

        if not candinates:
            return last
        return candinates[0] if last else candinates[1]



class UnionFind:
    def __init__(self, n):
        self.f = list(range(n))

    def find(self, x):
        while x != self.f[x]:
            self.f[x] = self.f[self.f[x]]
            x = self.f[x]
        return x

    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        if x_root == y_root:
            return True
        else:
            self.f[x_root] = y_root
            return False

5. 除法求余(带权并查集)

描述
image.png
思路
上一题,我们把元素放进并查集的时候,做了一个转换,因为等式都是小写字母,因此可以与 0 - 25 做一个映射,这里我们还需要绑定一个权值信息,进而我们将一个元素和一个结点类绑定在一起,为此我们设计一个结点类,并且并查集内部的数组我们使用哈希表代替(事实上不使用哈希表也是可以的,这里为了展示并查集实现的灵活性,使用哈希表代替,请大家自行完并查集内部使用两个数组,一个表示元素,另一个表示结点之间关系)。

并查集 - 图10
代码
Python代码:

class Solution:
    def calcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]:
        uf = UnionFind()
        dic = {}
        for i in range(len(equations)):
            n1 = equations[i][0]
            n2 = equations[i][1]
            val = values[i]
            dic[n1] = 1
            dic[n2] = 1
            uf.union(n1, n2, val)

        res = [-1.0] * len(queries)
        for i in range(len(queries)):
            q = queries[i]
            if q[0] not in dic or q[1] not in dic:
                res[i] = -1.0
            else:
                res[i] = uf.connected(q[0], q[1])
        return res


class UnionFind:
    def __init__(self):
        self.f = {}
        self.weight = {}

    def find(self, x):
        self.f.setdefault(x, x)
        self.weight.setdefault(x, 1.0)
        if x != self.f[x]:
            origin = self.f[x]
            self.f[x] = self.find(self.f[x])
            self.weight[x] *= self.weight[origin]
        return self.f[x]

    def union(self, x, y, val):
        root_x = self.find(x)
        root_y = self.find(y)
        self.f[root_x] = root_y
        self.weight[root_x] = self.weight[y] * val / self.weight[x]

    def connected(self, x, y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x == root_y:
            return self.weight[x] / self.weight[y]
        else:
            return -1.0

6. AcWing 237. 程序自动分析

数据范围
1≤n≤10000001≤n≤1000000
1≤i,j≤10000000001≤i,j≤1000000000
输入样例:

2
2
1 2 1
1 2 0
2
1 2 1
2 1 1

输出样例:

NO
YES

代码
Python代码:

class UnionFind:
    def __init__(self, n):
        self.f = {}

    def find(self, x):
        self.f.setdefault(x, x)
        while x != self.f[x]:
            self.f[x] = self.f[self.f[x]]
            x = self.f[x]
        return x

    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        self.f[x_root] = y_root

    def connected(self, x, y):
        return self.find(x) == self.find(y)


if __name__ == '__main__':
    t = int(input())
    for i in range(t):
        n = int(input())
        flag = True
        uf = UnionFind(n)
        not_equals = []
        for i in range(n):
            x, y, e = map(int, input().split())
            if e == 1:
                uf.union(x, y)
            else:
                not_equals.append([x, y])
        for x, y in not_equals:
            if uf.connected(x, y):
                flag = False
                break

        if flag:
            print("YES")
        else:
            print("NO")

Java代码:

7. AcWing 238. 银河英雄传说

有一个划分为N列的星际战场,各列依次编号为1,2,…,N。
有N艘战舰,也依次编号为1,2,…,N,其中第i号战舰处于第i列。
有T条指令,每条指令格式为以下两种之一:
1、M i j,表示让第i号战舰所在列的全部战舰保持原有顺序,接在第j号战舰所在列的尾部。
2、C i j,表示询问第i号战舰与第j号战舰当前是否处于同一列中,如果在同一列中,它们之间间隔了多少艘战舰。
现在需要你编写一个程序,处理一系列的指令。
输入格式
第一行包含整数T,表示共有T条指令。
接下来T行,每行一个指令,指令有两种形式:M i j或C i j。
其中M和C为大写字母表示指令类型,i和j为整数,表示指令涉及的战舰编号。
输出格式
你的程序应当依次对输入的每一条指令进行分析和处理:
如果是M i j形式,则表示舰队排列发生了变化,你的程序要注意到这一点,但是不要输出任何信息;
如果是C i j形式,你的程序要输出一行,仅包含一个整数,表示在同一列上,第i号战舰与第j号战舰之间布置的战舰数目,如果第i号战舰与第j号战舰当前不在同一列上,则输出-1。
数据范围
N≤30000,T≤500000
Java代码:


import java.util.*;

public class Main {
    public static void main(String[] args) {
        UnionFind uf = new UnionFind(30010);
        Scanner in = new Scanner(System.in);
        int T = Integer.valueOf(in.nextLine());
        for (int i = 0; i < T; i++) {
            String[] ss = in.nextLine().split(" ");
            int x = Integer.valueOf(ss[1]);
            int y = Integer.valueOf(ss[2]);
            if (ss[0].equals("M")) {
                uf.union(x, y);
            } else {
                int xRoot = uf.find(x);
                int yRoot = uf.find(y);
                if (xRoot == yRoot) {
                    System.out.println(Math.abs(uf.d[x] - uf.d[y]) - 1);
                } else {
                    System.out.println(-1);
                }
            }
        }
    }
}
class UnionFind {
    public int[] fa;
    public int[] d;
    public int[] size;

    public UnionFind(int n) {
        fa = new int[n];
        size = new int[n];
        d = new int[n];
        for (int i = 0; i < n; i++) {
            fa[i] = i;
            size[i] = 1;
        }
    }

    public int find(int x) {
        if (x != fa[x]) {
            int root = find(fa[x]);
            d[x] += d[fa[x]];
            fa[x] = root;
        }
        return fa[x];
    }
    public void union(int x, int y) {
        int x_root = find(x);
        int y_root = find(y);
        fa[x_root] = y_root;
        d[x_root] = size[y_root];
        size[y_root] += size[x_root];
    }
}

Python代码-超时:

class UnionFind:
    def __init__(self, n):
        self.f = [i for i in range(n)]
        self.size = [1] * n
        self.d = [0] * n

    def find(self, x):
        if x != self.f[x]:
            root = self.find(self.f[x])
            self.d[x] += self.d[self.f[x]]
            self.f[x] = root
        return self.f[x]

    def merge(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        self.f[x_root] = y_root
        self.d[x_root] = self.size[y_root]
        self.size[y_root] += self.size[x_root]


if __name__ == '__main__':

    T = int(input())
    uf = UnionFind(30010)
    for _ in range(T):
        command, x, y = input().split()
        x, y = int(x), int(y)
        if command == "M":
            uf.merge(x, y)
        else:
            if uf.find(x) == uf.find(y):
                print(abs(uf.d[x] - uf.d[y]) - 1)
            else:
                print(-1)