IT评测·应用市场-qidao123.com

标题: PyTorch 深度学习实战(15):Twin Delayed DDPG (TD3) 算法 [打印本页]

作者: 农民    时间: 2025-3-17 04:13
标题: PyTorch 深度学习实战(15):Twin Delayed DDPG (TD3) 算法
在上一篇文章中,我们先容了 Deep Deterministic Policy Gradient (DDPG) 算法,并使用它办理了 Pendulum 题目。本文将深入探究 Twin Delayed DDPG (TD3) 算法,这是一种改进的 DDPG 算法,可以或许有用办理 DDPG 中的过估计题目。我们将使用 PyTorch 实现 TD3 算法,并应用于经典的 Pendulum 题目。


一、TD3 算法基础

TD3 是 DDPG 的改进版本,通过引入以下三个关键技术来办理 DDPG 中的过估计题目:
1. TD3 的核心思想


2. TD3 的上风


3. TD3 的算法流程


二、Pendulum 题目实战

我们将使用 PyTorch 实现 TD3 算法,并应用于 Pendulum 题目。目的是控制摆杆使其保持直立。
1. 题目描述

Pendulum 情况的状态空间包括摆杆的角度和角速率。动作空间是一个连续的扭矩值,范围在 −2,2 之间。智能体每保持摆杆直立一步,就会得到一个负的奖励,目的是最大化累积奖励。
2. 实现步骤

3. 代码实现

以下是完备的代码实现:
  1. import gym
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import torch.nn.functional as F
  6. import numpy as np
  7. import random
  8. from collections import deque
  9. import matplotlib.pyplot as plt
  10. plt.rcParams['font.sans-serif'] = ['SimHei']
  11. plt.rcParams['axes.unicode_minus'] = False
  12. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  13. print(f"使用设备: {device}")
  14. env = gym.make('Pendulum-v1')
  15. state_dim = env.observation_space.shape[0]
  16. action_dim = env.action_space.shape[0]
  17. max_action = float(env.action_space.high[0])
  18. SEED = 42
  19. torch.manual_seed(SEED)
  20. np.random.seed(SEED)
  21. random.seed(SEED)
  22. # 改进的 Actor 网络(增加层归一化)
  23. class Actor(nn.Module):
  24.    def __init__(self, state_dim, action_dim, max_action):
  25.        super(Actor, self).__init__()
  26.        self.l1 = nn.Linear(state_dim, 256)
  27.        self.ln1 = nn.LayerNorm(256)
  28.        self.l2 = nn.Linear(256, 256)
  29.        self.ln2 = nn.LayerNorm(256)
  30.        self.l3 = nn.Linear(256, action_dim)
  31.        self.max_action = max_action
  32.    def forward(self, x):
  33.        x = F.relu(self.ln1(self.l1(x)))
  34.        x = F.relu(self.ln2(self.l2(x)))
  35.        x = torch.tanh(self.l3(x)) * self.max_action
  36.        return x
  37. # 改进的 Critic 网络(增加层归一化)
  38. class Critic(nn.Module):
  39.    def __init__(self, state_dim, action_dim):
  40.        super(Critic, self).__init__()
  41.        self.l1 = nn.Linear(state_dim + action_dim, 256)
  42.        self.ln1 = nn.LayerNorm(256)
  43.        self.l2 = nn.Linear(256, 256)
  44.        self.ln2 = nn.LayerNorm(256)
  45.        self.l3 = nn.Linear(256, 1)
  46.    def forward(self, x, u):
  47.        x = F.relu(self.ln1(self.l1(torch.cat([x, u], 1))))
  48.        x = F.relu(self.ln2(self.l2(x)))
  49.        x = self.l3(x)
  50.        return x
  51. class TD3:
  52.    def __init__(self, state_dim, action_dim, max_action):
  53.        self.actor = Actor(state_dim, action_dim, max_action).to(device)
  54.        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
  55.        self.actor_target.load_state_dict(self.actor.state_dict())
  56.        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=3e-4)
  57.        self.critic1 = Critic(state_dim, action_dim).to(device)
  58.        self.critic2 = Critic(state_dim, action_dim).to(device)
  59.        self.critic1_target = Critic(state_dim, action_dim).to(device)
  60.        self.critic2_target = Critic(state_dim, action_dim).to(device)
  61.        self.critic1_target.load_state_dict(self.critic1.state_dict())
  62.        self.critic2_target.load_state_dict(self.critic2.state_dict())
  63.        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=3e-4)
  64.        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=3e-4)
  65.        self.max_action = max_action
  66.        self.replay_buffer = deque(maxlen=1000000)
  67.        self.batch_size = 256
  68.        self.gamma = 0.99
  69.        self.tau = 0.005
  70.        self.policy_noise = 0.2
  71.        self.noise_clip = 0.5
  72.        self.policy_freq = 2
  73.        self.total_it = 0
  74.        self.exploration_noise = 0.1  # 新增探索噪声
  75.    def select_action(self, state, add_noise=True):
  76.        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
  77.        action = self.actor(state).cpu().data.numpy().flatten()
  78.        if add_noise:
  79.            noise = np.random.normal(0, self.exploration_noise, size=action_dim)
  80.            action = (action + noise).clip(-self.max_action, self.max_action)
  81.        return action
  82.    def train(self):
  83.        if len(self.replay_buffer) < self.batch_size:
  84.            return
  85.        self.total_it += 1
  86.        batch = random.sample(self.replay_buffer, self.batch_size)
  87.        state = torch.FloatTensor(np.array([t[0] for t in batch])).to(device)
  88.        action = torch.FloatTensor(np.array([t[1] for t in batch])).to(device)
  89.        reward = torch.FloatTensor(np.array([t[2] for t in batch])).reshape(-1, 1).to(device) / 10.0  # 奖励缩放
  90.        next_state = torch.FloatTensor(np.array([t[3] for t in batch])).to(device)
  91.        done = torch.FloatTensor(np.array([t[4] for t in batch])).reshape(-1, 1).to(device)
  92.        with torch.no_grad():
  93.            noise = (torch.randn_like(action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
  94.            next_action = (self.actor_target(next_state) + noise).clamp(-self.max_action, self.max_action)
  95.            target_Q1 = self.critic1_target(next_state, next_action)
  96.            target_Q2 = self.critic2_target(next_state, next_action)
  97.            target_Q = torch.min(target_Q1, target_Q2)
  98.            target_Q = reward + (1 - done) * self.gamma * target_Q
  99.        current_Q1 = self.critic1(state, action)
  100.        current_Q2 = self.critic2(state, action)
  101.        critic1_loss = F.mse_loss(current_Q1, target_Q)
  102.        critic2_loss = F.mse_loss(current_Q2, target_Q)
  103.        self.critic1_optimizer.zero_grad()
  104.        critic1_loss.backward()
  105.        torch.nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)  # 梯度裁剪
  106.        self.critic1_optimizer.step()
  107.        self.critic2_optimizer.zero_grad()
  108.        critic2_loss.backward()
  109.        torch.nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
  110.        self.critic2_optimizer.step()
  111.        if self.total_it % self.policy_freq == 0:
  112.            actor_loss = -self.critic1(state, self.actor(state)).mean()
  113.            self.actor_optimizer.zero_grad()
  114.            actor_loss.backward()
  115.            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
  116.            self.actor_optimizer.step()
  117.            for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
  118.                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  119.            for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
  120.                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  121.            for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
  122.                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
  123.    def save(self, filename):
  124.        torch.save(self.actor.state_dict(), filename + "_actor.pth")
  125. def train_td3(env, agent, episodes=2000, early_stop_threshold=-150):
  126.    rewards_history = []
  127.    moving_avg = []
  128.    best_avg = -np.inf
  129.    for ep in range(episodes):
  130.        state,_ = env.reset()
  131.        episode_reward = 0
  132.        done = False
  133.        step = 0
  134.        while not done:
  135.            # 线性衰减探索噪声
  136.            if ep < 300:
  137.                agent.exploration_noise = max(0.5 * (1 - ep / 300), 0.1)
  138.            else:
  139.                agent.exploration_noise = 0.1
  140.            action = agent.select_action(state, add_noise=(ep < 100))  # 前100轮强制探索
  141.            next_state, reward, done, _, _ = env.step(action)
  142.            agent.replay_buffer.append((state, action, reward, next_state, done))
  143.            state = next_state
  144.            episode_reward += reward
  145.            agent.train()
  146.            step += 1
  147.        rewards_history.append(episode_reward)
  148.        current_avg = np.mean(rewards_history[-50:])
  149.        moving_avg.append(current_avg)
  150.        if current_avg > best_avg:
  151.            best_avg = current_avg
  152.            agent.save("td3_pendulum_best")
  153.        if (ep + 1) % 50 == 0:
  154.            print(f"Episode: {ep + 1}, Avg Reward: {current_avg:.2f}")
  155.        # 早停机制
  156.        if current_avg >= early_stop_threshold:
  157.            print(f"早停触发,平均奖励达到 {current_avg:.2f}")
  158.            break
  159.    return moving_avg, rewards_history
  160. # 训练并可视化
  161. td3_agent = TD3(state_dim, action_dim, max_action)
  162. moving_avg, rewards_history = train_td3(env, td3_agent, episodes=2000)
  163. # 可视化结果
  164. plt.figure(figsize=(12, 6))
  165. plt.plot(rewards_history, alpha=0.6, label='single round reward')
  166. plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
  167. plt.xlabel('episodes')
  168. plt.ylabel('reward')
  169. plt.title('TD3 training performance on Pendulum-v1')
  170. plt.legend()
  171. plt.grid(True)
  172. plt.show()
复制代码

三、代码剖析


四、运行结果

运行上述代码后,你将看到以下输出:



五、总结

本文先容了 TD3 算法的基本原理,并使用 PyTorch 实现了一个简单的 TD3 模型来办理 Pendulum 题目。通过这个例子,我们学习了如何使用 TD3 算法进行连续动作空间的战略优化。
在下一篇文章中,我们将探究更高级的强化学习算法,如 Soft Actor-Critic (SAC)。敬请期待!
代码实例说明

希望这篇文章能帮助你更好地明白 TD3 算法!假如有任何题目,接待在批评区留言讨论。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com (https://dis.qidao123.com/) Powered by Discuz! X3.4