Soft Actor-Critic (SAC)算法

宁睿  论坛元老 | 2024-12-28 15:14:01 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1085|帖子 1085|积分 3265

代码

  1. import gym
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torch.nn.functional as F
  6. import numpy as np
  7. import pygame
  8. # 定义 Actor 网络
  9. class Actor(nn.Module):
  10.     def __init__(self, state_dim, action_dim, max_action):
  11.         super(Actor, self).__init__()
  12.         self.fc1 = nn.Linear(state_dim, 256)
  13.         self.fc2 = nn.Linear(256, 256)
  14.         self.mu = nn.Linear(256, action_dim)
  15.         self.log_std = nn.Linear(256, action_dim)
  16.         self.max_action = max_action
  17.     def forward(self, state):
  18.         x = F.relu(self.fc1(state))
  19.         x = F.relu(self.fc2(x))
  20.         mu = self.mu(x)
  21.         log_std = self.log_std(x)
  22.         log_std = torch.clamp(log_std, -20, 2)
  23.         std = torch.exp(log_std)
  24.         return mu, std
  25.     def sample(self, state):
  26.         mu, std = self.forward(state)
  27.         dist = torch.distributions.Normal(mu, std)
  28.         action = dist.rsample()
  29.         action = torch.tanh(action) * self.max_action
  30.         log_prob = dist.log_prob(action).sum(axis=-1)
  31.         log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(axis=-1)
  32.         return action, log_prob
  33. # 定义 Critic 网络
  34. class Critic(nn.Module):
  35.     def __init__(self, state_dim, action_dim):
  36.         super(Critic, self).__init__()
  37.         self.fc1 = nn.Linear(state_dim + action_dim, 256)
  38.         self.fc2 = nn.Linear(256, 256)
  39.         self.fc3 = nn.Linear(256, 1)
  40.     def forward(self, state, action):
  41.         x = torch.cat([state, action], 1)
  42.         x = F.relu(self.fc1(x))
  43.         x = F.relu(self.fc2(x))
  44.         x = self.fc3(x)
  45.         return x
  46. # SAC 算法
  47. class SAC:
  48.     def __init__(self, state_dim, action_dim, max_action):
  49.         self.actor = Actor(state_dim, action_dim, max_action)
  50.         self.critic1 = Critic(state_dim, action_dim)
  51.         self.critic2 = Critic(state_dim, action_dim)
  52.         self.target_critic1 = Critic(state_dim, action_dim)
  53.         self.target_critic2 = Critic(state_dim, action_dim)
  54.         self.target_critic1.load_state_dict(self.critic1.state_dict())
  55.         self.target_critic2.load_state_dict(self.critic2.state_dict())
  56.         self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
  57.         self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
  58.         self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
  59.         self.log_alpha = torch.tensor(np.log(0.1), requires_grad=True)
  60.         self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)
  61.         self.gamma = 0.99
  62.         self.tau = 0.005
  63.     def select_action(self, state):
  64.         state = torch.FloatTensor(state.reshape(1, -1))  # 确保 state 是 (1, state_dim) 的形状
  65.         action, _ = self.actor.sample(state)
  66.         return action.cpu().data.numpy().flatten()
  67.     def update(self, replay_buffer, batch_size=256):
  68.         state, action, next_state, reward, done = replay_buffer.sample(batch_size)
  69.         state = torch.FloatTensor(state)
  70.         action = torch.FloatTensor(action)
  71.         next_state = torch.FloatTensor(next_state)
  72.         reward = torch.FloatTensor(reward).unsqueeze(1)
  73.         done = torch.FloatTensor(done).unsqueeze(1)
  74.         with torch.no_grad():
  75.             next_action, next_log_prob = self.actor.sample(next_state)
  76.             target_q1 = self.target_critic1(next_state, next_action)
  77.             target_q2 = self.target_critic2(next_state, next_action)
  78.             target_q = torch.min(target_q1, target_q2) - self.log_alpha.exp() * next_log_prob
  79.             target_q = reward + (1 - done) * self.gamma * target_q
  80.         current_q1 = self.critic1(state, action)
  81.         current_q2 = self.critic2(state, action)
  82.         critic1_loss = F.mse_loss(current_q1, target_q)
  83.         critic2_loss = F.mse_loss(current_q2, target_q)
  84.         self.critic1_optimizer.zero_grad()
  85.         critic1_loss.backward()
  86.         self.critic1_optimizer.step()
  87.         self.critic2_optimizer.zero_grad()
  88.         critic2_loss.backward()
  89.         self.critic2_optimizer.step()
  90.         action_new, log_prob = self.actor.sample(state)
  91.         q1_new = self.critic1(state, action_new)
  92.         q2_new = self.critic2(state, action_new)
  93.         q_new = torch.min(q1_new, q2_new)
  94.         actor_loss = (self.log_alpha.exp() * log_prob - q_new).mean()
  95.         self.actor_optimizer.zero_grad()
  96.         actor_loss.backward()
  97.         self.actor_optimizer.step()
  98.         alpha_loss = -(self.log_alpha * (log_prob + 1).detach()).mean()
  99.         self.alpha_optimizer.zero_grad()
  100.         alpha_loss.backward()
  101.         self.alpha_optimizer.step()
  102.         for param, target_param in zip(self.critic1.parameters(), self.target_critic1.parameters()):
  103.             target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  104.         for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
  105.             target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  106. # 简单的 Replay Buffer
  107. class ReplayBuffer:
  108.     def __init__(self, max_size=1e6):
  109.         self.buffer = []
  110.         self.max_size = int(max_size)  # 将 max_size 转换为整数
  111.         self.ptr = 0
  112.     def add(self, state, action, next_state, reward, done):
  113.         if len(self.buffer) < self.max_size:
  114.             self.buffer.append(None)
  115.         self.buffer[self.ptr] = (state, action, next_state, reward, done)
  116.         self.ptr = (self.ptr + 1) % self.max_size
  117.     def sample(self, batch_size):
  118.         indices = np.random.randint(0, len(self.buffer), batch_size)
  119.         states, actions, next_states, rewards, dones = [], [], [], [], []
  120.         for idx in indices:
  121.             state, action, next_state, reward, done = self.buffer[idx]
  122.             states.append(state)
  123.             actions.append(action)
  124.             next_states.append(next_state)
  125.             rewards.append(reward)
  126.             dones.append(done)
  127.         return np.array(states), np.array(actions), np.array(next_states), np.array(rewards), np.array(dones)
  128. # 训练 SAC 算法
  129. env = gym.make('Pendulum-v1')
  130. state_dim = env.observation_space.shape[0]
  131. action_dim = env.action_space.shape[0]
  132. max_action = float(env.action_space.high[0])
  133. sac = SAC(state_dim, action_dim, max_action)
  134. replay_buffer = ReplayBuffer()
  135. max_episodes = 1000
  136. batch_size = 256
  137. for episode in range(max_episodes):
  138.     state = env.reset()
  139.     if isinstance(state, tuple):  # 如果返回的是元组,提取状态
  140.         state = state[0]
  141.     episode_reward = 0
  142.     done = False
  143.     while not done:
  144.         env.render()
  145.         action = sac.select_action(state)
  146.         next_state, reward, done, info = env.step(action)
  147.         replay_buffer.add(state, action, next_state, reward, done)
  148.         state = next_state
  149.         episode_reward += reward
  150.         if len(replay_buffer.buffer) > batch_size:
  151.             sac.update(replay_buffer, batch_size)
  152.     print(f"Episode {episode + 1}, Reward: {episode_reward}")
  153. env.close()
复制代码
简介

Soft Actor-Critic (SAC) 是一种基于最大熵(Maximum Entropy)的深度强化学习算法,专为一连动作空间设计。它联合了 Actor-Critic 框架和熵正则化(Entropy Regularization),在探索与利用之间取得了良好的平衡。
图片



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

宁睿

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表