马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
在 CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们必要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中太过估计的偏差。
下面是一个基于 PyTorch 的 Double DQN 实现:
1. 导入依靠
- import random
- import torch
- import torch.nn as nn
- import torch.optim as optim
- import numpy as np
- import gym
- from collections import deque
复制代码 2. 定义 Q 网络
我们必要定义一个 Q 网络,用于盘算 Q 值。这里使用简单的全毗连网络。
- class QNetwork(nn.Module):
- def __init__(self, state_dim, action_dim):
- super(QNetwork, self).__init__()
- self.fc1 = nn.Linear(state_dim, 128)
- self.fc2 = nn.Linear(128, 128)
- self.fc3 = nn.Linear(128, action_dim)
- def forward(self, x):
- x = torch.relu(self.fc1(x))
- x = torch.relu(self.fc2(x))
- return self.fc3(x)
复制代码 3. 创建 Agent
- class DoubleDQNAgent:
- def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
- self.state_dim = state_dim
- self.action_dim = action_dim
- self.gamma = gamma
- self.epsilon = epsilon
- self.epsilon_decay = epsilon_decay
- self.epsilon_min = epsilon_min
- self.lr = lr
- self.q_network = QNetwork(state_dim, action_dim)
- self.target_network = QNetwork(state_dim, action_dim)
- self.target_network.load_state_dict(self.q_network.state_dict())
- self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
- self.memory = deque(maxlen=10000)
- self.batch_size = 64
- def select_action(self, state):
- if random.random() < self.epsilon:
- return random.choice(range(self.action_dim)) # Explore
- else:
- state = torch.FloatTensor(state).unsqueeze(0)
- with torch.no_grad():
- q_values = self.q_network(state)
- return torch.argmax(q_values).item() # Exploit
- def store_experience(self, state, action, reward, next_state, done):
- self.memory.append((state, action, reward, next_state, done))
- def sample_batch(self):
- return random.sample(self.memory, self.batch_size)
- def update_target_network(self):
- self.target_network.load_state_dict(self.q_network.state_dict())
- def train(self):
- if len(self.memory) < self.batch_size:
- return
- batch = self.sample_batch()
- states, actions, rewards, next_states, dones = zip(*batch)
- states = torch.FloatTensor(states)
- actions = torch.LongTensor(actions)
- rewards = torch.FloatTensor(rewards)
- next_states = torch.FloatTensor(next_states)
- dones = torch.FloatTensor(dones)
- # Q values for current states
- q_values = self.q_network(states)
- q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
- # Next Q values using target network
- next_q_values = self.target_network(next_states)
- next_actions = self.q_network(next_states).argmax(1)
- next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
- # Double DQN update
- target = rewards + (1 - dones) * self.gamma * next_q_values
- # Compute loss
- loss = nn.MSELoss()(q_values, target)
- # Optimize the Q-network
- self.optimizer.zero_grad()
- loss.backward()
- self.optimizer.step()
- if self.epsilon > self.epsilon_min:
- self.epsilon *= self.epsilon_decay
复制代码 4. 训练过程
- def train_cartpole():
- env = gym.make('CartPole-v1')
- agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
- episodes = 1000
- for episode in range(episodes):
- state, info = env.reset()
- done = False
- total_reward = 0
- while not done:
- action = agent.select_action(state)
- next_state, reward, done, truncated, info = env.step(action)
- agent.store_experience(state, action, reward, next_state, done)
- state = next_state
- agent.train()
- total_reward += reward
- agent.update_target_network()
- print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")
- env.close()
- if __name__ == '__main__':
- train_cartpole()
复制代码 5. 表明
- QNetwork: 使用一个简单的 3 层全毗连神经网络来近似 Q 函数。
- DoubleDQNAgent:
- select_action: 根据 ε-greedy 策略选择动作。
- store_experience: 存储履历回放。
- sample_batch: 从记忆中随机采样批次。
- train: 更新 Q 网络的权重,使用 Double DQN 的目标盘算方法。
- update_target_network: 每一定步数更新目标网络。
- 训练过程: 在每一回合中,署理与情况互动并更新 Q 网络,通过履历回放机制逐步学习。
6. 调整超参数
- gamma: 扣头因子,控制未来嘉奖的影响。
- epsilon: 初始的探索率,随着训练的举行逐渐减小。
- lr: 学习率,控制权重更新的步伐。
- batch_size: 每次更新时,从记忆库中采样的批量巨细。
这个代码可以直接用于训练一个 CartPole 的 Double DQN 署理,逐步优化 Q 网络来完成游戏使命。假如你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |