IT评测·应用市场-qidao123.com技术社区

标题: 强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN [打印本页]

作者: 嚴華    时间: 2025-1-16 01:18
标题: 强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN
在 CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们必要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中太过估计的偏差。
下面是一个基于 PyTorch 的 Double DQN 实现:
1. 导入依靠

  1. import random
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import numpy as np
  6. import gym
  7. from collections import deque
复制代码
2. 定义 Q 网络

我们必要定义一个 Q 网络,用于盘算 Q 值。这里使用简单的全毗连网络。
  1. class QNetwork(nn.Module):
  2.     def __init__(self, state_dim, action_dim):
  3.         super(QNetwork, self).__init__()
  4.         self.fc1 = nn.Linear(state_dim, 128)
  5.         self.fc2 = nn.Linear(128, 128)
  6.         self.fc3 = nn.Linear(128, action_dim)
  7.     def forward(self, x):
  8.         x = torch.relu(self.fc1(x))
  9.         x = torch.relu(self.fc2(x))
  10.         return self.fc3(x)
复制代码
3. 创建 Agent

  1. class DoubleDQNAgent:
  2.     def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
  3.         self.state_dim = state_dim
  4.         self.action_dim = action_dim
  5.         self.gamma = gamma
  6.         self.epsilon = epsilon
  7.         self.epsilon_decay = epsilon_decay
  8.         self.epsilon_min = epsilon_min
  9.         self.lr = lr
  10.         self.q_network = QNetwork(state_dim, action_dim)
  11.         self.target_network = QNetwork(state_dim, action_dim)
  12.         self.target_network.load_state_dict(self.q_network.state_dict())
  13.         self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
  14.         self.memory = deque(maxlen=10000)
  15.         self.batch_size = 64
  16.     def select_action(self, state):
  17.         if random.random() < self.epsilon:
  18.             return random.choice(range(self.action_dim))  # Explore
  19.         else:
  20.             state = torch.FloatTensor(state).unsqueeze(0)
  21.             with torch.no_grad():
  22.                 q_values = self.q_network(state)
  23.             return torch.argmax(q_values).item()  # Exploit
  24.     def store_experience(self, state, action, reward, next_state, done):
  25.         self.memory.append((state, action, reward, next_state, done))
  26.     def sample_batch(self):
  27.         return random.sample(self.memory, self.batch_size)
  28.     def update_target_network(self):
  29.         self.target_network.load_state_dict(self.q_network.state_dict())
  30.     def train(self):
  31.         if len(self.memory) < self.batch_size:
  32.             return
  33.         batch = self.sample_batch()
  34.         states, actions, rewards, next_states, dones = zip(*batch)
  35.         states = torch.FloatTensor(states)
  36.         actions = torch.LongTensor(actions)
  37.         rewards = torch.FloatTensor(rewards)
  38.         next_states = torch.FloatTensor(next_states)
  39.         dones = torch.FloatTensor(dones)
  40.         # Q values for current states
  41.         q_values = self.q_network(states)
  42.         q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
  43.         # Next Q values using target network
  44.         next_q_values = self.target_network(next_states)
  45.         next_actions = self.q_network(next_states).argmax(1)
  46.         next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
  47.         # Double DQN update
  48.         target = rewards + (1 - dones) * self.gamma * next_q_values
  49.         # Compute loss
  50.         loss = nn.MSELoss()(q_values, target)
  51.         # Optimize the Q-network
  52.         self.optimizer.zero_grad()
  53.         loss.backward()
  54.         self.optimizer.step()
  55.         if self.epsilon > self.epsilon_min:
  56.             self.epsilon *= self.epsilon_decay
复制代码
4. 训练过程

  1. def train_cartpole():
  2.     env = gym.make('CartPole-v1')
  3.     agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
  4.     episodes = 1000
  5.     for episode in range(episodes):
  6.         state, info = env.reset()
  7.         done = False
  8.         total_reward = 0
  9.         while not done:
  10.             action = agent.select_action(state)
  11.             next_state, reward, done, truncated, info = env.step(action)
  12.             agent.store_experience(state, action, reward, next_state, done)
  13.             state = next_state
  14.             agent.train()
  15.             total_reward += reward
  16.         agent.update_target_network()
  17.         print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")
  18.     env.close()
  19. if __name__ == '__main__':
  20.     train_cartpole()
复制代码
5. 表明


6. 调整超参数


这个代码可以直接用于训练一个 CartPole 的 Double DQN 署理,逐步优化 Q 网络来完成游戏使命。假如你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4