Astar路径规划算法复现-python实现

打印 上一主题 下一主题

主题 576|帖子 576|积分 1728

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Fri May 24 09:04:23 2024
  4. """
  5. import os
  6. import sys
  7. import math
  8. import heapq
  9. import matplotlib.pyplot as plt
  10. import time
  11. '''
  12. 传统A*算法
  13. '''
  14. class Astar:
  15.     '''
  16.     AStar set the cost + heuristics as the priority
  17.     AStar将成本+启发式设置为优先级
  18.     '''
  19.     def __init__(self, s_start, s_goal, heuristic_type, xI, xG):
  20.         self.s_start = s_start
  21.         self.s_goal = s_goal
  22.         self.heuristic_type = heuristic_type
  23.         self.u_set = [(-1, 0), (0, 1), (1, 0), (0, -1)]
  24.         # self.obs = self.obs_map()  # 障碍物位置
  25.         self.Open = []  # 优先级序列,open集合
  26.         self.Closed = []  # 相邻点集合,访问visited序列
  27.         self.parent = dict()  # 相邻父节点
  28.         self.g = dict()  # 成本
  29.         self.x_range = 51  # 设置背景大小
  30.         self.y_range = 51
  31.         self.xI, self.xG = xI, xG
  32.         self.obs = self.obs_map()
  33.     def animation(self, path_l, visited_l, name, path_color='g'):
  34.         # 绘制地图基础元素
  35.         obs_x = [x[0] for x in self.obs]
  36.         obs_y = [x[1] for x in self.obs]
  37.         plt.plot(self.xI[0], self.xI[1], "bs")  # 起点
  38.         plt.plot(self.xG[0], self.xG[1], "gs")  # 终点
  39.         plt.plot(obs_x, obs_y, "sk")  # 障碍物
  40.         plt.title(name)
  41.         plt.axis("equal")
  42.         # 移除起点和终点于visited_l列表中,避免它们被标记为已访问点
  43.         visited_l = [node for node in visited_l if node != self.xI and node != self.xG]
  44.         # 绘制所有已访问节点
  45.         for x in visited_l:
  46.             plt.plot(x[0], x[1], color='gray', marker='o')
  47.         # 绘制路径
  48.         path_x = [point[0] for point in path_l]
  49.         path_y = [point[1] for point in path_l]
  50.         plt.plot(path_x, path_y, linewidth=3, color=path_color)
  51.         # 显示最终图形
  52.         plt.show(block=True)
  53.     def obs_map(self):
  54.         """
  55.         Initialize obstacles' positions
  56.         :return: map of obstacles
  57.         初始化障碍物位置
  58.         返回:障碍物
  59.         """
  60.         x = 51
  61.         y = 31
  62.         self.obs = set()
  63.         # 绘制边界
  64.         self.obs.update((i, 0) for i in range(x))
  65.         self.obs.update((i, y - 1) for i in range(x))
  66.         self.obs.update((0, i) for i in range(y))
  67.         self.obs.update((x - 1, i) for i in range(y))
  68.         # 给出障碍物坐标集1
  69.         self.obs.update((i, 15) for i in range(10, 21))
  70.         self.obs.update((20, i) for i in range(15))
  71.         # 给出障碍物坐标集2
  72.         self.obs.update((30, i) for i in range(15, 30))
  73.         # 给出障碍物坐标集3
  74.         self.obs.update((40, i) for i in range(16))
  75.         return self.obs
  76.     def searching(self):
  77.         """
  78.         A_star Searching.
  79.         :return: path, visited order
  80.         Astart搜索,返回路径、访问集,
  81.         """
  82.         self.parent[self.s_start] = self.s_start  # 初始化 起始父节点中只有起点。
  83.         self.g[self.s_start] = 0  # 初始代价为0
  84.         self.g[self.s_goal] = math.inf  # 目标节点代价为无穷大
  85.         # 将元素(self.f_value(self.s_start), self.s_start)插入到Open堆中,
  86.         # 并保持堆的性质(最小堆中父节点的值总是小于或等于其子节点的值))
  87.         # 这行代码的意思是:计算起始节点s_start的评估值f_value(self.s_start),
  88.         # 然后将这对值(f_value, self.s_start)作为一个元组插入到self.Open这个最小堆中。
  89.         # 这样做的目的是在诸如A*搜索算法等需要高效管理待探索节点的场景下,
  90.         # 确保每次可以从堆顶(也就是当前评估值最小的节点)取出下一个待处理的节点。
  91.         # 这对于寻找最短路径、最小成本解决方案等问题非常有用。
  92.         heapq.heappush(self.Open, (self.f_value(self.s_start), self.s_start))
  93.         while self.Open:
  94.             # heappop会取出栈顶元素并将原始数据从堆栈中删除
  95.             # 在这个例子中,heappop返回的元素假设是一个包含两个元素的元组,
  96.             # 但代码中只关心第二个元素(实际的数据,比如一个状态、节点或其他任何类型的数据),
  97.             # 所以用_占位符丢弃了第一个元素(通常是评估值或优先级),而把第二个元素赋值给了变量s
  98.             _, s_current = heapq.heappop(self.Open)  # s_current存储的是当前位置的坐标
  99.             # print('栈顶元素为{0}'.format(s_current))
  100.             self.Closed.append(s_current)
  101.             if s_current == self.s_goal:  # 迭代停止条件,判断出栈顶元素是否为目标点,如果为目标点,则退出
  102.                 break
  103.             # 如果不是,更新该点附近的代价值
  104.             # get_neighbor为获取附近点的坐标
  105.             for s_next in self.get_neighbor(s_current):
  106.                 new_cost = self.g[s_current] + self.cost(s_current, s_next)
  107.                 if s_next not in self.g:
  108.                     self.g[s_next] = math.inf
  109.                 if new_cost < self.g[s_next]:
  110.                     self.g[s_next] = new_cost
  111.                     self.parent[s_next] = s_current
  112.                     # heappush入栈时需要存储的该点的代价值的计算方式为
  113.                     heapq.heappush(self.Open, (self.f_value(s_next), s_next))
  114.         # self.animation(self.extract_path(self.parent), self.Closed, "A*")
  115.         return self.extract_path(self.parent), self.Closed
  116.     def get_neighbor(self, s_current):
  117.         """
  118.         :param s_current:
  119.         :return: 相邻点集合
  120.         """
  121.         return [(s_current[0] + u[0], s_current[1] + u[1]) for u in self.u_set]
  122.     def cost(self, s_current, s_next):
  123.         """
  124.         :param s_current 表示当前点
  125.         :param s_next 表示相邻点
  126.         :return 若与障碍物无冲突,则范围欧式距离成本,否则为无穷大成本
  127.         计算当前点与相邻点的距离成本
  128.         """
  129.         # 判断是否与障碍物冲突
  130.         if self.is_collision(s_current, s_next):
  131.             return math.inf
  132.         # 这里返回欧式距离成本
  133.         return math.hypot(s_next[0] - s_current[0], s_next[1] - s_current[1])
  134.     def is_collision(self, s_current, s_next):
  135.         """
  136.             check if the line segment (s_start, s_end) is collision.
  137.             :param s_current: start node
  138.             :param s_next: end node
  139.             :return: True: is collision / False: not collision
  140.             检查起终点线段与障碍物是否冲突
  141.         如果线段的起点或终点之一位于障碍物集合 self.obs 内,则直接判定为碰撞,返回 True。
  142.         若线段不垂直也不水平(即斜线段),则分为两种情况检查:
  143.             若线段为45度线(斜率为1:1或-1),则检查线段的端点形成的矩形框内是否有障碍物。
  144.             否则检查线段端点形成的另一矩形框内是否有障碍物。
  145.         若上述任一矩形框内有障碍,则判定为碰撞,返回 True
  146.         若无碰撞情况,则返回 False
  147.         """
  148.         # obs是障碍物,如果遇到障碍物,则距离(成本)无穷大
  149.         if s_current in self.obs or s_next in self.obs:
  150.             return True
  151.         ''''''
  152.         # 如果该点s_start与相邻点s_end不相同
  153.         if s_current[0] != s_next[0] and s_current[1] != s_next[1]:
  154.             # 如果两点横纵坐标之差相等,即线段不垂直也不水平。135度线
  155.             if s_next[0] - s_current[0] == s_current[1] - s_next[1]:
  156.                 s1 = (min(s_current[0], s_next[0]), min(s_current[1], s_next[1]))
  157.                 s2 = (max(s_current[0], s_next[0]), max(s_current[1], s_next[1]))
  158.             # 如果两点横纵坐标之差不相等
  159.             else:
  160.                 s1 = (min(s_current[0], s_next[0]), max(s_current[1], s_next[1]))
  161.                 s2 = (max(s_current[0], s_next[0]), min(s_current[1], s_next[1]))
  162.             # obs是障碍物
  163.             if s1 in self.obs or s2 in self.obs:
  164.                 return True
  165.         return False
  166.     def f_value(self, s_currrent):
  167.         """
  168.         f = g + h. (g: Cost to come, h: heuristic value)
  169.         :param s: current state
  170.         :return: f
  171.         """
  172.         return self.g[s_currrent] + self.heuristic(s_currrent)
  173.     def extract_path(self, parent):
  174.         path = [self.s_goal]
  175.         s = self.s_goal
  176.         while True:
  177.             s = parent[s]
  178.             path.append(s)
  179.             if s == self.s_start:
  180.                 break
  181.         return list(path)
  182.     def heuristic(self, s_current):
  183.         heuristic_type = self.heuristic_type  # heuristic type
  184.         goal = self.s_goal  # goal node
  185.         # 如果为manhattan,则采用曼哈顿距离,s存储的是中间点
  186.         if heuristic_type == "manhattan":
  187.             return abs(goal[0] - s_current[0]) + abs(goal[1] - s_current[1])
  188.         # 否则就是欧几里得距离,符合勾股定理
  189.         else:
  190.             return math.hypot(goal[0] - s_current[0], goal[1] - s_current[1])
  191. if __name__ == '__main__':
  192.     time_start = time.time()
  193.     s_start = (5, 5)
  194.     s_goal = (45, 26)
  195.     star_m = Astar(s_start, s_goal, "ee", s_start, s_goal)
  196.     path, visited = star_m.searching()
  197.     star_m.animation(path, visited, "A*")  # animation
  198.     time_end = time.time()
  199.     print("程序运行时间:", time_end - time_start)
  200.        
复制代码


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

西河刘卡车医

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表