百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)

ztj100 2025-04-26 22:45 40 浏览 0 评论

在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念,并使用 PyTorch 实现一个经典的深度 Q 网络(DQN)来解决强化学习中的经典问题——CartPole。

一、强化学习基础

强化学习(Reinforcement Learning, RL)是机器学习的一个重要分支,它通过智能体(Agent)与环境(Environment)的交互来学习策略,以最大化累积奖励。强化学习的核心思想是通过试错来学习,智能体在环境中采取行动,观察结果,并根据奖励信号调整策略。

1. 强化学习的基本要素

  • 智能体(Agent):学习并做出决策的主体。
  • 环境(Environment):智能体交互的外部世界。
  • 状态(State):环境在某一时刻的描述。
  • 动作(Action):智能体在某一状态下采取的行动。
  • 奖励(Reward):智能体采取动作后,环境返回的反馈信号。
  • 策略(Policy):智能体在给定状态下选择动作的规则。
  • 价值函数(Value Function):评估在某一状态下采取某一动作的长期回报。

2. Q-Learning 与深度 Q 网络(DQN)

Q-Learning 是一种经典的强化学习算法,它通过学习一个 Q 函数来评估在某一状态下采取某一动作的长期回报。Q 函数的更新公式为:

深度 Q 网络(DQN)将 Q-Learning 与深度学习结合,使用神经网络来近似 Q 函数。DQN 通过经验回放(Experience Replay)和目标网络(Target Network)来稳定训练过程。

二、CartPole 问题实战

CartPole 是强化学习中的经典问题,目标是控制一个小车(Cart)使其上的杆子(Pole)保持直立。我们将使用 PyTorch 实现一个 DQN 来解决这个问题。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义 DQN 模型。
  3. 定义经验回放缓冲区。
  4. 定义 DQN 训练过程。
  5. 测试模型并评估性能。

3. 代码实现

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 安装并导入必要的库
env = gym.make('CartPole-v1')

# 2. 定义 DQN 模型
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 3. 定义经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)

    def __len__(self):
        return len(self.buffer)

# 4. 定义 DQN 训练过程
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
target_model = DQN(state_size, action_size)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)

def train(batch_size, gamma=0.99):
    if len(buffer) < batch_size:
        return
    state, action, reward, next_state, done = buffer.sample(batch_size)
    state = torch.FloatTensor(state)
    next_state = torch.FloatTensor(next_state)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    done = torch.FloatTensor(done)

    q_values = model(state)
    next_q_values = target_model(next_state)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)

    loss = nn.MSELoss()(q_value, expected_q_value.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 5. 测试模型并评估性能
def test(env, model, episodes=10):
    total_reward = 0
    for _ in range(episodes):
        state = env.reset()
        done = False
        while not done:
            state = torch.FloatTensor(state).unsqueeze(0)
            action = model(state).max(1)[1].item()
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            state = next_state
    return total_reward / episodes

# 训练过程
episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
rewards = []

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = model(state_tensor).max(1)[1].item()

        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

        train(batch_size, gamma)

    epsilon = max(epsilon_min, epsilon * epsilon_decay)
    rewards.append(total_reward)

    if (episode + 1) % 50 == 0:
        avg_reward = test(env, model)
        print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")

# 6. 可视化训练结果
plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("DQN 训练过程")
plt.show()

三、代码解析

1.环境与模型定义

  • 使用 gym 创建 CartPole 环境。
  • 定义 DQN 模型,包含三个全连接层。

2.经验回放缓冲区

  • 使用 deque 实现经验回放缓冲区,存储状态、动作、奖励等信息。

3.训练过程

  • 使用 epsilon-greedy 策略进行探索与利用。
  • 通过经验回放缓冲区采样数据进行训练,更新模型参数。

4.测试过程

  • 在测试环境中评估模型性能,计算平均奖励。

5.可视化

  • 绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。


五、总结

本文介绍了强化学习的基本概念,并使用 PyTorch 实现了一个深度 Q 网络(DQN)来解决 CartPole 问题。通过这个例子,我们学习了如何定义 DQN 模型、使用经验回放缓冲区、训练模型以及评估性能。

在下一篇文章中,我们将探讨更复杂的强化学习算法,如 Actor-Critic 和 Proximal Policy Optimization (PPO)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解强化学习的基础知识!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

其实TensorFlow真的很水无非就这30篇熬夜练

好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...

交叉验证和超参数调整:如何优化你的机器学习模型

准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...

机器学习交叉验证全指南:原理、类型与实战技巧

机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...

深度学习中的类别激活热图可视化

作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...

超强,必会的机器学习评估指标

大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...

机器学习入门教程-第六课:监督学习与非监督学习

1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...

Python教程(三十八):机器学习基础

...

Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置

你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...

超全面讲透一个算法模型,高斯核!!

...

神经网络与传统统计方法的简单对比

传统的统计方法如...

AI 基础知识从0.1到0.2——用“房价预测”入门机器学习全流程

...

自回归滞后模型进行多变量时间序列预测

下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...

苹果AI策略:慢哲学——科技行业的“长期主义”试金石

苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...

时间序列预测全攻略,6大模型代码实操

如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...

AI 基础知识从 0.4 到 0.5—— 计算机视觉之光 CNN

...

取消回复欢迎 发表评论: