A*算法

A* 算法的最大时间消耗在于将节点插入open_set中,如果我们用优先队列优化,底层实际是二叉堆,因此时间复杂度是O(log(n))。所以我们的最大优化方式有两个,一个是优化open_set的数据结构,而另外一个就是尽可能的减少在寻路过程中加入到open_set中的点。

因此可行的解决的方式有:

  1. open_set数据结构的优化可以使用优先队列实现,如下面示例代码所示
  2. 判断是否将点插入open_set的依据是F值,而如果在寻路的过程中,如果两点中间存在大阻挡,将会使得将很多无用的点塞入open_set中,因此常见的一种做法是进行分块处理,在可绕过阻挡块处设置寻路点。

以下图为例,不同块之间除了黑点之外都是阻挡(黑点是桥梁),因此可以把黑点直接设置为寻路点,这样A -> C,只要分成三个 A 寻路即可,避免了“撞墙”,大大提升了效率
image.png
如果图形是没有阻挡的,实际上使用分块的方式依然可以提升整体的性能,所以还是有必要提前划分某些寻路点。
*当然要注意,这种分块的形式,得到的一般是一个近似的最短路径,在实际项目工程中是可以接受的

  1. 第三种方式就是使用更快的寻路方式了,可以使用JPS寻路算法
  1. import math
  2. import heapq
  3. # f = g + h
  4. class Node():
  5. Weight = 10 # 这里的权重可以根据不同的地形修改,比如沙漠,那么值应该大些,平原值可以小些
  6. D = 10
  7. DD = 14
  8. def __init__(self, x, y, g=0, h=0):
  9. self.x = x # 坐标x
  10. self.y = y # 坐标y
  11. self.g = g # 当前点到起始点的代价,dijkstra算法
  12. self.h = h # 当前点到终点的代价,最佳优先搜索算法
  13. self.father = None
  14. def calAddG(self, current_node):
  15. if abs(self.x - current_node.x) == 1 and abs(self.y - current_node.y) == 1:
  16. return DD * Node.Weight # 斜线距离,代价为14
  17. else:
  18. return D * Node.Weight # 直线距离,代价为10
  19. def setG(self, val):
  20. self.g = val
  21. def calH(self, node): # 曼哈顿距离,也可以使用欧氏距离
  22. dx = abs(self.x - node.x)
  23. dy = abs(self.y - node.y)
  24. # H的计算主要有三种
  25. # 1. 曼哈顿距离,适用于只有直线的情况下:(dx + dy) * D
  26. # 2. 对角距离:适用于直线+45度角的情况:(dx + dy) + (DD - 2*D) * min(dx, dy)
  27. # 3. 欧几里得距离:适用于任意角度,但是还有sqrt操作,相对较慢
  28. # D * sqrt(dx * dx + dy * dy)
  29. return D * sqrt(dx * dx + dy * dy)
  30. def setH(self, val):
  31. self.h = val
  32. def setFather(self, node):
  33. self.father = node
  34. def __lt__(self, other):
  35. return id(self) < id(other)
  36. def __le__(self, other):
  37. return id(self) >= id(other)
  38. class AStar():
  39. def __init__(self, graph):
  40. self.graph = graph
  41. self.max_row = len(graph)
  42. self.max_col = len(graph[0])
  43. def isValid(self, node):
  44. row = node.x
  45. col = node.y
  46. return 0 <= row < self.max_row and 0 <= col < self.max_col and self.graph[row][col] != "#"
  47. def search(self, start_node, end_node):
  48. if not (self.isValid(start_node) and self.isValid(end_node)): # 检查节点的合法性
  49. return []
  50. start_node.calH(end_node)
  51. start_node.setG(0)
  52. path_list = [] # 最后的路径列表
  53. open_list = [] # 当前可选择的点(f, node)
  54. close_list = [] # 已经选择过的点(x,y)
  55. heapq.heappush(open_list, (start_node.g + start_node.h, start_node))
  56. while True:
  57. _val, current_node = heapq.heappop(open_list) # 从优先队列中弹出最小值
  58. close_list.append((current_node.x, current_node.y))
  59. self.searchNear(current_node, end_node, open_list, close_list)
  60. open_node = self.getFromOpenList(end_node, open_list)
  61. if open_node: # 最后的结束节点已在open_list中
  62. while True:
  63. path_list.append((open_node.x, open_node.y))
  64. if open_node.father != None:
  65. open_node = open_node.father
  66. else:
  67. break
  68. elif len(open_list) == 0:
  69. break
  70. return path_list
  71. def searchNear(self, current_node, end_node, open_list, close_list):
  72. """
  73. 搜索节点周围的点
  74. 按照八个方位搜索
  75. (x-1,y-1)(x-1,y)(x-1,y+1)
  76. (x ,y-1)(x ,y)(x ,y+1)
  77. (x+1,y-1)(x+1,y)(x+1,y+1)
  78. """
  79. pos_list = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
  80. for pos_x, pos_y in pos_list:
  81. self.searchOneNode(current_node, end_node, Node(current_node.x + pos_x, current_node.y + pos_y), close_list,
  82. open_list)
  83. def searchOneNode(self, current_node, end_node, node, close_list, open_list):
  84. """
  85. 搜索一个节点
  86. """
  87. row = node.x
  88. col = node.y
  89. if self.graph[row][col] == "#": # 是障碍点
  90. return
  91. if (node.x, node.y) in close_list: # 已遍历过
  92. return
  93. open_node = self.getFromOpenList(node, open_list)
  94. tempG = current_node.g + node.calAddG(current_node)
  95. tempH = node.calH(end_node)
  96. if not open_node: # 当前node不在open_list中,计算其G,H存入open_list中
  97. node.setG(tempG)
  98. node.setH(tempH)
  99. node.setFather(current_node)
  100. heapq.heappush(open_list, (node.g + node.h, node))
  101. else:
  102. if (open_node.g > tempG):
  103. # need heapify
  104. open_node.setG(tempG)
  105. open_node.setH(tempH)
  106. open_node.setFather(current_node)
  107. heapq.heapify(open_list) # 重新调整
  108. def getFromOpenList(self, node, open_list): # 如果不在open_list中则返回False,亦可作为判断用
  109. if not open_list:
  110. return None
  111. for _val, open_node in open_list:
  112. if open_node.x == node.x and open_node.y == node.y:
  113. return open_node
  114. return None
  115. if __name__ == '__main__':
  116. graph = [list("####################"),
  117. list("#*****#************#"),
  118. list("#*****#*****#*####*#"),
  119. list("#*########*##******#"),
  120. list("#*****#*****######*#"),
  121. list("#*****#####*#******#"),
  122. list("####**#*****#*######"),
  123. list("#*****#**#**#**#***#"),
  124. list("#**#*****#**#****#*#"),
  125. list("####################")]
  126. print('graph: ')
  127. for g in graph:
  128. print(g)
  129. star = AStar(graph)
  130. path_list = star.search(Node(1, 1), Node(8, 18))
  131. for x, y in path_list:
  132. graph[x][y] = "O"
  133. print('after A* graph: ')
  134. for g in graph:
  135. print(g)
  1. #include <limits.h>
  2. #include <memory>
  3. #include <vector>
  4. #include <queue>
  5. #include <math.h>
  6. #include <unordered_map>
  7. #include <iostream>
  8. using namespace std;
  9. class aStar
  10. {
  11. public:
  12. using Map = vector<vector<int>>;
  13. vector<pair<int, int>> search(const Map&, int, int, int, int);
  14. private:
  15. bool _isValid(const Map&, int, int);
  16. int _calH(int, int, int, int);
  17. int _calG(const Map&, int, int, int);
  18. static int D; // 直线运动的代价
  19. static int DD; // 斜线运动的代价
  20. static vector<pair<int, int>> path;
  21. int _maxX; // 当前地图X的最大值
  22. int _maxY; // 当前地图Y的最大值
  23. struct Node
  24. {
  25. int x; // x轴位置
  26. int y; // y轴位置
  27. int g; // g值
  28. int h; // h值
  29. Node* parent; // 父节点
  30. bool close; // 是否已经在close中,true是,false否
  31. Node(int x, int y): x(x), y(y), close(false), parent(nullptr){}
  32. };
  33. struct cmp
  34. {
  35. bool operator()(const Node* node1, const Node* node2)
  36. {
  37. return node1->g + node1->h > node2->g + node2->h;
  38. }
  39. };
  40. };
  41. int aStar::D = 10;
  42. int aStar::DD = 14;
  43. vector<pair<int, int>> aStar::path{
  44. /*
  45. 搜索节点周围的点
  46. 按照八个方位搜索
  47. (x-1,y-1)(x-1,y)(x-1,y+1)
  48. (x ,y-1)(x ,y)(x ,y+1)
  49. (x+1,y-1)(x+1,y)(x+1,y+1)
  50. */
  51. {-1, 0}, {1, 0}, {0, -1}, {0, 1}, {-1, -1}, {-1, 1}, {1, -1}, {1, 1} // 前四个方向是直线,后四个是斜线
  52. };
  53. bool aStar::_isValid(const Map& map, int x, int y)
  54. {
  55. return 0 <= x < _maxX && 0 <= y < _maxY && map[x][y] > -1;
  56. }
  57. int aStar::_calH(int current_x, int current_y, int end_x, int end_y)
  58. {
  59. int dx = abs(current_x - end_x);
  60. int dy = abs(current_y - end_y);
  61. /*
  62. 三种计算方式:
  63. 1. 曼哈顿距离,适合只能直线运动的场景:(dx + dy) * D
  64. 2. 对角距离,适合直线+45度角运动的场景:(dx + dy) + (DD - 2*D) * min(dx, dy)
  65. 3. 欧几里得距离:适合任意角度运行,有sqrt操作,相对较慢,开方可以使用别的方式优化:D * sqrt(dx * dx + dy * dy)
  66. */
  67. return D * sqrt(dx * dx + dy * dy);
  68. }
  69. int aStar::_calG(const Map& map, int dest_x, int dest_y, int weight)
  70. {
  71. /*
  72. 如果是直线运动,则结果会是 D * 目标坐标点的地形权重,比如沙漠权重高点,平地权重就低点,即优先走平地
  73. */
  74. return weight * map[dest_x][dest_y];
  75. }
  76. vector<pair<int, int>> aStar::search(const Map& map, int start_x, int start_y, int end_x, int end_y)
  77. {
  78. _maxX = map.size();
  79. _maxY = map[0].size();
  80. if (!_isValid(map, start_x, start_y) or !_isValid(map, end_x, end_y)) // 起始/终点坐标异常
  81. {
  82. return {};
  83. }
  84. unordered_map<int, Node*> nodeMap;
  85. Node* startNode = new Node(start_x, start_y);
  86. nodeMap[startNode->x + _maxX * startNode->y] = startNode; // x + maxX * y 这种方式可以保证x和y计算出来的值是唯一的
  87. startNode->g = 0;
  88. startNode->h = _calH(startNode->x, startNode->y, end_x, end_y);
  89. priority_queue<Node*, vector<Node*>, cmp> Que; // 使用 priority_queue 会有重复插入的问题,因此需要使用close字段进行判断,和dijkstra算法处理相似
  90. Que.emplace(startNode);
  91. vector<pair<int, int>> ret;
  92. while (!Que.empty())
  93. {
  94. Node* parentNode = Que.top();
  95. Que.pop();
  96. if (parentNode->close) continue; // 过滤重复
  97. parentNode->close = true;
  98. if (parentNode->x == end_x && parentNode->y == end_y)
  99. {
  100. while (parentNode->parent != nullptr)
  101. {
  102. ret.emplace_back(parentNode->x, parentNode->y);
  103. parentNode = parentNode->parent;
  104. }
  105. ret.emplace_back(parentNode->x, parentNode->y);
  106. // 内存释放,在实际项目工程中可以预先创建一定数量的node节点,然后每次申请节点只返回其中未被使用的节点,用完之后清除状态数据(x\y\close等),不用释放,这样是空间换时间
  107. for (auto iter = nodeMap.begin(); iter != nodeMap.end(); ++iter) delete iter->second;
  108. return ret;
  109. }
  110. for (int i = 0; i < path.size(); ++i)
  111. {
  112. int new_x = path[i].first + parentNode->x;
  113. int new_y = path[i].second + parentNode->y;
  114. if (!_isValid(map, new_x, new_y)) continue;
  115. int idx = new_x + _maxX * new_y;
  116. if (nodeMap.find(idx) != nodeMap.end()) // 已经遍历过
  117. {
  118. Node* childNode = nodeMap[idx];
  119. if (childNode->close) continue;
  120. int new_g = _calG(map, new_x, new_y, i < 4 ? D : DD); // 直线 or 斜线,对于重复的节点,这里就会由于父节点位置的不同产生不同的值
  121. int new_h = _calH(new_x, new_y, end_x, end_y);
  122. if (new_g + new_h < childNode->g + childNode->h) // 有更小的权重,这里会造成重复插入,上面有过滤,因此无妨
  123. {
  124. childNode->g = new_g;
  125. childNode->h = new_h;
  126. childNode->parent = parentNode;
  127. Que.emplace(childNode);
  128. }
  129. }
  130. else
  131. {
  132. Node* childNode = new Node(new_x, new_y);
  133. nodeMap[idx] = childNode;
  134. int new_g = _calG(map, new_x, new_y, i < 4 ? D : DD); // 直线 or 斜线
  135. int new_h = _calH(new_x, new_y, end_x, end_y);
  136. childNode->g = new_g;
  137. childNode->h = new_h;
  138. childNode->parent = parentNode;
  139. Que.emplace(childNode);
  140. }
  141. }
  142. }
  143. return {}; // 没有找到路径
  144. }
  145. int main()
  146. {
  147. // vector<vector<int>> map{ // -1表示阻挡,其他数值表示当前点的权重值,值越大代表地形越"难走",优先级越低
  148. // {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
  149. // {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
  150. // {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1},
  151. // {1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1},
  152. // {1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1},
  153. // {1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1},
  154. // {1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1},
  155. // {1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1},
  156. // {1, 0, 0, 1, 0, 0, 0, 0 ,0, 1, 0 ,0, 1, 0, 0, 0, 0, 1, 0, 1},
  157. // {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}
  158. // };
  159. vector<vector<int>> map{ // -1表示阻挡,其他数值表示当前点的权重值,值越大代表地形越"难走",优先级越低
  160. {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1},
  161. {-1, 0, 5, 0, 0, 0, 1, 0, -1, -1},
  162. {-1, 0, 10, 0, 0, 0, 1, 0, -1, -1},
  163. {-1, 10, -1, -1, -1, -1, -1, -1, 0, -1},
  164. {-1, 0, 0, 0, 0, 0, 1, 0, -1, -1},
  165. {-1, 0, 0, 0, 0, 0, 1, 1, 0, -1},
  166. {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1}
  167. };
  168. aStar a;
  169. auto path = a.search(map, 1, 1, 5, 8);
  170. for (int i = path.size()-1; i >= 0; --i)
  171. {
  172. cout << "(" << path[i].first << ", " << path[i].second << ")";
  173. if (i != 0) cout << " => ";
  174. }
  175. cout << endl;
  176. }