代码
- import gym
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import torch.nn.functional as F
- import numpy as np
- import pygame
- # 定义 Actor 网络
- class Actor(nn.Module):
- def __init__(self, state_dim, action_dim, max_action):
- super(Actor, self).__init__()
- self.fc1 = nn.Linear(state_dim, 256)
- self.fc2 = nn.Linear(256, 256)
- self.mu = nn.Linear(256, action_dim)
- self.log_std = nn.Linear(256, action_dim)
- self.max_action = max_action
- def forward(self, state):
- x = F.relu(self.fc1(state))
- x = F.relu(self.fc2(x))
- mu = self.mu(x)
- log_std = self.log_std(x)
- log_std = torch.clamp(log_std, -20, 2)
- std = torch.exp(log_std)
- return mu, std
- def sample(self, state):
- mu, std = self.forward(state)
- dist = torch.distributions.Normal(mu, std)
- action = dist.rsample()
- action = torch.tanh(action) * self.max_action
- log_prob = dist.log_prob(action).sum(axis=-1)
- log_prob -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(axis=-1)
- return action, log_prob
- # 定义 Critic 网络
- class Critic(nn.Module):
- def __init__(self, state_dim, action_dim):
- super(Critic, self).__init__()
- self.fc1 = nn.Linear(state_dim + action_dim, 256)
- self.fc2 = nn.Linear(256, 256)
- self.fc3 = nn.Linear(256, 1)
- def forward(self, state, action):
- x = torch.cat([state, action], 1)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
- # SAC 算法
- class SAC:
- def __init__(self, state_dim, action_dim, max_action):
- self.actor = Actor(state_dim, action_dim, max_action)
- self.critic1 = Critic(state_dim, action_dim)
- self.critic2 = Critic(state_dim, action_dim)
- self.target_critic1 = Critic(state_dim, action_dim)
- self.target_critic2 = Critic(state_dim, action_dim)
- self.target_critic1.load_state_dict(self.critic1.state_dict())
- self.target_critic2.load_state_dict(self.critic2.state_dict())
- self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
- self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
- self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
- self.log_alpha = torch.tensor(np.log(0.1), requires_grad=True)
- self.alpha_optimizer = optim.Adam([self.log_alpha], lr=3e-4)
- self.gamma = 0.99
- self.tau = 0.005
- def select_action(self, state):
- state = torch.FloatTensor(state.reshape(1, -1)) # 确保 state 是 (1, state_dim) 的形状
- action, _ = self.actor.sample(state)
- return action.cpu().data.numpy().flatten()
- def update(self, replay_buffer, batch_size=256):
- state, action, next_state, reward, done = replay_buffer.sample(batch_size)
- state = torch.FloatTensor(state)
- action = torch.FloatTensor(action)
- next_state = torch.FloatTensor(next_state)
- reward = torch.FloatTensor(reward).unsqueeze(1)
- done = torch.FloatTensor(done).unsqueeze(1)
- with torch.no_grad():
- next_action, next_log_prob = self.actor.sample(next_state)
- target_q1 = self.target_critic1(next_state, next_action)
- target_q2 = self.target_critic2(next_state, next_action)
- target_q = torch.min(target_q1, target_q2) - self.log_alpha.exp() * next_log_prob
- target_q = reward + (1 - done) * self.gamma * target_q
- current_q1 = self.critic1(state, action)
- current_q2 = self.critic2(state, action)
- critic1_loss = F.mse_loss(current_q1, target_q)
- critic2_loss = F.mse_loss(current_q2, target_q)
- self.critic1_optimizer.zero_grad()
- critic1_loss.backward()
- self.critic1_optimizer.step()
- self.critic2_optimizer.zero_grad()
- critic2_loss.backward()
- self.critic2_optimizer.step()
- action_new, log_prob = self.actor.sample(state)
- q1_new = self.critic1(state, action_new)
- q2_new = self.critic2(state, action_new)
- q_new = torch.min(q1_new, q2_new)
- actor_loss = (self.log_alpha.exp() * log_prob - q_new).mean()
- self.actor_optimizer.zero_grad()
- actor_loss.backward()
- self.actor_optimizer.step()
- alpha_loss = -(self.log_alpha * (log_prob + 1).detach()).mean()
- self.alpha_optimizer.zero_grad()
- alpha_loss.backward()
- self.alpha_optimizer.step()
- for param, target_param in zip(self.critic1.parameters(), self.target_critic1.parameters()):
- target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
- for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
- target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
- # 简单的 Replay Buffer
- class ReplayBuffer:
- def __init__(self, max_size=1e6):
- self.buffer = []
- self.max_size = int(max_size) # 将 max_size 转换为整数
- self.ptr = 0
- def add(self, state, action, next_state, reward, done):
- if len(self.buffer) < self.max_size:
- self.buffer.append(None)
- self.buffer[self.ptr] = (state, action, next_state, reward, done)
- self.ptr = (self.ptr + 1) % self.max_size
- def sample(self, batch_size):
- indices = np.random.randint(0, len(self.buffer), batch_size)
- states, actions, next_states, rewards, dones = [], [], [], [], []
- for idx in indices:
- state, action, next_state, reward, done = self.buffer[idx]
- states.append(state)
- actions.append(action)
- next_states.append(next_state)
- rewards.append(reward)
- dones.append(done)
- return np.array(states), np.array(actions), np.array(next_states), np.array(rewards), np.array(dones)
- # 训练 SAC 算法
- env = gym.make('Pendulum-v1')
- state_dim = env.observation_space.shape[0]
- action_dim = env.action_space.shape[0]
- max_action = float(env.action_space.high[0])
- sac = SAC(state_dim, action_dim, max_action)
- replay_buffer = ReplayBuffer()
- max_episodes = 1000
- batch_size = 256
- for episode in range(max_episodes):
- state = env.reset()
- if isinstance(state, tuple): # 如果返回的是元组,提取状态
- state = state[0]
- episode_reward = 0
- done = False
- while not done:
- env.render()
- action = sac.select_action(state)
- next_state, reward, done, info = env.step(action)
- replay_buffer.add(state, action, next_state, reward, done)
- state = next_state
- episode_reward += reward
- if len(replay_buffer.buffer) > batch_size:
- sac.update(replay_buffer, batch_size)
- print(f"Episode {episode + 1}, Reward: {episode_reward}")
- env.close()
复制代码 简介
Soft Actor-Critic (SAC) 是一种基于最大熵(Maximum Entropy)的深度强化学习算法,专为一连动作空间设计。它联合了 Actor-Critic 框架和熵正则化(Entropy Regularization),在探索与利用之间取得了良好的平衡。
图片
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |