强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN

嚴華  论坛元老 | 2025-1-16 01:18:10 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1050|帖子 1050|积分 3150

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
在 CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们必要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中太过估计的偏差。
下面是一个基于 PyTorch 的 Double DQN 实现:
1. 导入依靠

  1. import random
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import numpy as np
  6. import gym
  7. from collections import deque
复制代码
2. 定义 Q 网络

我们必要定义一个 Q 网络,用于盘算 Q 值。这里使用简单的全毗连网络。
  1. class QNetwork(nn.Module):
  2.     def __init__(self, state_dim, action_dim):
  3.         super(QNetwork, self).__init__()
  4.         self.fc1 = nn.Linear(state_dim, 128)
  5.         self.fc2 = nn.Linear(128, 128)
  6.         self.fc3 = nn.Linear(128, action_dim)
  7.     def forward(self, x):
  8.         x = torch.relu(self.fc1(x))
  9.         x = torch.relu(self.fc2(x))
  10.         return self.fc3(x)
复制代码
3. 创建 Agent

  1. class DoubleDQNAgent:
  2.     def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
  3.         self.state_dim = state_dim
  4.         self.action_dim = action_dim
  5.         self.gamma = gamma
  6.         self.epsilon = epsilon
  7.         self.epsilon_decay = epsilon_decay
  8.         self.epsilon_min = epsilon_min
  9.         self.lr = lr
  10.         self.q_network = QNetwork(state_dim, action_dim)
  11.         self.target_network = QNetwork(state_dim, action_dim)
  12.         self.target_network.load_state_dict(self.q_network.state_dict())
  13.         self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
  14.         self.memory = deque(maxlen=10000)
  15.         self.batch_size = 64
  16.     def select_action(self, state):
  17.         if random.random() < self.epsilon:
  18.             return random.choice(range(self.action_dim))  # Explore
  19.         else:
  20.             state = torch.FloatTensor(state).unsqueeze(0)
  21.             with torch.no_grad():
  22.                 q_values = self.q_network(state)
  23.             return torch.argmax(q_values).item()  # Exploit
  24.     def store_experience(self, state, action, reward, next_state, done):
  25.         self.memory.append((state, action, reward, next_state, done))
  26.     def sample_batch(self):
  27.         return random.sample(self.memory, self.batch_size)
  28.     def update_target_network(self):
  29.         self.target_network.load_state_dict(self.q_network.state_dict())
  30.     def train(self):
  31.         if len(self.memory) < self.batch_size:
  32.             return
  33.         batch = self.sample_batch()
  34.         states, actions, rewards, next_states, dones = zip(*batch)
  35.         states = torch.FloatTensor(states)
  36.         actions = torch.LongTensor(actions)
  37.         rewards = torch.FloatTensor(rewards)
  38.         next_states = torch.FloatTensor(next_states)
  39.         dones = torch.FloatTensor(dones)
  40.         # Q values for current states
  41.         q_values = self.q_network(states)
  42.         q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
  43.         # Next Q values using target network
  44.         next_q_values = self.target_network(next_states)
  45.         next_actions = self.q_network(next_states).argmax(1)
  46.         next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
  47.         # Double DQN update
  48.         target = rewards + (1 - dones) * self.gamma * next_q_values
  49.         # Compute loss
  50.         loss = nn.MSELoss()(q_values, target)
  51.         # Optimize the Q-network
  52.         self.optimizer.zero_grad()
  53.         loss.backward()
  54.         self.optimizer.step()
  55.         if self.epsilon > self.epsilon_min:
  56.             self.epsilon *= self.epsilon_decay
复制代码
4. 训练过程

  1. def train_cartpole():
  2.     env = gym.make('CartPole-v1')
  3.     agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
  4.     episodes = 1000
  5.     for episode in range(episodes):
  6.         state, info = env.reset()
  7.         done = False
  8.         total_reward = 0
  9.         while not done:
  10.             action = agent.select_action(state)
  11.             next_state, reward, done, truncated, info = env.step(action)
  12.             agent.store_experience(state, action, reward, next_state, done)
  13.             state = next_state
  14.             agent.train()
  15.             total_reward += reward
  16.         agent.update_target_network()
  17.         print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")
  18.     env.close()
  19. if __name__ == '__main__':
  20.     train_cartpole()
复制代码
5. 表明



  • QNetwork: 使用一个简单的 3 层全毗连神经网络来近似 Q 函数。
  • DoubleDQNAgent:

    • select_action: 根据 ε-greedy 策略选择动作。
    • store_experience: 存储履历回放。
    • sample_batch: 从记忆中随机采样批次。
    • train: 更新 Q 网络的权重,使用 Double DQN 的目标盘算方法。
    • update_target_network: 每一定步数更新目标网络。

  • 训练过程: 在每一回合中,署理与情况互动并更新 Q 网络,通过履历回放机制逐步学习。
6. 调整超参数



  • gamma: 扣头因子,控制未来嘉奖的影响。
  • epsilon: 初始的探索率,随着训练的举行逐渐减小。
  • lr: 学习率,控制权重更新的步伐。
  • batch_size: 每次更新时,从记忆库中采样的批量巨细。
这个代码可以直接用于训练一个 CartPole 的 Double DQN 署理,逐步优化 Q 网络来完成游戏使命。假如你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

嚴華

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表