IT评测·应用市场-qidao123.com技术社区
标题:
强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN
[打印本页]
作者:
嚴華
时间:
2025-1-16 01:18
标题:
强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN
在 CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们必要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中太过估计的偏差。
下面是一个基于 PyTorch 的 Double DQN 实现:
1. 导入依靠
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque
复制代码
2. 定义 Q 网络
我们必要定义一个 Q 网络,用于盘算 Q 值。这里使用简单的全毗连网络。
class QNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(QNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 128)
self.fc2 = nn.Linear(128, 128)
self.fc3 = nn.Linear(128, action_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
复制代码
3. 创建 Agent
class DoubleDQNAgent:
def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
self.state_dim = state_dim
self.action_dim = action_dim
self.gamma = gamma
self.epsilon = epsilon
self.epsilon_decay = epsilon_decay
self.epsilon_min = epsilon_min
self.lr = lr
self.q_network = QNetwork(state_dim, action_dim)
self.target_network = QNetwork(state_dim, action_dim)
self.target_network.load_state_dict(self.q_network.state_dict())
self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)
self.memory = deque(maxlen=10000)
self.batch_size = 64
def select_action(self, state):
if random.random() < self.epsilon:
return random.choice(range(self.action_dim)) # Explore
else:
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
q_values = self.q_network(state)
return torch.argmax(q_values).item() # Exploit
def store_experience(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done))
def sample_batch(self):
return random.sample(self.memory, self.batch_size)
def update_target_network(self):
self.target_network.load_state_dict(self.q_network.state_dict())
def train(self):
if len(self.memory) < self.batch_size:
return
batch = self.sample_batch()
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Q values for current states
q_values = self.q_network(states)
q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Next Q values using target network
next_q_values = self.target_network(next_states)
next_actions = self.q_network(next_states).argmax(1)
next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)
# Double DQN update
target = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = nn.MSELoss()(q_values, target)
# Optimize the Q-network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
复制代码
4. 训练过程
def train_cartpole():
env = gym.make('CartPole-v1')
agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)
episodes = 1000
for episode in range(episodes):
state, info = env.reset()
done = False
total_reward = 0
while not done:
action = agent.select_action(state)
next_state, reward, done, truncated, info = env.step(action)
agent.store_experience(state, action, reward, next_state, done)
state = next_state
agent.train()
total_reward += reward
agent.update_target_network()
print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")
env.close()
if __name__ == '__main__':
train_cartpole()
复制代码
5. 表明
QNetwork
: 使用一个简单的 3 层全毗连神经网络来近似 Q 函数。
DoubleDQNAgent
:
select_action: 根据 ε-greedy 策略选择动作。
store_experience: 存储履历回放。
sample_batch: 从记忆中随机采样批次。
train: 更新 Q 网络的权重,使用 Double DQN 的目标盘算方法。
update_target_network: 每一定步数更新目标网络。
训练过程
: 在每一回合中,署理与情况互动并更新 Q 网络,通过履历回放机制逐步学习。
6. 调整超参数
gamma: 扣头因子,控制未来嘉奖的影响。
epsilon: 初始的探索率,随着训练的举行逐渐减小。
lr: 学习率,控制权重更新的步伐。
batch_size: 每次更新时,从记忆库中采样的批量巨细。
这个代码可以直接用于训练一个 CartPole 的 Double DQN 署理,逐步优化 Q 网络来完成游戏使命。假如你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/)
Powered by Discuz! X3.4