百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(11):强化学习与深度 Q 网络(DQN)

ztj100 2025-04-26 22:45 35 浏览 0 评论

在之前的文章中,我们介绍了神经网络、卷积神经网络(CNN)、循环神经网络(RNN)、Transformer 等多种深度学习模型,并应用于图像分类、文本分类、时间序列预测等任务。本文将介绍强化学习的基本概念,并使用 PyTorch 实现一个经典的深度 Q 网络(DQN)来解决强化学习中的经典问题——CartPole。

一、强化学习基础

强化学习(Reinforcement Learning, RL)是机器学习的一个重要分支,它通过智能体(Agent)与环境(Environment)的交互来学习策略,以最大化累积奖励。强化学习的核心思想是通过试错来学习,智能体在环境中采取行动,观察结果,并根据奖励信号调整策略。

1. 强化学习的基本要素

  • 智能体(Agent):学习并做出决策的主体。
  • 环境(Environment):智能体交互的外部世界。
  • 状态(State):环境在某一时刻的描述。
  • 动作(Action):智能体在某一状态下采取的行动。
  • 奖励(Reward):智能体采取动作后,环境返回的反馈信号。
  • 策略(Policy):智能体在给定状态下选择动作的规则。
  • 价值函数(Value Function):评估在某一状态下采取某一动作的长期回报。

2. Q-Learning 与深度 Q 网络(DQN)

Q-Learning 是一种经典的强化学习算法,它通过学习一个 Q 函数来评估在某一状态下采取某一动作的长期回报。Q 函数的更新公式为:

深度 Q 网络(DQN)将 Q-Learning 与深度学习结合,使用神经网络来近似 Q 函数。DQN 通过经验回放(Experience Replay)和目标网络(Target Network)来稳定训练过程。

二、CartPole 问题实战

CartPole 是强化学习中的经典问题,目标是控制一个小车(Cart)使其上的杆子(Pole)保持直立。我们将使用 PyTorch 实现一个 DQN 来解决这个问题。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义 DQN 模型。
  3. 定义经验回放缓冲区。
  4. 定义 DQN 训练过程。
  5. 测试模型并评估性能。

3. 代码实现

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置字体为 SimHei(黑体)
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 安装并导入必要的库
env = gym.make('CartPole-v1')

# 2. 定义 DQN 模型
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 3. 定义经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)

    def __len__(self):
        return len(self.buffer)

# 4. 定义 DQN 训练过程
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
model = DQN(state_size, action_size)
target_model = DQN(state_size, action_size)
target_model.load_state_dict(model.state_dict())
optimizer = optim.Adam(model.parameters(), lr=0.001)
buffer = ReplayBuffer(10000)

def train(batch_size, gamma=0.99):
    if len(buffer) < batch_size:
        return
    state, action, reward, next_state, done = buffer.sample(batch_size)
    state = torch.FloatTensor(state)
    next_state = torch.FloatTensor(next_state)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    done = torch.FloatTensor(done)

    q_values = model(state)
    next_q_values = target_model(next_state)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    next_q_value = next_q_values.max(1)[0]
    expected_q_value = reward + gamma * next_q_value * (1 - done)

    loss = nn.MSELoss()(q_value, expected_q_value.detach())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 5. 测试模型并评估性能
def test(env, model, episodes=10):
    total_reward = 0
    for _ in range(episodes):
        state = env.reset()
        done = False
        while not done:
            state = torch.FloatTensor(state).unsqueeze(0)
            action = model(state).max(1)[1].item()
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            state = next_state
    return total_reward / episodes

# 训练过程
episodes = 500
batch_size = 64
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
rewards = []

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = model(state_tensor).max(1)[1].item()

        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

        train(batch_size, gamma)

    epsilon = max(epsilon_min, epsilon * epsilon_decay)
    rewards.append(total_reward)

    if (episode + 1) % 50 == 0:
        avg_reward = test(env, model)
        print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")

# 6. 可视化训练结果
plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("DQN 训练过程")
plt.show()

三、代码解析

1.环境与模型定义

  • 使用 gym 创建 CartPole 环境。
  • 定义 DQN 模型,包含三个全连接层。

2.经验回放缓冲区

  • 使用 deque 实现经验回放缓冲区,存储状态、动作、奖励等信息。

3.训练过程

  • 使用 epsilon-greedy 策略进行探索与利用。
  • 通过经验回放缓冲区采样数据进行训练,更新模型参数。

4.测试过程

  • 在测试环境中评估模型性能,计算平均奖励。

5.可视化

  • 绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。


五、总结

本文介绍了强化学习的基本概念,并使用 PyTorch 实现了一个深度 Q 网络(DQN)来解决 CartPole 问题。通过这个例子,我们学习了如何定义 DQN 模型、使用经验回放缓冲区、训练模型以及评估性能。

在下一篇文章中,我们将探讨更复杂的强化学习算法,如 Actor-Critic 和 Proximal Policy Optimization (PPO)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解强化学习的基础知识!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

30天学会Python编程:16. Python常用标准库使用教程

16.1collections模块16.1.1高级数据结构16.1.2示例...

强烈推荐!Python 这个宝藏库 re 正则匹配

Python的re模块(RegularExpression正则表达式)提供各种正则表达式的匹配操作。...

Python爬虫中正则表达式的用法,只讲如何应用,不讲原理

Python爬虫:正则的用法(非原理)。大家好,这节课给大家讲正则的实际用法,不讲原理,通俗易懂的讲如何用正则抓取内容。·导入re库,这里是需要从html这段字符串中提取出中间的那几个文字。实例一个对...

Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)

实现功能:Python数据分析实战-利用正则表达式提取文本中的URL网址和邮箱...

python爬虫教程之爬取当当网 Top 500 本五星好评书籍

我们使用requests和re来写一个爬虫作为一个爱看书的你(说的跟真的似的)怎么能发现好书呢?所以我们爬取当当网的前500本好五星评书籍怎么样?ok接下来就是学习python的正确姿...

深入理解re模块:Python中的正则表达式神器解析

在Python中,"re"是一个强大的模块,用于处理正则表达式(regularexpressions)。正则表达式是一种强大的文本模式匹配工具,用于在字符串中查找、替换或提取特定模式...

如何使用正则表达式和 Python 匹配不以模式开头的字符串

需要在Python中使用正则表达式来匹配不以给定模式开头的字符串吗?如果是这样,你可以使用下面的语法来查找所有的字符串,除了那些不以https开始的字符串。r"^(?!https).*&...

先Mark后用!8分钟读懂 Python 性能优化

从本文总结了Python开发时,遇到的性能优化问题的定位和解决。概述:性能优化的原则——优化需要优化的部分。性能优化的一般步骤:首先,让你的程序跑起来结果一切正常。然后,运行这个结果正常的代码,看看它...

Python“三步”即可爬取,毋庸置疑

声明:本实例仅供学习,切忌遵守robots协议,请不要使用多线程等方式频繁访问网站。#第一步导入模块importreimportrequests#第二步获取你想爬取的网页地址,发送请求,获取网页内...

简单学Python——re库(正则表达式)2(split、findall、和sub)

1、split():分割字符串,返回列表语法:re.split('分隔符','目标字符串')例如:importrere.split(',','...

Lavazza拉瓦萨再度牵手上海大师赛

阅读此文前,麻烦您点击一下“关注”,方便您进行讨论和分享。Lavazza拉瓦萨再度牵手上海大师赛标题:2024上海大师赛:网球与咖啡的浪漫邂逅在2024年的上海劳力士大师赛上,拉瓦萨咖啡再次成为官...

ArkUI-X构建Android平台AAR及使用

本教程主要讲述如何利用ArkUI-XSDK完成AndroidAAR开发,实现基于ArkTS的声明式开发范式在android平台显示。包括:1.跨平台Library工程开发介绍...

Deepseek写歌详细教程(怎样用deepseek写歌功能)

以下为结合DeepSeek及相关工具实现AI写歌的详细教程,涵盖作词、作曲、演唱全流程:一、核心流程三步法1.AI生成歌词-打开DeepSeek(网页/APP/API),使用结构化提示词生成歌词:...

“AI说唱解说影视”走红,“零基础入行”靠谱吗?本报记者实测

“手里翻找冻鱼,精心的布局;老漠却不言语,脸上带笑意……”《狂飙》剧情被写成歌词,再配上“科目三”背景音乐的演唱,这段1分钟30秒的视频受到了无数网友的点赞。最近一段时间随着AI技术的发展,说唱解说影...

AI音乐制作神器揭秘!3款工具让你秒变高手

在音乐创作的领域里,每个人都有一颗想要成为大师的心。但是面对复杂的乐理知识和繁复的制作过程,许多人的热情被一点点消磨。...

取消回复欢迎 发表评论: