【强化学习环境配置+github 无人机强化学习demo复现】

打印 上一主题 下一主题

主题 909|帖子 909|积分 2727


1. N卡驱动安装

起首用体系管家查看电脑硬件信息,本次采用RTX 2080Ti 显卡配置环境,体系接入N卡官网地址:https://www.nvidia.cn/drivers/lookup/, 手动填写体系配置点击查找

如果不玩游戏的朋侪,推荐选择NVIDIA Studio驱动程序,点击查看按钮

这里点击下载最新的驱动程序

双击运行,选择安装位置

耐心等待安装过程




安装(更新)好了显卡驱动以后查看对应版本。我们按下win+R组合键,打开cmd命令窗口。输入如下的命令。
  1. nvidia-smi
复制代码
得到如下图的信息图,可以看到驱动的版本是565.90;最高支持的CUDA版本是12.7版本。得到显卡的最高支持的CUDA版本,我们就可以根据这个信息来安装环境了。

2. Anaconda 安装

打开网址:https://www.anaconda.com/download/success,现在是2024年11月,对应的anaconda版本是支持python3.12。如果想下载之前的版本,或者更低python版本的anaconda。大家可以根据本身空间大小选择anaconda或者miniconda

安装conda
以管理员身份运行软件,点击next



修改位置



上面就是安装完成,可以检查环境变量,添加一下,方便后期编译器识别
注:根据Anaconda安装的位置修改(E:\Anaconda)部门
  1. E:\Anaconda(Python需要)
  2. E:\Anaconda\Scripts(conda自带脚本)
  3. E:\Anaconda\Library\mingw-w64\bin(使用C with python的时候)
  4. E:\Anaconda\Library\usr\bin
  5. E:\Anaconda\Library\bin(jupyter notebook动态库)
复制代码

3. Pytorch环境安装

  1.     按下开始键(win键),点击如图中的图标。打开anaconda的终端Prompt。
复制代码

实验如下的指令查看有哪些环境
  1. conda env list
复制代码
可以看出来,新安装的anaconda只有一个base环境。


修改环境位置

安装的文件夹也需要设置权限

  1. conda info
  2. conda create -n rltorch python=3.10
复制代码

当安装好了以后,实验conda env list
这个命令,就可以看到比一开始多了一个pytorch这个环境。现在我们可以在这个环境里面安装深度学习框架和一些Python包了。
实验如下命令,激活这个环境。conda activate 假造环境名称
  1. conda activate rltorch
复制代码

  1.     安装pytorch-gup版的环境,由于[pytorch的官网](https://pytorch.org/)在国外,下载相关的环境包是比较慢的,所以我们给环境换源。在pytorch环境下执行如下的命名给环境换清华源。
复制代码

  1. conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  2. conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  3. conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
  4. conda config --set show_channel_urls yes
复制代码
然后打开pytorch的官网,由于开头我们通过驱动检测到我的显卡为 RTX2080Ti,最高支持cuda11.6版本,以是我们选择cuda11.3版本的cuda,然后将下面红色框框中的内容复制下来,肯定不要把后面的-c pytorch也复制下来,因为如许运行就是照旧在国外源下载,如许就会很慢。
  1. # CUDA 11.8
  2. conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
  3. # CUDA 12.1
  4. conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia
  5. # CPU Only
  6. conda install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 cpuonly -c pytorch
复制代码
安装过程需要耐心等待

  1. pip list
复制代码

安装 gym 模块: 你可以使用 pip 来安装 gym 模块。pygame模块一般用来交互显示训练效果,打开终端或命令提示符,然后运行以下命令:
  1. pip install gym
  2. pip install pygame
复制代码
4. Vscode安装

我们编写代码采用Vscode,Vscode和Git软件配置教程参考:https://vor2345.blog.csdn.net/article/details/142727918
5. Github demo复现

我们来复现西工大的一篇开源论文《基于MASAC强化学习算法的多无人机协同路径规划》文章DOI: https://doi.org/10.1360/SSI-2024-0050
github代码开源地址:https://github.com/henbudidiao/UAV-path-planning
打开Vscode gitbash
  1. git clone "https://github.com/henbudidiao/UAV-path-planning.git"
复制代码
我们需要打开motion plan项目文件夹

修改四处代码
5.1. main_SAC.py

是作者论文重要计划的强化学习方法:
  1. # -*- coding: utf-8 -*-
  2. #开发者:Bright Fang
  3. #开发时间:2023/7/30 18:13
  4. from rl_env.path_env import RlGame
  5. # import pygame
  6. # from assignment import constants as C
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import numpy as np
  11. from matplotlib import pyplot as plt
  12. import os
  13. import pickle as pkl
  14. filedir = os.path.dirname(__file__)
  15. shoplistfile = filedir + "\\MASAC_new1"  #保存文件数据所在文件的文件名
  16. shoplistfile_test = filedir + "\\MASAC_d_test2"  #保存文件数据所在文件的文件名
  17. shoplistfile_test1 = filedir + "\\MASAC_compare"  #保存文件数据所在文件的文件名
  18. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  19. N_Agent=1
  20. M_Enemy=4
  21. RENDER=True
  22. TRAIN_NUM = 1
  23. TEST_EPIOSDE=100
  24. env = RlGame(n=N_Agent,m=M_Enemy,render=RENDER).unwrapped
  25. state_number=7
  26. action_number=env.action_space.shape[0]
  27. max_action = env.action_space.high[0]
  28. min_action = env.action_space.low[0]
  29. EP_MAX = 500
  30. EP_LEN = 1000
  31. GAMMA = 0.9
  32. q_lr = 3e-4
  33. value_lr = 3e-3
  34. policy_lr = 1e-3
  35. BATCH = 128
  36. tau = 1e-2
  37. MemoryCapacity=20000
  38. Switch=1
  39. class Ornstein_Uhlenbeck_Noise:
  40.     def __init__(self, mu, sigma=0.1, theta=0.1, dt=1e-2, x0=None):
  41.         self.theta = theta
  42.         self.mu = mu
  43.         self.sigma = sigma
  44.         self.dt = dt
  45.         self.x0 = x0
  46.         self.reset()
  47.     def __call__(self):
  48.         x = self.x_prev + \
  49.             self.theta * (self.mu - self.x_prev) * self.dt + \
  50.             self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
  51.         '''
  52.         后两行是dXt,其中后两行的前一行是θ(μ-Xt)dt,后一行是σεsqrt(dt)
  53.         '''
  54.         self.x_prev = x
  55.         return x
  56.     def reset(self):
  57.         if self.x0 is not None:
  58.             self.x_prev = self.x0
  59.         else:
  60.             self.x_prev = np.zeros_like(self.mu)
  61. class ActorNet(nn.Module):
  62.     def __init__(self,inp,outp):
  63.         super(ActorNet, self).__init__()
  64.         self.in_to_y1=nn.Linear(inp,256)
  65.         self.in_to_y1.weight.data.normal_(0,0.1)
  66.         self.y1_to_y2=nn.Linear(256,256)
  67.         self.y1_to_y2.weight.data.normal_(0,0.1)
  68.         self.out=nn.Linear(256,outp)
  69.         self.out.weight.data.normal_(0,0.1)
  70.         self.std_out = nn.Linear(256, outp)
  71.         self.std_out.weight.data.normal_(0, 0.1)
  72.     def forward(self,inputstate):
  73.         inputstate=self.in_to_y1(inputstate)
  74.         inputstate=F.relu(inputstate)
  75.         inputstate=self.y1_to_y2(inputstate)
  76.         inputstate=F.relu(inputstate)
  77.         mean=max_action*torch.tanh(self.out(inputstate))#输出概率分布的均值mean
  78.         log_std=self.std_out(inputstate)#softplus激活函数的值域>0
  79.         log_std=torch.clamp(log_std,-20,2)
  80.         std=log_std.exp()
  81.         return mean,std
  82. class CriticNet(nn.Module):
  83.     def __init__(self,input,output):
  84.         super(CriticNet, self).__init__()
  85.         #q1
  86.         self.in_to_y1=nn.Linear(input+output,256)
  87.         self.in_to_y1.weight.data.normal_(0,0.1)
  88.         self.y1_to_y2=nn.Linear(256,256)
  89.         self.y1_to_y2.weight.data.normal_(0,0.1)
  90.         self.out=nn.Linear(256,1)
  91.         self.out.weight.data.normal_(0,0.1)
  92.         #q2
  93.         self.q2_in_to_y1 = nn.Linear(input+output, 256)
  94.         self.q2_in_to_y1.weight.data.normal_(0, 0.1)
  95.         self.q2_y1_to_y2 = nn.Linear(256, 256)
  96.         self.q2_y1_to_y2.weight.data.normal_(0, 0.1)
  97.         self.q2_out = nn.Linear(256, 1)
  98.         self.q2_out.weight.data.normal_(0, 0.1)
  99.     def forward(self,s,a):
  100.         inputstate = torch.cat((s, a), dim=1)
  101.         #q1
  102.         q1=self.in_to_y1(inputstate)
  103.         q1=F.relu(q1)
  104.         q1=self.y1_to_y2(q1)
  105.         q1=F.relu(q1)
  106.         q1=self.out(q1)
  107.         #q2
  108.         q2 = self.q2_in_to_y1(inputstate)
  109.         q2 = F.relu(q2)
  110.         q2 = self.q2_y1_to_y2(q2)
  111.         q2 = F.relu(q2)
  112.         q2 = self.q2_out(q2)
  113.         return q1,q2
  114. class Memory():
  115.     def __init__(self,capacity,dims):
  116.         self.capacity=capacity
  117.         self.mem=np.zeros((capacity,dims))
  118.         self.memory_counter=0
  119.     '''存储记忆'''
  120.     def store_transition(self,s,a,r,s_):
  121.         tran = np.hstack((s, a,r, s_))  # 把s,a,r,s_困在一起,水平拼接
  122.         index = self.memory_counter % self.capacity#除余得索引
  123.         self.mem[index, :] = tran  # 给索引存值,第index行所有列都为其中一次的s,a,r,s_;mem会是一个capacity行,(s+a+r+s_)列的数组
  124.         self.memory_counter+=1
  125.     '''随机从记忆库里抽取'''
  126.     def sample(self,n):
  127.         assert self.memory_counter>=self.capacity,'记忆库没有存满记忆'
  128.         sample_index = np.random.choice(self.capacity, n)#从capacity个记忆里随机抽取n个为一批,可得到抽样后的索引号
  129.         new_mem = self.mem[sample_index, :]#由抽样得到的索引号在所有的capacity个记忆中  得到记忆s,a,r,s_
  130.         return new_mem
  131. class Actor():
  132.     def __init__(self):
  133.         self.action_net=ActorNet(state_number,action_number)#这只是均值mean
  134.         self.optimizer=torch.optim.Adam(self.action_net.parameters(),lr=policy_lr)
  135.     def choose_action(self,s):
  136.         inputstate = torch.FloatTensor(s)
  137.         mean,std=self.action_net(inputstate)
  138.         dist = torch.distributions.Normal(mean, std)
  139.         action=dist.sample()
  140.         action=torch.clamp(action,min_action,max_action)
  141.         return action.detach().numpy()
  142.     def evaluate(self,s):
  143.         inputstate = torch.FloatTensor(s)
  144.         mean,std=self.action_net(inputstate)
  145.         dist = torch.distributions.Normal(mean, std)
  146.         noise = torch.distributions.Normal(0, 1)
  147.         z = noise.sample()
  148.         action=torch.tanh(mean+std*z)
  149.         action=torch.clamp(action,min_action,max_action)
  150.         action_logprob=dist.log_prob(mean+std*z)-torch.log(1-action.pow(2)+1e-6)
  151.         return action,action_logprob
  152.     def learn(self,actor_loss):
  153.         loss=actor_loss
  154.         self.optimizer.zero_grad()
  155.         loss.backward()
  156.         self.optimizer.step()
  157. class Entroy():
  158.     def __init__(self):
  159.         self.target_entropy = -0.1
  160.         self.log_alpha = torch.zeros(1, requires_grad=True)
  161.         self.alpha = self.log_alpha.exp()
  162.         self.optimizer = torch.optim.Adam([self.log_alpha], lr=q_lr)
  163.     def learn(self,entroy_loss):
  164.         loss=entroy_loss
  165.         self.optimizer.zero_grad()
  166.         loss.backward()
  167.         self.optimizer.step()
  168. class Critic():
  169.     def __init__(self):
  170.         self.critic_v,self.target_critic_v=CriticNet(state_number*(N_Agent+M_Enemy),action_number),CriticNet(state_number*(N_Agent+M_Enemy),action_number)#改网络输入状态,生成一个Q值
  171.         self.target_critic_v.load_state_dict(self.critic_v.state_dict())
  172.         self.optimizer = torch.optim.Adam(self.critic_v.parameters(), lr=value_lr,eps=1e-5)
  173.         self.lossfunc = nn.MSELoss()
  174.     def soft_update(self):
  175.         for target_param, param in zip(self.target_critic_v.parameters(), self.critic_v.parameters()):
  176.             target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
  177.     def get_v(self,s,a):
  178.         return self.critic_v(s,a)
  179.     def target_get_v(self,s,a):
  180.         return self.target_critic_v(s,a)
  181.     def learn(self,current_q1,current_q2,target_q):
  182.         loss = self.lossfunc(current_q1, target_q) + self.lossfunc(current_q2, target_q)
  183.         self.optimizer.zero_grad()
  184.         loss.backward()
  185.         self.optimizer.step()
  186. def main():
  187.     run(env)
  188. def run(env):
  189.     if Switch==0:
  190.         try:
  191.             assert M_Enemy == 1
  192.         except:
  193.             print('程序终止,被逮到~嘿嘿,哥们儿预判到你会犯错,这段程序中变量\'M_Enemy\'的值必须为1,请把它的值改为1。\n'
  194.                   '改为1之后程序一定会报错,这是因为组数越界,更改path_env.py文件中的跟随者无人机初始化个数;删除多余的\n'
  195.                   '求距离函数,即变量dis_1_agent_0_to_3等,以及提到变量dis_1_agent_0_to_3等的地方;删除画无人机轨迹的\n'
  196.                   '函数;删除step函数的最后一个返回值dis_1_agent_0_to_1;将player.py文件中的变量dt改为1;即可开始训练!\n'
  197.                   '如果实在不会改也无妨,我会在不久之后出一个视频来手把手教大伙怎么改,可持续关注此项目github中的README文件。\n')
  198.         else:
  199.             print('SAC训练中...')
  200.             all_ep_r = [[] for i in range(TRAIN_NUM)]
  201.             all_ep_r0 = [[] for i in range(TRAIN_NUM)]
  202.             all_ep_r1 = [[] for i in range(TRAIN_NUM)]
  203.             for k in range(TRAIN_NUM):
  204.                 actors = [None for _ in range(N_Agent+M_Enemy)]
  205.                 critics = [None for _ in range(N_Agent+M_Enemy)]
  206.                 entroys = [None for _ in range(N_Agent+M_Enemy)]
  207.                 for i in range(N_Agent+M_Enemy):
  208.                     actors[i] = Actor()
  209.                     critics[i] = Critic()
  210.                     entroys[i] = Entroy()
  211.                 M = Memory(MemoryCapacity, 2 * state_number*(N_Agent+M_Enemy) + action_number*(N_Agent+M_Enemy) + 1*(N_Agent+M_Enemy))
  212.                 ou_noise = Ornstein_Uhlenbeck_Noise(mu=np.zeros(((N_Agent+M_Enemy), action_number)))
  213.                 action=np.zeros(((N_Agent+M_Enemy), action_number))
  214.                 # aaa = np.zeros((N_Agent, state_number))
  215.                 for episode in range(EP_MAX):
  216.                     observation = env.reset()  # 环境重置
  217.                     reward_totle,reward_totle0,reward_totle1 = 0,0,0
  218.                     for timestep in range(EP_LEN):
  219.                         for i in range(N_Agent+M_Enemy):
  220.                             action[i] = actors[i].choose_action(observation[i])
  221.                         # action[0]=actor0.choose_action(observation[0])
  222.                         # action[1] = actor0.choose_action(observation[1])
  223.                         if episode <= 20:
  224.                             noise = ou_noise()
  225.                         else:
  226.                             noise = 0
  227.                         action = action + noise
  228.                         action = np.clip(action, -max_action, max_action)
  229.                         observation_, reward,done,win,team_counter= env.step(action)  # 单步交互
  230.                         M.store_transition(observation.flatten(), action.flatten(), reward.flatten(), observation_.flatten())
  231.                         # 记忆库存储
  232.                         # 有的2000个存储数据就开始学习
  233.                         if M.memory_counter > MemoryCapacity:
  234.                             b_M = M.sample(BATCH)
  235.                             b_s = b_M[:, :state_number*(N_Agent+M_Enemy)]
  236.                             b_a = b_M[:, state_number*(N_Agent+M_Enemy): state_number*(N_Agent+M_Enemy) + action_number*(N_Agent+M_Enemy)]
  237.                             b_r = b_M[:, -state_number*(N_Agent+M_Enemy) - 1*(N_Agent+M_Enemy): -state_number*(N_Agent+M_Enemy)]
  238.                             b_s_ = b_M[:, -state_number*(N_Agent+M_Enemy):]
  239.                             b_s = torch.FloatTensor(b_s)
  240.                             b_a = torch.FloatTensor(b_a)
  241.                             b_r = torch.FloatTensor(b_r)
  242.                             b_s_ = torch.FloatTensor(b_s_)
  243.                             # if not done[0]:
  244.                             #     new_action_0, log_prob_0 = actor0.evaluate(b_s_[:, 0:state_number])
  245.                             #     target_q10, target_q20 = critic0.target_critic_v(b_s_[:, 0:state_number], new_action_0)
  246.                             #     target_q0 = b_r[:, 0:1] + GAMMA * (1 - b_done[0]) *(torch.min(target_q10, target_q20) - entroy0.alpha * log_prob_0)
  247.                             #     current_q10, current_q20 = critic0.get_v(b_s[:,0:state_number], b_a[:, 0:action_number*1])
  248.                             #     critic0.learn(current_q10, current_q20, target_q0.detach())
  249.                             #     a0, log_prob0 = actor0.evaluate(b_s[:, 0:state_number*1])
  250.                             #     q10, q20 = critic0.get_v(b_s[:, 0:state_number*1], a0)
  251.                             #     q0 = torch.min(q10, q20)
  252.                             #     actor_loss0 = (entroy0.alpha * log_prob0 - q0).mean()
  253.                             #     alpha_loss0 = -(entroy0.log_alpha.exp() * (
  254.                             #                     log_prob0 + entroy0.target_entropy).detach()).mean()
  255.                             #     actor0.learn(actor_loss0)
  256.                             #     entroy0.learn(alpha_loss0)
  257.                             #     entroy0.alpha = entroy0.log_alpha.exp()
  258.                             #     # 软更新
  259.                             #     critic0.soft_update()
  260.                             # if not done[1]:
  261.                             #     new_action_1, log_prob_1 = actor1.evaluate(b_s_[:, state_number:state_number*2])
  262.                             #     target_q11, target_q21 = critic1.target_critic_v(b_s_[:, state_number:state_number*2], new_action_1)
  263.                             #     target_q1= b_r[:, 1:2] + GAMMA * (1 - b_done[1])*(torch.min(target_q11, target_q21) - entroy1.alpha * log_prob_1)
  264.                             #     current_q11, current_q21 = critic1.get_v(b_s[:, state_number:state_number*2], b_a[:, action_number:action_number * 2])
  265.                             #     critic1.learn(current_q11, current_q21, target_q1.detach())
  266.                             #     a1, log_prob1 = actor1.evaluate(b_s[:, state_number:state_number*2])
  267.                             #     q11, q21 = critic1.get_v(b_s[:, state_number:state_number*2], a1)
  268.                             #     q1 = torch.min(q11, q21)
  269.                             #     actor_loss1 = (entroy1.alpha * log_prob1 - q1).mean()
  270.                             #     alpha_loss1 = -(entroy1.log_alpha.exp() * (
  271.                             #             log_prob1 + entroy1.target_entropy).detach()).mean()
  272.                             #     actor1.learn(actor_loss1)
  273.                             #     entroy1.learn(alpha_loss1)
  274.                             #     entroy1.alpha = entroy1.log_alpha.exp()
  275.                             #     # 软更新
  276.                             #     critic1.soft_update()
  277.                             for i in range(N_Agent+M_Enemy):
  278.                             # # # TODO 方法二
  279.                             # new_action_0, log_prob_0 = actor0.evaluate(b_s_[:, :state_number])
  280.                             # new_action_1, log_prob_1 = actor0.evaluate(b_s_[:, state_number:state_number * 2])
  281.                             # new_action = torch.hstack((new_action_0, new_action_1))
  282.                             # # new_action = torch.cat((new_action_0, new_action_1),dim=1)
  283.                             # log_prob_ = (log_prob_0 + log_prob_1) / 2
  284.                             # # log_prob_=torch.hstack((log_prob_0.mean(axis=1).unsqueeze(dim=1),log_prob_1.mean(axis=1).unsqueeze(dim=1)))
  285.                             # target_q1, target_q2 = critic0.target_critic_v(b_s_, new_action)
  286.                             #
  287.                             # target_q = b_r + GAMMA * (torch.min(target_q1, target_q2) - entroy0.alpha * log_prob_)
  288.                             #
  289.                             # current_q1, current_q2 = critic0.get_v(b_s, b_a)
  290.                             # critic0.learn(current_q1, current_q2, target_q.detach())
  291.                             # a0, log_prob0 = actor0.evaluate(b_s[:, :state_number])
  292.                             # a1, log_prob1 = actor0.evaluate(b_s[:, state_number:state_number * 2])
  293.                             # a = torch.hstack((a0, a1))
  294.                             # # a = torch.cat((a0, a1),dim=1)
  295.                             # log_prob = (log_prob0 + log_prob1) / 2
  296.                             # # log_prob = torch.hstack((log_prob0.mean(axis=1).unsqueeze(dim=1), log_prob1.mean(axis=1).unsqueeze(dim=1)))
  297.                             # q1, q2 = critic0.get_v(b_s, a)
  298.                             # q = torch.min(q1, q2)
  299.                             #
  300.                             # actor_loss = (entroy0.alpha * log_prob - q).mean()
  301.                             # alpha_loss = -(entroy0.log_alpha.exp() * (log_prob + entroy0.target_entropy).detach()).mean()
  302.                             #
  303.                             # actor0.learn(actor_loss)
  304.                             # # actor1.learn(actor_loss)
  305.                             # entroy0.learn(alpha_loss)
  306.                             # entroy0.alpha = entroy0.log_alpha.exp()
  307.                             # # 软更新
  308.                             # critic0.soft_update()
  309.                                 # TODO 方法零
  310.                                 # if not done[i]:
  311.                                 new_action, log_prob_ = actors[i].evaluate(b_s_[:, state_number*i:state_number*(i+1)])
  312.                                 target_q1, target_q2 = critics[i].target_critic_v(b_s_, new_action)
  313.                                 target_q = b_r[:, i:(i+1)] + GAMMA * (torch.min(target_q1, target_q2) - entroys[i].alpha * log_prob_)
  314.                                 current_q1, current_q2 = critics[i].get_v(b_s, b_a[:, action_number*i:action_number*(i+1)])
  315.                                 critics[i].learn(current_q1, current_q2, target_q.detach())
  316.                                 a, log_prob = actors[i].evaluate(b_s[:, state_number*i:state_number*(i+1)])
  317.                                 q1, q2 = critics[i].get_v(b_s, a)
  318.                                 q = torch.min(q1, q2)
  319.                                 actor_loss = (entroys[i].alpha * log_prob - q).mean()
  320.                                 alpha_loss = -(entroys[i].log_alpha.exp() * (
  321.                                                 log_prob + entroys[i].target_entropy).detach()).mean()
  322.                                 actors[i].learn(actor_loss)
  323.                                 entroys[i].learn(alpha_loss)
  324.                                 entroys[i].alpha = entroys[i].log_alpha.exp()
  325.                                 # 软更新
  326.                                 critics[i].soft_update()
  327.                                     # #TODO 方法一
  328.                                     # new_action_0, log_prob_0 = actors[i].evaluate(b_s_[:, :state_number])
  329.                                     # new_action_1, log_prob_1 = actors[i].evaluate(b_s_[:, state_number:state_number * 2])
  330.                                     # new_action = torch.hstack((new_action_0, new_action_1))
  331.                                     # # new_action = torch.cat((new_action_0, new_action_1),dim=1)
  332.                                     # # log_prob_ = (log_prob_0 + log_prob_1) / 2
  333.                                     # # log_prob_=torch.hstack((log_prob_0.mean(axis=1).unsqueeze(dim=1),log_prob_1.mean(axis=1).unsqueeze(dim=1)))
  334.                                     # target_q1, target_q2 = critics[i].target_critic_v(b_s_, new_action)
  335.                                     # if i==0:
  336.                                     #     target_q = b_r[:, i:(i+1)] + GAMMA * (torch.min(target_q1, target_q2) - entroys[i].alpha * log_prob_0)
  337.                                     # elif i==1:
  338.                                     #     target_q = b_r[:, i:(i+1)] + GAMMA * (torch.min(target_q1, target_q2) - entroys[i].alpha * log_prob_1)
  339.                                     # current_q1, current_q2 = critics[i].get_v(b_s, b_a)
  340.                                     # critics[i].learn(current_q1, current_q2, target_q.detach())
  341.                                     # a0, log_prob0 = actors[i].evaluate(b_s[:, :state_number])
  342.                                     # a1, log_prob1 = actors[i].evaluate(b_s[:, state_number:state_number * 2])
  343.                                     # a = torch.hstack((a0, a1))
  344.                                     # # a = torch.cat((a0, a1),dim=1)
  345.                                     # # log_prob = (log_prob0 + log_prob1) / 2
  346.                                     # # log_prob = torch.hstack((log_prob0.mean(axis=1).unsqueeze(dim=1), log_prob1.mean(axis=1).unsqueeze(dim=1)))
  347.                                     # q1, q2 = critics[i].get_v(b_s, a)
  348.                                     # q = torch.min(q1, q2)
  349.                                     # if i == 0:
  350.                                     #     actor_loss = (entroys[i].alpha * log_prob0 - q).mean()
  351.                                     #     alpha_loss = -(entroys[i].log_alpha.exp() * (log_prob0 + entroys[i].target_entropy).detach()).mean()
  352.                                     # elif i == 1:
  353.                                     #     actor_loss = (entroys[i].alpha * log_prob1 - q).mean()
  354.                                     #     alpha_loss = -(entroys[i].log_alpha.exp() * (log_prob1 + entroys[i].target_entropy).detach()).mean()
  355.                                     # actors[i].learn(actor_loss)
  356.                                     # entroys[i].learn(alpha_loss)
  357.                                     # entroys[i].alpha = entroys[i].log_alpha.exp()
  358.                                     # # 软更新
  359.                                     # critics[i].soft_update()
  360.                                 # #TODO 方法二
  361.                                 # new_action_0, log_prob_0 = actors[i].evaluate(b_s_[:, :state_number])
  362.                                 # new_action_1, log_prob_1 = actors[i].evaluate(b_s_[:, state_number:state_number * 2])
  363.                                 # new_action = torch.hstack((new_action_0, new_action_1))
  364.                                 # log_prob_ = (log_prob_0 + log_prob_1) / 2
  365.                                 # # log_prob_=torch.hstack((log_prob_0.mean(axis=1).unsqueeze(dim=1),log_prob_1.mean(axis=1).unsqueeze(dim=1)))
  366.                                 # target_q1, target_q2 = critics[i].target_critic_v(b_s_, new_action)
  367.                                 # target_q = b_r + GAMMA * (torch.min(target_q1, target_q2) - entroys[i].alpha * log_prob_)
  368.                                 # current_q1, current_q2 = critics[i].get_v(b_s, b_a)
  369.                                 # critics[i].learn(current_q1, current_q2, target_q.detach())
  370.                                 # a0, log_prob0 = actors[i].evaluate(b_s[:, :state_number])
  371.                                 # a1, log_prob1 = actors[i].evaluate(b_s[:, state_number:state_number * 2])
  372.                                 # a = torch.hstack((a0, a1))
  373.                                 # log_prob = (log_prob0 + log_prob1) / 2
  374.                                 # # log_prob = torch.hstack((log_prob0.mean(axis=1).unsqueeze(dim=1), log_prob1.mean(axis=1).unsqueeze(dim=1)))
  375.                                 # q1, q2 = critics[i].get_v(b_s, a)
  376.                                 # q = torch.min(q1, q2)
  377.                                 # actor_loss = ( entroys[i].alpha * log_prob - q).mean()
  378.                                 # actors[i].learn(actor_loss)
  379.                                 # alpha_loss = -( entroys[i].log_alpha.exp() * (log_prob +  entroys[i].target_entropy).detach()).mean()
  380.                                 # entroys[i].learn(alpha_loss)
  381.                                 # entroys[i].alpha =  entroys[i].log_alpha.exp()
  382.                                 # # 软更新
  383.                                 # critics[i].soft_update()
  384.                             # new_action_0, log_prob_0 = actor.evaluate(b_s_[:, :state_number])
  385.                             # new_action_1, log_prob_1 = actor.evaluate(b_s_[:, state_number:state_number*2])
  386.                             # new_action=torch.hstack((new_action_0,new_action_1))
  387.                             # log_prob_=(log_prob_0+log_prob_1)/2
  388.                             # # log_prob_=torch.hstack((log_prob_0.mean(axis=1).unsqueeze(dim=1),log_prob_1.mean(axis=1).unsqueeze(dim=1)))
  389.                             # target_q1,target_q2=critic.target_critic_v(b_s_,new_action)
  390.                             # target_q=b_r+GAMMA*(torch.min(target_q1,target_q2)-entroy.alpha*log_prob_)
  391.                             # current_q1, current_q2 = critic.get_v(b_s, b_a)
  392.                             # critic.learn(current_q1,current_q2,target_q.detach())
  393.                             # a0,log_prob0=actor.evaluate(b_s[:, :state_number])
  394.                             # a1, log_prob1 = actor.evaluate(b_s[:, state_number:state_number*2])
  395.                             # a = torch.hstack((a0, a1))
  396.                             # log_prob=(log_prob0+log_prob1)/2
  397.                             # # log_prob = torch.hstack((log_prob0.mean(axis=1).unsqueeze(dim=1), log_prob1.mean(axis=1).unsqueeze(dim=1)))
  398.                             # q1,q2=critic.get_v(b_s,a)
  399.                             # q=torch.min(q1,q2)
  400.                             # actor_loss = (entroy.alpha * log_prob - q).mean()
  401.                             # actor.learn(actor_loss)
  402.                             # alpha_loss = -(entroy.log_alpha.exp() * (log_prob + entroy.target_entropy).detach()).mean()
  403.                             # entroy.learn(alpha_loss)
  404.                             # entroy.alpha=entroy.log_alpha.exp()
  405.                             # # 软更新
  406.                             # critic.soft_update()
  407.                         observation = observation_
  408.                         reward_totle += reward.mean()
  409.                         reward_totle0 += float(reward[0])
  410.                         reward_totle1 += float(reward[1])
  411.                         if RENDER:
  412.                             env.render()
  413.                         if done:
  414.                             break
  415.                     print("Ep: {} rewards: {}".format(episode, reward_totle))
  416.                     all_ep_r[k].append(reward_totle)
  417.                     all_ep_r0[k].append(reward_totle0)
  418.                     all_ep_r1[k].append(reward_totle1)
  419.                     if episode % 20 == 0 and episode > 200:#保存神经网络参数
  420.                         save_data = {'net': actors[0].action_net.state_dict(), 'opt': actors[0].optimizer.state_dict()}
  421.                         torch.save(save_data, filedir + "\\Path_SAC_actor_L1.pth")
  422.                         save_data = {'net': actors[1].action_net.state_dict(), 'opt': actors[1].optimizer.state_dict()}
  423.                         torch.save(save_data, filedir + "\\Path_SAC_actor_F1.pth")
  424.                 # plt.plot(np.arange(len(all_ep_r)), all_ep_r)
  425.                 # plt.xlabel('Episode')
  426.                 # plt.ylabel('Total reward')
  427.                 # plt.figure(2, figsize=(8, 4), dpi=150)
  428.                 # plt.plot(np.arange(len(all_ep_r0)), all_ep_r0)
  429.                 # plt.xlabel('Episode')
  430.                 # plt.ylabel('Leader reward')
  431.                 # plt.figure(3, figsize=(8, 4), dpi=150)
  432.                 # plt.plot(np.arange(len(all_ep_r1)), all_ep_r1)
  433.                 # plt.xlabel('Episode')
  434.                 # plt.ylabel('Follower reward')
  435.                 # plt.show()
  436.                 # env.close()
  437.             all_ep_r_mean = np.mean((np.array(all_ep_r)), axis=0)
  438.             all_ep_r_std = np.std((np.array(all_ep_r)), axis=0)
  439.             all_ep_L_mean = np.mean((np.array(all_ep_r0)), axis=0)
  440.             all_ep_L_std = np.std((np.array(all_ep_r0)), axis=0)
  441.             all_ep_F_mean = np.mean((np.array(all_ep_r1)), axis=0)
  442.             all_ep_F_std = np.std((np.array(all_ep_r1)), axis=0)
  443.             d = {"all_ep_r_mean": all_ep_r_mean, "all_ep_r_std": all_ep_r_std,
  444.                  "all_ep_L_mean": all_ep_L_mean, "all_ep_L_std": all_ep_L_std,
  445.                  "all_ep_F_mean": all_ep_F_mean, "all_ep_F_std": all_ep_F_std,}
  446.             f = open(shoplistfile, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  447.             pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  448.             f.close()
  449.             all_ep_r_max = all_ep_r_mean + all_ep_r_std * 0.95
  450.             all_ep_r_min = all_ep_r_mean - all_ep_r_std * 0.95
  451.             all_ep_L_max = all_ep_L_mean + all_ep_L_std * 0.95
  452.             all_ep_L_min = all_ep_L_mean - all_ep_L_std * 0.95
  453.             all_ep_F_max = all_ep_F_mean + all_ep_F_std * 0.95
  454.             all_ep_F_min = all_ep_F_mean - all_ep_F_std * 0.95
  455.             plt.margins(x=0)
  456.             plt.plot(np.arange(len(all_ep_r_mean)), all_ep_r_mean, label='MASAC', color='#e75840')
  457.             plt.fill_between(np.arange(len(all_ep_r_mean)), all_ep_r_max, all_ep_r_min, alpha=0.6, facecolor='#e75840')
  458.             plt.xlabel('Episode')
  459.             plt.ylabel('Total reward')
  460.             plt.figure(2, figsize=(8, 4), dpi=150)
  461.             plt.margins(x=0)
  462.             plt.plot(np.arange(len(all_ep_L_mean)), all_ep_L_mean, label='MASAC', color='#e75840')
  463.             plt.fill_between(np.arange(len(all_ep_L_mean)), all_ep_L_max, all_ep_L_min, alpha=0.6,
  464.                              facecolor='#e75840')
  465.             plt.xlabel('Episode')
  466.             plt.ylabel('Leader reward')
  467.             plt.figure(3, figsize=(8, 4), dpi=150)
  468.             plt.margins(x=0)
  469.             plt.plot(np.arange(len(all_ep_F_mean)), all_ep_F_mean, label='MASAC', color='#e75840')
  470.             plt.fill_between(np.arange(len(all_ep_F_mean)), all_ep_F_max, all_ep_F_min, alpha=0.6,
  471.                              facecolor='#e75840')
  472.             plt.xlabel('Episode')
  473.             plt.ylabel('Follower reward')
  474.             plt.legend()
  475.             plt.show()
  476.             env.close()
  477.     else:
  478.         print('SAC测试中...')
  479.         aa = Actor()
  480.         checkpoint_aa = torch.load(filedir + "\\Path_SAC_actor_L1.pth")
  481.         aa.action_net.load_state_dict(checkpoint_aa['net'])
  482.         bb = Actor()
  483.         checkpoint_bb = torch.load(filedir + "\\Path_SAC_actor_F1.pth")
  484.         bb.action_net.load_state_dict(checkpoint_bb['net'])
  485.         action = np.zeros((N_Agent+M_Enemy, action_number))
  486.         win_times = 0
  487.         average_FKR=0
  488.         average_timestep=0
  489.         average_integral_V=0
  490.         average_integral_U= 0
  491.         all_ep_V, all_ep_U, all_ep_T, all_ep_F = [], [], [], []
  492.         for j in range(TEST_EPIOSDE):
  493.             state = env.reset()
  494.             total_rewards = 0
  495.             integral_V=0
  496.             integral_U=0
  497.             v,v1,Dis=[],[],[]
  498.             for timestep in range(EP_LEN):
  499.                 for i in range(N_Agent):
  500.                     action[i] = aa.choose_action(state[i])
  501.                 for i in range(M_Enemy):
  502.                     action[i+1] = bb.choose_action(state[i+1])
  503.                 # action[0] = aa.choose_action(state[0])
  504.                 # action[1] = bb.choose_action(state[1])
  505.                 new_state, reward,done,win,team_counter,dis = env.step(action)  # 执行动作
  506.                 if win:
  507.                     win_times += 1
  508.                 v.append(state[0][2]*30)
  509.                 v1.append(state[1][2]*30)
  510.                 Dis.append(dis)
  511.                 integral_V+=state[0][2]
  512.                 integral_U+=abs(action[0]).sum()
  513.                 total_rewards += reward.mean()
  514.                 state = new_state
  515.                 if RENDER:
  516.                     env.render()
  517.                 if done:
  518.                     break
  519.             FKR=team_counter/timestep
  520.             average_FKR += FKR
  521.             average_timestep += timestep
  522.             average_integral_V += integral_V
  523.             average_integral_U += integral_U
  524.             print("Score", total_rewards)
  525.             all_ep_V.append(integral_V)
  526.             all_ep_U.append(integral_U)
  527.             all_ep_T.append(timestep)
  528.             all_ep_F.append(FKR)
  529.             # print('最大编队保持率',FKR)
  530.             # print('最短飞行时间',timestep)
  531.             # print('最短飞行路程', integral_V)
  532.             # print('最小能量损耗', integral_U)
  533.             # d = {"leader": v, "follower": v1 }
  534.             # d = {"distance": Dis}
  535.             # f = open(shoplistfile_test, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  536.             # pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  537.             # f.close()
  538.             # plt.plot(np.arange(len(v)), v)
  539.             # plt.plot(np.arange(len(v1)), v1)
  540.             # plt.plot(np.arange(len(Dis)), Dis)
  541.             # plt.show()
  542.         print('任务完成率',win_times / TEST_EPIOSDE)
  543.         print('平均最大编队保持率', average_FKR/TEST_EPIOSDE)
  544.         print('平均最短飞行时间', average_timestep/TEST_EPIOSDE)
  545.         print('平均最短飞行路程', average_integral_V/TEST_EPIOSDE)
  546.         print('平均最小能量损耗', average_integral_U/TEST_EPIOSDE)
  547.         # d = {"all_ep_V": all_ep_V, "all_ep_U": all_ep_U, "all_ep_T": all_ep_T, "all_ep_F": all_ep_F, }
  548.         # f = open(shoplistfile_test1, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  549.         # pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  550.         # f.close()
  551.         env.close()
  552. if __name__ == '__main__':
  553.     main()
复制代码
5.2 main_DDPG.py

是作者论文对比的一种强化学习方法
  1. # -*- coding: utf-8 -*-
  2. #开发者:Bright Fang
  3. #开发时间:2023/7/20 23:34
  4. from rl_env.path_env import RlGame
  5. # import pygame
  6. # from assignment import constants as C
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import numpy as np
  11. from matplotlib import pyplot as plt
  12. import os
  13. import pickle as pkl
  14. filedir = os.path.dirname(__file__)
  15. shoplistfile = filedir + "\\MADDPG"  #保存文件数据所在文件的文件名
  16. shoplistfile_test = filedir + "\\MADDPG_compare"  #保存文件数据所在文件的文件名
  17. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  18. N_Agent=1
  19. M_Enemy=4
  20. RENDER=True
  21. env = RlGame(n=N_Agent,m=M_Enemy,render=RENDER).unwrapped
  22. state_number=7
  23. TEST_EPIOSDE=100
  24. TRAIN_NUM = 3
  25. EP_LEN = 1000
  26. EPIOSDE_ALL=500
  27. action_number=env.action_space.shape[0]
  28. max_action = env.action_space.high[0]
  29. min_action = env.action_space.low[0]
  30. LR_A = 1e-3    # learning rate for actor
  31. LR_C = 1e-3    # learning rate for critic
  32. GAMMA = 0.95
  33. MemoryCapacity=20000
  34. Batch=128
  35. Switch=1
  36. tau = 0.005
  37. '''DDPG第一步 设计A-C框架的Actor(DDPG算法,只有critic的部分才会用到记忆库)'''
  38. '''第一步 设计A-C框架形式的网络部分'''
  39. class ActorNet(nn.Module):
  40.     def __init__(self,inp,outp):
  41.         super(ActorNet, self).__init__()
  42.         self.in_to_y1=nn.Linear(inp,50)
  43.         self.in_to_y1.weight.data.normal_(0,0.1)
  44.         self.y1_to_y2=nn.Linear(50,20)
  45.         self.y1_to_y2.weight.data.normal_(0,0.1)
  46.         self.out=nn.Linear(20,outp)
  47.         self.out.weight.data.normal_(0,0.1)
  48.     def forward(self,inputstate):
  49.         inputstate=self.in_to_y1(inputstate)
  50.         inputstate=F.relu(inputstate)
  51.         inputstate=self.y1_to_y2(inputstate)
  52.         inputstate=torch.sigmoid(inputstate)
  53.         act=max_action*torch.tanh(self.out(inputstate))
  54.         # return F.softmax(act,dim=-1)
  55.         return act
  56. class CriticNet(nn.Module):
  57.     def __init__(self,input,output):
  58.         super(CriticNet, self).__init__()
  59.         self.in_to_y1=nn.Linear(input+output,40)
  60.         self.in_to_y1.weight.data.normal_(0,0.1)
  61.         self.y1_to_y2=nn.Linear(40,20)
  62.         self.y1_to_y2.weight.data.normal_(0,0.1)
  63.         self.out=nn.Linear(20,1)
  64.         self.out.weight.data.normal_(0,0.1)
  65.     def forward(self,s,a):
  66.         inputstate = torch.cat((s, a), dim=1)
  67.         inputstate=self.in_to_y1(inputstate)
  68.         inputstate=F.relu(inputstate)
  69.         inputstate=self.y1_to_y2(inputstate)
  70.         inputstate=torch.sigmoid(inputstate)
  71.         Q=self.out(inputstate)
  72.         return Q
  73. class Actor():
  74.     def __init__(self):
  75.         self.actor_estimate_eval,self.actor_reality_target = ActorNet(state_number,action_number),ActorNet(state_number,action_number)
  76.         self.optimizer = torch.optim.Adam(self.actor_estimate_eval.parameters(), lr=LR_A)
  77.     '''第二步 编写根据状态选择动作的函数'''
  78.     def choose_action(self, s):
  79.         inputstate = torch.FloatTensor(s)
  80.         probs = self.actor_estimate_eval(inputstate)
  81.         return probs.detach().numpy()
  82.     '''第四步 编写A的学习函数'''
  83.     '''生成输入为s的actor估计网络,用于传给critic估计网络,虽然这与choose_action函数一样,但如果直接用choose_action
  84.     函数生成的动作,DDPG是不会收敛的,原因在于choose_action函数生成的动作经过了记忆库,动作从记忆库出来后,动作的梯度数据消失了
  85.     所以再次编写了learn_a函数,它生成的动作没有过记忆库,是带有梯度的'''
  86.     def learn_a(self, s):
  87.         s = torch.FloatTensor(s)
  88.         A_prob = self.actor_estimate_eval(s)
  89.         return A_prob
  90.     '''把s_输入给actor现实网络,生成a_,a_将会被传给critic的实现网络'''
  91.     def learn_a_(self, s_):
  92.         s_ = torch.FloatTensor(s_)
  93.         A_prob=self.actor_reality_target(s_).detach()
  94.         return A_prob
  95.     '''actor的学习函数接受来自critic估计网络算出的Q_estimate_eval当做自己的loss,即负的critic_estimate_eval(s,a),使loss
  96.     最小化,即最大化critic网络生成的价值'''
  97.     def learn(self, a_loss):
  98.         loss = a_loss
  99.         self.optimizer.zero_grad()
  100.         loss.backward()
  101.         self.optimizer.step()
  102.     '''第六步,最后一步  编写软更新程序,Actor部分与critic部分都会有软更新代码'''
  103.     '''DQN是硬更新,即固定时间更新,而DDPG采用软更新,w_老_现实=τ*w_新_估计+(1-τ)w_老_现实'''
  104.     def soft_update(self):
  105.         for target_param, param in zip(self.actor_reality_target.parameters(), self.actor_estimate_eval.parameters()):
  106.             target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
  107. class Critic():
  108.     def __init__(self):
  109.         self.critic_estimate_eval,self.critic_reality_target=CriticNet(state_number*(N_Agent+M_Enemy),action_number),CriticNet(state_number*(N_Agent+M_Enemy),action_number)
  110.         self.optimizer = torch.optim.Adam(self.critic_estimate_eval.parameters(), lr=LR_C)
  111.         self.lossfun=nn.MSELoss()
  112.     '''第五步 编写critic的学习函数'''
  113.     '''使用critic估计网络得到 actor的loss,这里的输入参数a是带梯度的'''
  114.     def learn_loss(self, s, a):
  115.         s = torch.FloatTensor(s)
  116.         # a = a.view(-1, 1)
  117.         Q_estimate_eval = -self.critic_estimate_eval(s, a).mean()
  118.         return Q_estimate_eval
  119.     '''这里的输入参数a与a_是来自记忆库的,不带梯度,根据公式我们会得到critic的loss'''
  120.     def learn(self, s, a, r, s_, a_):
  121.         s = torch.FloatTensor(s)
  122.         a = torch.FloatTensor(a)#当前动作a来自记忆库
  123.         r = torch.FloatTensor(r)
  124.         s_ = torch.FloatTensor(s_)
  125.         # a_ = a_.view(-1, 1)  # view中一个参数定为-1,代表动态调整这个维度上的元素个数,以保证元素的总数不变
  126.         Q_estimate_eval = self.critic_estimate_eval(s, a)
  127.         Q_next = self.critic_reality_target(s_, a_).detach()
  128.         Q_reality_target = r + GAMMA * Q_next
  129.         loss = self.lossfun(Q_estimate_eval, Q_reality_target)
  130.         self.optimizer.zero_grad()
  131.         loss.backward()
  132.         self.optimizer.step()
  133.     def soft_update(self):
  134.         for target_param, param in zip(self.critic_reality_target.parameters(), self.critic_estimate_eval.parameters()):
  135.             target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
  136. '''第三步  建立记忆库'''
  137. class Memory():
  138.     def __init__(self,capacity,dims):
  139.         self.capacity=capacity
  140.         self.mem=np.zeros((capacity,dims))
  141.         self.memory_counter=0
  142.     '''存储记忆'''
  143.     def store_transition(self,s,a,r,s_):
  144.         tran = np.hstack((s, a,r, s_))  # 把s,a,r,s_困在一起,水平拼接
  145.         index = self.memory_counter % self.capacity#除余得索引
  146.         self.mem[index, :] = tran  # 给索引存值,第index行所有列都为其中一次的s,a,r,s_;mem会是一个capacity行,(s+a+r+s_)列的数组
  147.         self.memory_counter+=1
  148.     '''随机从记忆库里抽取'''
  149.     def sample(self,n):
  150.         assert self.memory_counter>=self.capacity,'记忆库没有存满记忆'
  151.         sample_index = np.random.choice(self.capacity, n)#从capacity个记忆里随机抽取n个为一批,可得到抽样后的索引号
  152.         new_mem = self.mem[sample_index, :]#由抽样得到的索引号在所有的capacity个记忆中  得到记忆s,a,r,s_
  153.         return new_mem
  154.     '''OU噪声'''
  155. class Ornstein_Uhlenbeck_Noise:
  156.     def __init__(self, mu, sigma=0.1, theta=0.1, dt=1e-2, x0=None):
  157.         self.theta = theta
  158.         self.mu = mu
  159.         self.sigma = sigma
  160.         self.dt = dt
  161.         self.x0 = x0
  162.         self.reset()
  163.     def __call__(self):
  164.         x = self.x_prev + \
  165.             self.theta * (self.mu - self.x_prev) * self.dt + \
  166.             self.sigma * np.sqrt(self.dt) * np.random.normal(size=self.mu.shape)
  167.         '''
  168.         后两行是dXt,其中后两行的前一行是θ(μ-Xt)dt,后一行是σεsqrt(dt)
  169.         '''
  170.         self.x_prev = x
  171.         return x
  172.     def reset(self):
  173.         if self.x0 is not None:
  174.             self.x_prev = self.x0
  175.         else:
  176.             self.x_prev = np.zeros_like(self.mu)
  177. def main():
  178.     os.environ["SDL_VIDEODRIVER"] = "dummy"
  179.     run(env)
  180. def run(env):
  181.     if Switch==0:
  182.         all_ep_r = [[] for i in range(TRAIN_NUM)]
  183.         all_ep_r0 = [[] for i in range(TRAIN_NUM)]
  184.         all_ep_r1 = [[] for i in range(TRAIN_NUM)]
  185.         for k in range(TRAIN_NUM):
  186.             actors = [None for _ in range(N_Agent + M_Enemy)]
  187.             critics = [None for _ in range(N_Agent + M_Enemy)]
  188.             for i in range(N_Agent + M_Enemy):
  189.                 actors[i] = Actor()
  190.                 critics[i] = Critic()
  191.             M = Memory(MemoryCapacity, 2 * state_number*(N_Agent+M_Enemy) + action_number*(N_Agent+M_Enemy) + 1*(N_Agent+M_Enemy))  # 奖惩是一个浮点数
  192.             ou_noise = Ornstein_Uhlenbeck_Noise(mu=np.zeros(((N_Agent+M_Enemy), action_number)))
  193.             action = np.zeros(((N_Agent + M_Enemy), action_number))
  194.             # all_ep_r = []
  195.             for episode in range(EPIOSDE_ALL):
  196.                 observation=env.reset()
  197.                 reward_totle, reward_totle0, reward_totle1 = 0, 0, 0
  198.                 for timestep in range(EP_LEN):
  199.                     for i in range(N_Agent + M_Enemy):
  200.                         action[i] = actors[i].choose_action(observation[i])
  201.                         # action[0]=actor0.choose_action(observation[0])
  202.                         # action[1] = actor0.choose_action(observation[1])
  203.                     if episode <= 50:
  204.                         noise = ou_noise()
  205.                     else:
  206.                         noise = 0
  207.                     action = action + noise
  208.                     action = np.clip(action, -max_action, max_action)
  209.                     observation_, reward,done,win,team_counter = env.step(action)  # 单步交互
  210.                     M.store_transition(observation.flatten(), action.flatten(), reward.flatten()/1000, observation_.flatten())
  211.                     if M.memory_counter > MemoryCapacity:
  212.                         b_M = M.sample(Batch)
  213.                         b_s = b_M[:, :state_number*(N_Agent+M_Enemy)]
  214.                         b_a = b_M[:,
  215.                               state_number * (N_Agent + M_Enemy): state_number * (N_Agent + M_Enemy) + action_number * (
  216.                                           N_Agent + M_Enemy)]
  217.                         b_r = b_M[:, -state_number * (N_Agent + M_Enemy) - 1 * (N_Agent + M_Enemy): -state_number * (
  218.                                     N_Agent + M_Enemy)]
  219.                         b_s_ = b_M[:, -state_number * (N_Agent + M_Enemy):]
  220.                         for i in range(N_Agent + M_Enemy):
  221.                             actor_action_0 = actors[i].learn_a(b_s[:, state_number*i:state_number*(i+1)])
  222.                             # actor_action_1 = actors[1].learn_a(b_s[:, state_number:state_number * 2])
  223.                             # actor_action = torch.hstack((actor_action_0, actor_action_1))
  224.                             actor_action_0_ = actors[i].learn_a_(b_s_[:, state_number*i:state_number*(i+1)])
  225.                             # actor_action_1_ = actors[1].learn_a_(b_s_[:, state_number:state_number * 2])
  226.                             # actor_action_ = torch.hstack((actor_action_0_, actor_action_1_))
  227.                             critics[i].learn(b_s, b_a[:, action_number*i:action_number*(i+1)], b_r, b_s_, actor_action_0_)
  228.                             Q_c_to_a_loss = critics[i].learn_loss(b_s, actor_action_0)
  229.                             actors[i].learn(Q_c_to_a_loss)
  230.                             # 软更新
  231.                             actors[i].soft_update()
  232.                             critics[i].soft_update()
  233.                     observation = observation_
  234.                     reward_totle += reward.mean()
  235.                     reward_totle0 += float(reward[0])
  236.                     reward_totle1 += float(reward[1])
  237.                     if RENDER:
  238.                         env.render()
  239.                     if done:
  240.                         break
  241.                 print('Episode {},奖励:{}'.format(episode, reward_totle))
  242.                 # all_ep_r.append(reward_totle)
  243.                 all_ep_r[k].append(reward_totle)
  244.                 all_ep_r0[k].append(reward_totle0)
  245.                 all_ep_r1[k].append(reward_totle1)
  246.                 if episode % 50 == 0 and episode > 200:#保存神经网络参数
  247.                     save_data = {'net': actors[0].actor_estimate_eval.state_dict(), 'opt': actors[0].optimizer.state_dict()}
  248.                     torch.save(save_data, filedir + "\\Path_DDPG_actor_new.pth")
  249.                     save_data = {'net': actors[1].actor_estimate_eval.state_dict(), 'opt': actors[1].optimizer.state_dict()}
  250.                     torch.save(save_data, filedir + "\\Path_DDPG_actor_1_new.pth")
  251.             # plt.plot(np.arange(len(all_ep_r)), all_ep_r)
  252.             # plt.xlabel('Episode')
  253.             # plt.ylabel('Moving averaged episode reward')
  254.             # plt.show()
  255.             # env.close()
  256.         all_ep_r_mean = np.mean((np.array(all_ep_r)), axis=0)
  257.         all_ep_r_std = np.std((np.array(all_ep_r)), axis=0)
  258.         all_ep_L_mean = np.mean((np.array(all_ep_r0)), axis=0)
  259.         all_ep_L_std = np.std((np.array(all_ep_r0)), axis=0)
  260.         all_ep_F_mean = np.mean((np.array(all_ep_r1)), axis=0)
  261.         all_ep_F_std = np.std((np.array(all_ep_r1)), axis=0)
  262.         d = {"all_ep_r_mean": all_ep_r_mean, "all_ep_r_std": all_ep_r_std,
  263.              "all_ep_L_mean": all_ep_L_mean, "all_ep_L_std": all_ep_L_std,
  264.              "all_ep_F_mean": all_ep_F_mean, "all_ep_F_std": all_ep_F_std,}
  265.         f = open(shoplistfile, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  266.         pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  267.         f.close()
  268.         all_ep_r_max = all_ep_r_mean + all_ep_r_std * 0.95
  269.         all_ep_r_min = all_ep_r_mean - all_ep_r_std * 0.95
  270.         all_ep_L_max = all_ep_L_mean + all_ep_L_std * 0.95
  271.         all_ep_L_min = all_ep_L_mean - all_ep_L_std * 0.95
  272.         all_ep_F_max = all_ep_F_mean + all_ep_F_std * 0.95
  273.         all_ep_F_min = all_ep_F_mean - all_ep_F_std * 0.95
  274.         plt.margins(x=0)
  275.         plt.plot(np.arange(len(all_ep_r_mean)), all_ep_r_mean, label='MADDPG', color='#e75840')
  276.         plt.fill_between(np.arange(len(all_ep_r_mean)), all_ep_r_max, all_ep_r_min, alpha=0.6, facecolor='#e75840')
  277.         plt.xlabel('Episode')
  278.         plt.ylabel('Total reward')
  279.         plt.figure(2, figsize=(8, 4), dpi=150)
  280.         plt.margins(x=0)
  281.         plt.plot(np.arange(len(all_ep_L_mean)), all_ep_L_mean, label='MADDPG', color='#e75840')
  282.         plt.fill_between(np.arange(len(all_ep_L_mean)), all_ep_L_max, all_ep_L_min, alpha=0.6,
  283.                          facecolor='#e75840')
  284.         plt.xlabel('Episode')
  285.         plt.ylabel('Leader reward')
  286.         plt.figure(3, figsize=(8, 4), dpi=150)
  287.         plt.margins(x=0)
  288.         plt.plot(np.arange(len(all_ep_F_mean)), all_ep_F_mean, label='MADDPG', color='#e75840')
  289.         plt.fill_between(np.arange(len(all_ep_F_mean)), all_ep_F_max, all_ep_F_min, alpha=0.6,
  290.                          facecolor='#e75840')
  291.         plt.xlabel('Episode')
  292.         plt.ylabel('Follower reward')
  293.         plt.legend()
  294.         plt.show()
  295.         env.close()
  296.     else:
  297.         print('MADDPG测试中...')
  298.         aa = Actor()
  299.         checkpoint_aa = torch.load(filedir + "\\Path_DDPG_actor_new.pth")
  300.         aa.actor_estimate_eval.load_state_dict(checkpoint_aa['net'])
  301.         bb = Actor()
  302.         checkpoint_bb = torch.load(filedir + "\\Path_DDPG_actor_1_new.pth")
  303.         bb.actor_estimate_eval.load_state_dict(checkpoint_bb['net'])
  304.         action = np.zeros((N_Agent + M_Enemy, action_number))
  305.         win_times = 0
  306.         average_FKR = 0
  307.         average_timestep = 0
  308.         average_integral_V = 0
  309.         average_integral_U = 0
  310.         all_ep_V, all_ep_U, all_ep_T, all_ep_F = [], [], [], []
  311.         for j in range(TEST_EPIOSDE):
  312.             state = env.reset()
  313.             total_rewards = 0
  314.             integral_V = 0
  315.             integral_U = 0
  316.             v, v1 = [], []
  317.             for timestep in range(EP_LEN):
  318.                 for i in range(N_Agent):
  319.                     action[i] = aa.choose_action(state[i])
  320.                 for i in range(M_Enemy):
  321.                     action[i + 1] = bb.choose_action(state[i + 1])
  322.                 # action[0] = aa.choose_action(state[0])
  323.                 # action[1] = bb.choose_action(state[1])
  324.                 new_state, reward, done, win, team_counter,d = env.step(action)  # 执行动作
  325.                 if win:
  326.                     win_times += 1
  327.                 v.append(state[0][2])
  328.                 v1.append(state[1][2])
  329.                 integral_V += state[0][2]
  330.                 integral_U += abs(action[0]).sum()
  331.                 total_rewards += reward.mean()
  332.                 state = new_state
  333.                 if RENDER:
  334.                     env.render()
  335.                 if done:
  336.                     break
  337.             FKR = team_counter / timestep
  338.             average_FKR += FKR
  339.             average_timestep += timestep
  340.             average_integral_V += integral_V
  341.             average_integral_U += integral_U
  342.             print("Score", total_rewards)
  343.             all_ep_V.append(integral_V)
  344.             all_ep_U.append(integral_U)
  345.             all_ep_T.append(timestep)
  346.             all_ep_F.append(FKR)
  347.             # print('最大编队保持率',FKR)
  348.             # print('最短飞行时间',timestep)
  349.             # print('最短飞行路程', integral_V)
  350.             # print('最小能量损耗', integral_U)
  351.             # plt.plot(np.arange(len(v)), v)
  352.             # plt.plot(np.arange(len(v1)), v1)
  353.             # plt.show()
  354.         print('任务完成率', win_times / TEST_EPIOSDE)
  355.         print('平均最大编队保持率', average_FKR / TEST_EPIOSDE)
  356.         print('平均最短飞行时间', average_timestep / TEST_EPIOSDE)
  357.         print('平均最短飞行路程', average_integral_V / TEST_EPIOSDE)
  358.         print('平均最小能量损耗', average_integral_U / TEST_EPIOSDE)
  359.         d = {"all_ep_V": all_ep_V, "all_ep_U": all_ep_U, "all_ep_T": all_ep_T, "all_ep_F": all_ep_F, }
  360.         f = open(shoplistfile_test, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  361.         pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  362.         f.close()
  363.         env.close()
  364. if __name__ == '__main__':
  365.     main()
复制代码
5.3 main.py

是作者测试训练完评估权重的演示demo
  1. # -*- coding: utf-8 -*-
  2. #开发者:Bright Fang
  3. #开发时间:2023/7/30 18:13
  4. from rl_env.path_env import RlGame
  5. # import pygame
  6. # from assignment import constants as C
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import numpy as np
  11. from matplotlib import pyplot as plt
  12. import os
  13. import pickle as pkl
  14. shoplistfile_test = "MASAC_compare"  #保存文件数据所在文件的文件名
  15. os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
  16. N_Agent=1
  17. M_Enemy=4
  18. RENDER=True
  19. TRAIN_NUM = 1
  20. TEST_EPIOSDE=100
  21. env = RlGame(n=N_Agent,m=M_Enemy,render=RENDER).unwrapped
  22. state_number=7
  23. action_number=env.action_space.shape[0]
  24. max_action = env.action_space.high[0]
  25. min_action = env.action_space.low[0]
  26. EP_MAX = 500
  27. EP_LEN = 1000
  28. def main():
  29.     run(env)
  30. def run(env):
  31.     print('随机测试中...')
  32.     action = np.zeros((N_Agent+M_Enemy, action_number))
  33.     win_times = 0
  34.     average_FKR=0
  35.     average_timestep=0
  36.     average_integral_V=0
  37.     average_integral_U= 0
  38.     all_ep_V,all_ep_U,all_ep_T,all_ep_F=[],[],[],[]
  39.     for j in range(TEST_EPIOSDE):
  40.         state = env.reset()
  41.         total_rewards = 0
  42.         integral_V=0
  43.         integral_U=0
  44.         v,v1=[],[]
  45.         for timestep in range(EP_LEN):
  46.             for i in range(N_Agent+M_Enemy):
  47.                 action[i] = env.action_space.sample()
  48.             # action[0] = aa.choose_action(state[0])
  49.             # action[1] = bb.choose_action(state[1])
  50.             new_state, reward,done,win,team_counter,d = env.step(action)  # 执行动作
  51.             if win:
  52.                 win_times += 1
  53.             v.append(state[0][2])
  54.             v1.append(state[1][2])
  55.             integral_V+=state[0][2]
  56.             integral_U+=abs(action[0]).sum()
  57.             total_rewards += reward.mean()
  58.             state = new_state
  59.             if RENDER:
  60.                 env.render()
  61.             if done:
  62.                 break
  63.         FKR=team_counter/timestep
  64.         average_FKR += FKR
  65.         average_timestep += timestep
  66.         average_integral_V += integral_V
  67.         average_integral_U += integral_U
  68.         print("Score", total_rewards)
  69.         all_ep_V.append(integral_V)
  70.         all_ep_U.append(integral_U)
  71.         all_ep_T.append(timestep)
  72.         all_ep_F.append(FKR)
  73.         # print('最大编队保持率',FKR)
  74.         # print('最短飞行时间',timestep)
  75.         # print('最短飞行路程', integral_V)
  76.         # print('最小能量损耗', integral_U)
  77.         # plt.plot(np.arange(len(v)), v)
  78.         # plt.plot(np.arange(len(v1)), v1)
  79.         # plt.show()
  80.     print('任务完成率',win_times / TEST_EPIOSDE)
  81.     print('平均最大编队保持率', average_FKR/TEST_EPIOSDE)
  82.     print('平均最短飞行时间', average_timestep/TEST_EPIOSDE)
  83.     print('平均最短飞行路程', average_integral_V/TEST_EPIOSDE)
  84.     print('平均最小能量损耗', average_integral_U/TEST_EPIOSDE)
  85.     # d = {"all_ep_V": all_ep_V,"all_ep_U": all_ep_U,"all_ep_T": all_ep_T,"all_ep_F": all_ep_F,}
  86.     # f = open(shoplistfile_test, 'wb')  # 二进制打开,如果找不到该文件,则创建一个
  87.     # pkl.dump(d, f, pkl.HIGHEST_PROTOCOL)  # 写入文件
  88.     # f.close()
  89.     env.close()
  90. if __name__ == '__main__':
  91.     main()
复制代码
第四处是环境配置位置

5.4. path_env.py

环境显示调用的类与方法
  1. # -*- coding: utf-8 -*-
  2. #开发者:Bright Fang
  3. #开发时间:2023/7/20 23:30
  4. import numpy as np
  5. import os
  6. import copy
  7. import gym
  8. from assignment import constants as C
  9. from gym import spaces
  10. import math
  11. import random
  12. import pygame
  13. from assignment.components import player
  14. from assignment import tools
  15. from assignment.components import info
  16. class RlGame(gym.Env):
  17.     def __init__(self, n,m,render=False):
  18.         self.hero_num = n
  19.         self.enemy_num = m
  20.         self.obstacle_num=1
  21.         self.goal_num=1
  22.         self.Render=render
  23.         self.game_info = {
  24.             'epsoide': 0,
  25.             'hero_win': 0,
  26.             'enemy_win': 0,
  27.             'win': '未知',
  28.         }
  29.         if self.Render:
  30.             pygame.init()
  31.             pygame.mixer.init()
  32.             self.SCREEN = pygame.display.set_mode((C.SCREEN_W, C.SCREEN_H))
  33.             pygame.display.set_caption("基于深度强化学习的空战场景无人机路径规划软件")
  34.             self.GRAPHICS = tools.load_graphics(".\\assignment\\source\\image")
  35.             self.SOUND = tools.load_sound(".\\assignment\\source\\music")
  36.             self.clock = pygame.time.Clock()
  37.             self.mouse_pos=(100,100)
  38.             pygame.time.set_timer(C.CREATE_ENEMY_EVENT, C.ENEMY_MAKE_TIME)
  39.             # self.res, init_extra, update_extra, skip_override, waypoints = simulate(filename='')
  40.         # else:
  41.         #     self.dispaly=None
  42.         low = np.array([-1,-1])
  43.         high=np.array([1,1])
  44.         # self.action_space =spaces.Discrete(21)
  45.         # self.action_space = spaces.Discrete(2)
  46.         self.action_space=spaces.Box(low=low,high=high,dtype=np.float32)
  47.         # self.action_space = [spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),
  48.         #                      spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),
  49.         #                      spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),
  50.         #                      spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),
  51.         #                      spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),spaces.Discrete(2),]
  52.     def start(self):
  53.         # self.game_info=game_info
  54.         self.finished=False
  55.         # self.next='game_over'
  56.         self.set_battle_background()#战斗的背景
  57.         self.set_enemy_image()
  58.         self.set_hero_image()
  59.         self.set_obstacle_image()
  60.         self.set_goal_image()
  61.         self.info = info.Info('battle_screen',self.game_info)
  62.         # self.state = 'battle'
  63.         self.counter_1 = 0
  64.         self.counter_hero = 0
  65.         self.enemy_counter=0
  66.         self.enemy_counter_1 = 0
  67.         #又定义了一个参数,为了放在start函数里重置
  68.         self.enemy_num_start=self.enemy_num
  69.         self.trajectory_x,self.trajectory_y=[],[]
  70.         self.enemy_trajectory_x,self.enemy_trajectory_y=[[] for i in range(self.enemy_num)],[[] for i in range(self.enemy_num)]
  71.         # RL状态
  72.         # self.hero_state = np.zeros((self.hero_num, 4))
  73.         # self.hero_α = np.zeros((self.hero_num, 1))
  74.         self.uav_obs_check= np.zeros((self.hero_num, 1))
  75.     def set_battle_background(self):
  76.         self.battle_background = self.GRAPHICS['background']
  77.         self.battle_background = pygame.transform.scale(self.battle_background,C.SCREEN_SIZE)  # 缩放
  78.         self.view = self.SCREEN.get_rect()
  79.         #若要移动的背景图像,请用下面的代码替换
  80.         # bg1=player.BackgroundSprite(image_name='background3',size=C.SCREEN_SIZE)
  81.         # bg2=player.BackgroundSprite(image_name='background3',size=C.SCREEN_SIZE)
  82.         # bg2.rect.y=-bg2.rect.height
  83.         # self.background_group=pygame.sprite.Group(bg1,bg2)
  84.     def set_hero_image(self):
  85.         self.hero = self.__dict__
  86.         self.hero_group = pygame.sprite.Group()
  87.         self.hero_image = self.GRAPHICS['fighter-blue']
  88.         for i in range(self.hero_num):
  89.             self.hero['hero'+str(i)]=player.Hero(image=self.hero_image)
  90.             self.hero_group.add(self.hero['hero'+str(i)])
  91.     def set_enemy_image(self):
  92.         self.enemy = self.__dict__
  93.         self.enemy_group = pygame.sprite.Group()
  94.         self.enemy_image = self.GRAPHICS['fighter-green']
  95.         for i in range(self.enemy_num):
  96.             self.enemy['enemy'+str(i)]=player.Enemy(image=self.enemy_image)
  97.             self.enemy_group.add(self.enemy['enemy'+str(i)])
  98.     def set_hero(self):
  99.         self.hero = self.__dict__
  100.         self.hero_group = pygame.sprite.Group()
  101.         for i in range(self.hero_num):
  102.             self.hero['hero'+str(i)]=player.Hero()
  103.             self.hero_group.add(self.hero['hero'+str(i)])
  104.     def set_enemy(self):
  105.         self.enemy = self.__dict__
  106.         self.enemy_group = pygame.sprite.Group()
  107.         for i in range(self.enemy_num):
  108.             self.enemy['enemy'+str(i)]=player.Enemy()
  109.             self.enemy_group.add(self.enemy['enemy'+str(i)])
  110.     def set_obstacle_image(self):
  111.         self.obstacle = self.__dict__
  112.         self.obstacle_group = pygame.sprite.Group()
  113.         self.obstacle_image = self.GRAPHICS['hole']
  114.         for i in range(self.obstacle_num):
  115.             self.obstacle['obstacle'+str(i)]=player.Obstacle(image=self.obstacle_image)
  116.             self.obstacle_group.add(self.obstacle['obstacle'+str(i)])
  117.     def set_obstacle(self):
  118.         self.obstacle = self.__dict__
  119.         self.obstacle_group = pygame.sprite.Group()
  120.         for i in range(self.obstacle_num):
  121.             self.obstacle['obstacle'+str(i)]=player.Obstacle()
  122.             self.obstacle_group.add(self.obstacle['obstacle'+str(i)])
  123.     def set_goal_image(self):
  124.         self.goal = self.__dict__
  125.         self.goal_group = pygame.sprite.Group()
  126.         self.goal_image = self.GRAPHICS['goal']
  127.         for i in range(self.goal_num):
  128.             self.goal['goal'+str(i)]=player.Goal(image=self.goal_image)
  129.             self.goal_group.add(self.goal['goal'+str(i)])
  130.     def set_goal(self):
  131.         self.goal = self.__dict__
  132.         self.goal_group = pygame.sprite.Group()
  133.         for i in range(self.goal_num):
  134.             self.goal['goal'+str(i)]=player.Goal()
  135.             self.goal_group.add(self.goal['goal'+str(i)])
  136.     def update_game_info(self):#死亡后重置数据
  137.         self.game_info['epsoide'] += 1
  138.         self.game_info['enemy_win'] = self.game_info['epsoide'] - self.game_info['hero_win']
  139.     def reset(self):#reset的仅是环境状态,
  140.         # obs=np.zeros((self.n, 4))#这是个二维矩阵,n*2维,现在只考虑一个己方无人机,所以现在是一个一维的
  141.         # game_info=self.my_game.state.game_info
  142.         # self.my_game.state.start(game_info)
  143.         if self.Render:
  144.             self.start()
  145.         else:
  146.             self.set_hero()
  147.             self.set_enemy()
  148.             self.set_goal()
  149.             self.set_obstacle()
  150.         self.team_counter = 0
  151.         self.done = False
  152.         self.hero_state = np.zeros((self.hero_num+self.enemy_num,7))
  153.         self.hero_α = np.zeros((self.hero_num, 1))
  154.         # self.goal_x,self.goal_y=random.randint(100, 500), random.randint(100, 200)
  155.         return np.array([[self.hero0.init_x/1000,self.hero0.init_y/1000,self.hero0.speed/30,self.hero0.theta*57.3/360
  156.                             ,self.goal0.init_x/1000, self.goal0.init_y/1000,0],
  157.                          [self.enemy0.init_x / 1000, self.enemy0.init_y / 1000, self.enemy0.speed / 30,
  158.                          self.enemy0.theta * 57.3 / 360
  159.                             , self.hero0.init_x/1000, self.hero0.init_y/1000,self.hero0.speed / 30],
  160.                          [self.enemy1.init_x / 1000, self.enemy1.init_y / 1000, self.enemy1.speed / 30,
  161.                           self.enemy1.theta * 57.3 / 360
  162.                              , self.hero0.init_x / 1000, self.hero0.init_y / 1000, self.hero0.speed / 30],
  163.                          [self.enemy2.init_x / 1000, self.enemy2.init_y / 1000, self.enemy2.speed / 30,
  164.                           self.enemy2.theta * 57.3 / 360
  165.                              , self.hero0.init_x / 1000, self.hero0.init_y / 1000, self.hero0.speed / 30],
  166.                          [self.enemy3.init_x / 1000, self.enemy3.init_y / 1000, self.enemy3.speed / 30,
  167.                           self.enemy3.theta * 57.3 / 360
  168.                              , self.hero0.init_x / 1000, self.hero0.init_y / 1000, self.hero0.speed / 30],
  169.                          ])#np.array([self.my_game.state.hero['hero0'].posx/1000,self.my_game.state.hero['hero0'].posy/1000,self.my_game.state.hero['hero0'].speed/2,self.my_game.state.hero['hero0'].theta*57.3/360])#np.zeros((self.n,2)).flatten()
  170.     def step(self,action):
  171.         dis_1_obs = np.zeros((self.hero_num, 1))
  172.         dis_1_goal = np.zeros((self.hero_num+self.enemy_num, 1))
  173.         r=np.zeros((self.hero_num+self.enemy_num, 1))
  174.         o_flag = 0
  175.         o_flag1 = 0
  176.         #空气阻力系数
  177.         F_k=0.08
  178.         #无人机质量,100是像素与现实速度的比例,因为10像素/帧对应现实的100m/s
  179.         m=12000/100
  180.         #扰动的加速度
  181.         F_a=0
  182.         #边界奖励
  183.         edge_r=np.zeros((self.hero_num, 1))
  184.         edge_r_f = np.zeros((self.enemy_num, 1))
  185.         #避障奖励
  186.         obstacle_r = np.zeros((self.hero_num, 1))
  187.         obstacle_r1 = np.zeros((self.enemy_num, 1))
  188.         #目标奖励
  189.         goal_r = np.zeros((self.hero_num, 1))
  190.         # 编队奖励
  191.         follow_r = np.zeros((self.enemy_num, 1))
  192.         follow_r0 = 0
  193.         speed_r=0
  194.         # print(self.goal0.init_x)
  195.         dis_1_agent_0_to_1=math.hypot(self.hero0.posx - self.enemy0.posx, self.hero0.posy - self.enemy0.posy)
  196.         dis_1_agent_0_to_2 = math.hypot(self.hero0.posx - self.enemy1.posx, self.hero0.posy - self.enemy1.posy)
  197.         dis_1_agent_0_to_3 = math.hypot(self.hero0.posx - self.enemy2.posx, self.hero0.posy - self.enemy2.posy)
  198.         dis_1_agent_0_to_4 = math.hypot(self.hero0.posx - self.enemy3.posx, self.hero0.posy - self.enemy3.posy)
  199.         for i in range(self.hero_num+self.enemy_num):
  200.             #空气阻力
  201.             # self.hero['hero' + str(i)].F=F_k*math.pow(self.hero['hero' + str(i)].speed,2)
  202.             # F_a=(self.hero['hero' + str(i)].F/m)*math.cos(self.hero['hero' + str(i)].theta * 57.3)
  203.             # 己方与障碍物的碰撞检测
  204.             # self.hero['hero' + str(i)].enemies = pygame.sprite.spritecollide(self.hero['hero' + str(i)],
  205.             #                                                                      self.obstacle_group, False)
  206.             if i==0:#leader
  207.                 dis_1_obs[i] = math.hypot(self.hero['hero' + str(i)].posx - self.obstacle0.init_x,
  208.                                           self.hero['hero' + str(i)].posy - self.obstacle0.init_y)
  209.                 dis_1_goal[i] = math.hypot(self.hero['hero' + str(i)].posx - self.goal0.init_x,
  210.                                            self.hero['hero' + str(i)].posy - self.goal0.init_y)
  211.                 if self.hero['hero' + str(i)].posx <= C.ENEMY_AREA_X + 50:
  212.                     edge_r[i] = -1
  213.                 elif self.hero['hero' + str(i)].posx >= C.ENEMY_AREA_WITH:
  214.                     edge_r[i] = -1
  215.                 if self.hero['hero' + str(i)].posy >= C.ENEMY_AREA_HEIGHT:
  216.                     edge_r[i] = -1
  217.                 elif self.hero['hero' + str(i)].posy <= C.ENEMY_AREA_Y + 50:
  218.                     edge_r[i] = -1
  219.                 if 0 < dis_1_agent_0_to_1 < 50 and dis_1_agent_0_to_2<50 and dis_1_agent_0_to_3<50and dis_1_agent_0_to_4<50:
  220.                     follow_r0=0
  221.                     self.team_counter+=1
  222.                     if abs(self.hero0.speed-self.enemy0.speed)<1:
  223.                         speed_r=1
  224.                 else:
  225.                     follow_r0=-0.001*dis_1_agent_0_to_1
  226.                 if dis_1_goal[i] < 40 and not self.hero['hero' + str(i)].dead:
  227.                     goal_r[i] = 1000.0
  228.                     self.hero['hero' + str(i)].win = True
  229.                     self.hero['hero' + str(i)].die()
  230.                     self.done= True
  231.                     # self.game_info['hero_win'] += 1
  232.                     # self.update_game_info()
  233.                     print('aa')
  234.                 # elif dis_1_goal < 100:
  235.                 #     r = 1.0
  236.                 elif dis_1_obs[i] < 20 and not self.hero['hero' + str(i)].dead:
  237.                     o_flag = 1
  238.                     obstacle_r[i] = -500
  239.                     self.hero['hero' + str(i)].die()
  240.                     self.hero['hero' + str(i)].win = False
  241.                     self.done = True
  242.                     # self.update_game_info()
  243.                     print('gg')
  244.                 elif dis_1_obs[i] < 40 and not self.hero['hero' + str(i)].dead:
  245.                     o_flag = 1
  246.                     # print(-100000*math.pow(1/dis_1_obs[i],2))
  247.                     obstacle_r[i] = -2#-100000*math.pow(1/dis_1_obs[i],2)
  248.                 elif not self.hero['hero' + str(i)].dead:
  249.                     # print(math.exp(100/dis_1_goal[i])/10)
  250.                     goal_r[i] =-0.001 * dis_1_goal[i]# math.exp(100/dis_1_goal[i])/10
  251.                 r[i] = edge_r[i] + obstacle_r[i] + goal_r[i]+speed_r+follow_r0
  252.                 self.hero_state[i] = [self.hero['hero' + str(i)].posx / 1000, self.hero['hero' + str(i)].posy / 1000,
  253.                                       self.hero['hero' + str(i)].speed / 30,
  254.                                       self.hero['hero' + str(i)].theta * 57.3 / 360,
  255.                                       self.goal0.init_x / 1000, self.goal0.init_y / 1000, o_flag]
  256.                 self.hero['hero' + str(i)].update(action[i], self.Render)
  257.                 self.trajectory_x.append(self.hero['hero' + str(i)].posx)
  258.                 self.trajectory_y.append(self.hero['hero' + str(i)].posy)
  259.             else:
  260.                 dis_2_obs = math.hypot(self.enemy['enemy' + str(i-1)].posx - self.obstacle0.init_x,
  261.                                           self.enemy['enemy' + str(i-1)].posy - self.obstacle0.init_y)
  262.                 dis_1_goal[i] = math.hypot(self.enemy['enemy' + str(i-1)].posx - self.goal0.init_x,
  263.                                            self.enemy['enemy' + str(i-1)].posy- self.goal0.init_y)
  264.                 if dis_2_obs < 40:
  265.                     o_flag1 = 1
  266.                     obstacle_r1 = -2
  267.                 if self.enemy['enemy' + str(i-1)].posx <= C.ENEMY_AREA_X + 50:
  268.                     edge_r_f[i-1] = -1
  269.                 elif self.enemy['enemy' + str(i-1)].posx >= C.ENEMY_AREA_WITH:
  270.                     edge_r_f[i-1] = -1
  271.                 if self.enemy['enemy' + str(i-1)].posy >= C.ENEMY_AREA_HEIGHT:
  272.                     edge_r_f[i-1] = -1
  273.                 elif self.enemy['enemy' + str(i-1)].posy <= C.ENEMY_AREA_Y + 50:
  274.                     edge_r_f[i-1] = -1
  275.                 if 0 < dis_1_agent_0_to_1 < 50 and dis_1_goal[0]<dis_1_goal[1]:
  276.                     # follow_r[i-1]=5-abs(self.hero0.theta-self.enemy0.theta)
  277.                     # print('hhh')
  278.                     if abs(self.hero0.speed-self.enemy0.speed)<1:
  279.                         speed_r=1
  280.                         # print('hh')
  281.                 else:
  282.                     follow_r[i-1]=-0.001*dis_1_agent_0_to_1
  283.                 r[i] =  follow_r[i-1]+speed_r
  284.                 self.hero_state[i] = [self.enemy['enemy' + str(i-1)].posx / 1000, self.enemy['enemy' + str(i-1)].posy / 1000,
  285.                                       self.enemy['enemy' + str(i-1)].speed / 30,
  286.                                       self.enemy['enemy' + str(i-1)].theta * 57.3 / 360,
  287.                                       self.hero0.posx / 1000, self.hero0.posy / 1000,self.hero0.speed / 30 ]
  288.                 self.enemy['enemy' + str(i-1)].update(action[i], self.Render)
  289.                 self.enemy_trajectory_x[i-1].append(self.enemy['enemy' + str(i-1)].posx)
  290.                 self.enemy_trajectory_y[i-1].append(self.enemy['enemy' + str(i-1)].posy)
  291.                 # print(self.hero_state[i])
  292.             # init_to_goal=math.atan2((-150+self.hero['hero'+str(i)].init_y),(200-self.hero['hero'+str(i)].init_x))
  293.             # uav_to_goal = math.atan2((-self.goal0.init_y + self.hero['hero' + str(i)].posy), (self.goal0.init_x - self.hero['hero' + str(i)].posx))
  294.             # uav_to_obstacle = math.atan2((-self.obstacle0.init_y + self.hero['hero' + str(i)].posy), (self.obstacle0.init_x - self.hero['hero' + str(i)].posx))
  295.             # self.hero_α[i] =  0.1*abs(uav_to_obstacle - self.hero['hero' + str(i)].theta)
  296.         # 自己更新位置
  297.         # self.hero_group.update(action[0], action[1],self.Render)
  298.         hero_state = copy.deepcopy(self.hero_state)
  299.         done = copy.deepcopy(self.done)
  300.         return hero_state,r,done,self.hero['hero0'].win,self.team_counter,dis_1_agent_0_to_1
  301.     def render(self):
  302.         for event in pygame.event.get():
  303.             if event.type == pygame.QUIT:
  304.                 pygame.display.quit()
  305.                 quit()
  306.             elif event.type == pygame.MOUSEMOTION:
  307.                 self.mouse_pos = pygame.mouse.get_pos()
  308.             elif event.type == C.CREATE_ENEMY_EVENT:
  309.                 C.ENEMY_FLAG = True
  310.         # 画背景
  311.         self.SCREEN.blit(self.battle_background, self.view)
  312.         # 文字显示
  313.         self.info.update(self.mouse_pos)
  314.         # 画图
  315.         self.draw(self.SCREEN)
  316.         pygame.display.update()
  317.         self.clock.tick(C.FPS)
  318.     def draw(self,surface):
  319.         # self.background_group.draw(surface)
  320.         #敌占区的矩形
  321.         pygame.draw.rect(surface, C.BLACK, C.ENEMY_AREA, 3)
  322.         #目标星星
  323.         # pygame.draw.polygon(surface, C.GREEN,[(200, 135), (205, 145), (215, 145), (210, 155), (213, 165), (200, 160), (187, 165), (190, 155), (185, 145), (195, 145)])
  324.         pygame.draw.circle(surface, C.RED, (self.goal0.init_x, self.goal0.init_y), 1)
  325.         pygame.draw.circle(surface, C.RED, (self.goal0.init_x, self.goal0.init_y), 40,1)
  326.         # pygame.draw.circle(surface, C.GREEN, (self.goal0.init_x, self.goal0.init_y),100, 1)
  327.         pygame.draw.circle(surface, C.BLACK, (self.obstacle0.init_x, self.obstacle0.init_y), 20, 1)
  328.         # 画轨迹
  329.         for i in range(1, len(self.trajectory_x)):
  330.             pygame.draw.line(surface, C.BLUE, (self.trajectory_x[i - 1], self.trajectory_y[i - 1]), (self.trajectory_x[i], self.trajectory_y[i]))
  331.         for j in range(self.enemy_num):
  332.             for i in range(1, len(self.trajectory_x)):
  333.                 pygame.draw.line(surface, C.GREEN, (self.enemy_trajectory_x[j][i - 1], self.enemy_trajectory_y[j][i - 1]),
  334.                                  (self.enemy_trajectory_x[j][i], self.enemy_trajectory_y[j][i]))
  335.         #障碍物
  336.         # pygame.draw.circle(surface, C.BLACK, (250, 300), 20)
  337.         # 画自己
  338.         self.hero_group.draw(surface)
  339.         self.enemy_group.draw(surface)
  340.         #障碍物
  341.         self.obstacle_group.draw(surface)
  342.         # 目标星星
  343.         self.goal_group.draw(surface)
  344.         #画文字信息
  345.         self.info.draw(surface)
  346.     def close(self):
  347.         pygame.display.quit()
  348.         quit()
复制代码
然后直接运行main_SAC.py

完美!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

大连密封材料

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表