给你一个points 数组,表示 2D 平面上的一些点,其中 points[i] = [xi, yi] 。

连接点 [xi, yi] 和点 [xj, yj] 的费用为它们之间的 曼哈顿距离 :|xi - xj| + |yi - yj| ,其中 |val| 表示 val 的绝对值。

请你返回将所有点连接的最小总费用。只有任意两点之间 有且仅有 一条简单路径时,才认为所有点都已连接。

输入:points = [[0,0],[2,2],[3,10],[5,2],[7,0]]
输出:20
解释:
1584. 连接所有点的最小费用 - 图1
我们可以按照上图所示连接所有点得到最小总费用,总费用为 20 。
注意到任意两个点之间只有唯一一条路径互相到达。

解法一:Kruskal + 并查集

计算所有点之间的距离,按edge长度从小到大遍历,使用并查集检查当前edge所连接的两个顶点是否已经合并。如果未合并,则将其合并,并累加edge的长度。

  1. class UF:
  2. def __init__(self, n: int):
  3. self.id = [i for i in range(n)]
  4. self.sz = [1 for _ in range(n)]
  5. self.count = n
  6. def find(self, p: int) -> int:
  7. while p != self.id[p]:
  8. p = self.id[p]
  9. return p
  10. def connected(self, p: int, q: int) -> bool:
  11. return self.find(p) == self.find(q)
  12. def union(self, p: int, q: int) -> bool:
  13. """当p, q是新的连接时返回True"""
  14. p_id, q_id = self.find(p), self.find(q)
  15. if p_id == q_id:
  16. return False
  17. if self.sz[p_id] < self.sz[q_id]:
  18. self.id[p_id] = q_id
  19. self.sz[q_id] += self.sz[p_id]
  20. else:
  21. self.id[q_id] = p_id
  22. self.sz[p_id] += self.sz[q_id]
  23. self.count -= 1
  24. return True
  25. class Solution:
  26. def minCostConnectPoints(self, points: List[List[int]]) -> int:
  27. """Kruskal算法"""
  28. dist = lambda i, j: abs(points[i][0] - points[j][0]) + abs(points[i][1] - points[j][1])
  29. n = len(points)
  30. edges = sorted(((i, j, dist(i, j)) for i in range(n) for j in range(i + 1, n)), key=lambda x: x[2])
  31. uf = UF(n)
  32. ans = 0
  33. for p, q, dist in edges:
  34. # p, q是新的连接
  35. if uf.union(p, q):
  36. ans += dist
  37. # 当所有节点都已连接时退出循环
  38. if uf.count == 1:
  39. break
  40. return ans

解法二:Prim + DP

Prim的定义是初始化added顶点几何,将points[0]加入,然后选择剩余的顶点集至added集的最短edge,将edge的顶点加入added。算法实现要点在于如何找到符合要求的最短edge。

added用来记录已经连接的顶点,初始化为points[0]
remains用来记录还未连接的各顶点p到added的最短距离w。
当remains不为空时,遍历remains中的各点p,计算p到added各点的距离。当该距离小于w时,更新w;遍历的同时找出w值最小的p,从remains中删除,并加入added。

以上DP可以做空间优化,只需要记录最后一次added的顶点即可,即last。因为remains中的值要么不变,要么因为新加入的last而需要更新。

算法时间复杂度为O(n^2)。

  1. class Solution:
  2. def minCostConnectPoints(self, points: List[List[int]]) -> int:
  3. ans = 0
  4. remains = {(x, y): math.inf for x, y in points}
  5. last, _ = remains.popitem()
  6. while remains:
  7. # 计算remains中每个点和last的距离,并找出距离最小的点
  8. minp, minw = None, math.inf
  9. for p, w in remains.items():
  10. dist = abs(p[0] - last[0]) + abs(p[1] - last[1])
  11. if w > dist:
  12. remains[p] = w = dist
  13. if w < minw:
  14. minp, minw = p, w
  15. # 距离最小的点作为下一步的起点
  16. last = minp
  17. del remains[minp]
  18. ans += minw
  19. return ans