百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(14):Deep Deterministic Policy Gradient

ztj100 2025-04-26 22:46 10 浏览 0 评论

在上一篇文章中,我们介绍了 Proximal Policy Optimization (PPO) 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Deep Deterministic Policy Gradient (DDPG) 算法,这是一种用于连续动作空间的强化学习算法。我们将使用 PyTorch 实现 DDPG 算法,并应用于经典的 Pendulum 问题。



一、DDPG 算法基础

DDPG 是一种基于 Actor-Critic 框架的算法,专门用于解决连续动作空间的强化学习问题。它结合了深度 Q 网络(DQN)和策略梯度方法的优点,能够高效地处理高维状态和动作空间。

1. DDPG 的核心思想

  • 确定性策略
    • DDPG 使用确定性策略(Deterministic Policy),即给定状态时,策略网络直接输出一个确定的动作,而不是动作的概率分布。
  • 目标网络
    • DDPG 使用目标网络(Target Network)来稳定训练过程,类似于 DQN 中的目标网络。
  • 经验回放
    • DDPG 使用经验回放缓冲区(Replay Buffer)来存储和重用过去的经验,从而提高数据利用率。

2. DDPG 的优势

  • 适用于连续动作空间
    • DDPG 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。
  • 训练稳定
    • 通过目标网络和经验回放,DDPG 能够稳定地训练策略网络和价值网络。
  • 高效采样
    • DDPG 可以重复使用旧策略的采样数据,从而提高数据利用率。

3. DDPG 的算法流程

  1. 使用当前策略采样一批数据。
  2. 使用目标网络计算目标 Q 值。
  3. 更新 Critic 网络以最小化 Q 值的误差。
  4. 更新 Actor 网络以最大化 Q 值。
  5. 更新目标网络。
  6. 重复上述过程,直到策略收敛。

二、Pendulum 问题实战

我们将使用 PyTorch 实现 DDPG 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。

1. 问题描述

Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 -2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义 Actor 网络和 Critic 网络。
  3. 定义 DDPG 训练过程。
  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 环境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

# 随机种子设置
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


# 定义 Actor 网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 512)
        self.ln1 = nn.LayerNorm(512)  # 层归一化
        self.fc2 = nn.Linear(512, 512)
        self.ln2 = nn.LayerNorm(512)
        self.fc3 = nn.Linear(512, action_dim)
        self.max_action = max_action

    def forward(self, x):
        x = F.relu(self.ln1(self.fc1(x)))
        x = F.relu(self.ln2(self.fc2(x)))
        return self.max_action * torch.tanh(self.fc3(x))


# 定义 Critic 网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x, u):
        x = F.relu(self.fc1(torch.cat([x, u], 1)))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# 添加OU噪声类
class OUNoise:
    def __init__(self, action_dim, mu=0, theta=0.15, sigma=0.2):
        self.mu = mu * np.ones(action_dim)
        self.theta = theta
        self.sigma = sigma
        self.reset()

    def reset(self):
        self.state = np.copy(self.mu)

    def sample(self):
        dx = self.theta * (self.mu - self.state) + self.sigma * np.random.randn(len(self.state))
        self.state += dx
        return self.state


# 定义 DDPG 算法
class DDPG:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-4)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)
        self.noise = OUNoise(action_dim, sigma=0.2)  # 示例:Ornstein-Uhlenbeck噪声

        self.max_action = max_action
        self.replay_buffer = deque(maxlen=1000000)
        self.batch_size = 64
        self.gamma = 0.99
        self.tau = 0.005
        self.noise_sigma = 0.5  # 初始噪声强度
        self.noise_decay = 0.995

        self.actor_lr_scheduler = optim.lr_scheduler.StepLR(self.actor_optimizer, step_size=100, gamma=0.95)
        self.critic_lr_scheduler = optim.lr_scheduler.StepLR(self.critic_optimizer, step_size=100, gamma=0.95)

    def select_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        self.actor.eval()
        with torch.no_grad():
            action = self.actor(state).cpu().data.numpy().flatten()
        self.actor.train()
        return action

    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # 从经验回放缓冲区中采样
        batch = random.sample(self.replay_buffer, self.batch_size)
        state = torch.FloatTensor(np.array([transition[0] for transition in batch])).to(device)
        action = torch.FloatTensor(np.array([transition[1] for transition in batch])).to(device)
        reward = torch.FloatTensor(np.array([transition[2] for transition in batch])).reshape(-1, 1).to(device)
        next_state = torch.FloatTensor(np.array([transition[3] for transition in batch])).to(device)
        done = torch.FloatTensor(np.array([transition[4] for transition in batch])).reshape(-1, 1).to(device)

        # 计算目标 Q 值
        next_action = self.actor_target(next_state)
        target_Q = self.critic_target(next_state, next_action)
        target_Q = reward + (1 - done) * self.gamma * target_Q

        # 更新 Critic 网络
        current_Q = self.critic(state, action)
        critic_loss = F.mse_loss(current_Q, target_Q.detach())
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # 更新 Actor 网络
        actor_loss = -self.critic(state, self.actor(state)).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 更新目标网络
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, filename):
        torch.save(self.actor.state_dict(), filename + "_actor.pth")
        torch.save(self.critic.state_dict(), filename + "_critic.pth")

    def load(self, filename):
        self.actor.load_state_dict(torch.load(filename + "_actor.pth"))
        self.critic.load_state_dict(torch.load(filename + "_critic.pth"))


# 训练流程
def train_ddpg(env, agent, episodes=500):
    rewards_history = []
    moving_avg = []

    for ep in range(episodes):
        state,_ = env.reset()
        episode_reward = 0
        done = False

        while not done:
            action = agent.select_action(state)
            next_state, reward, done, _, _ = env.step(action)
            agent.replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            episode_reward += reward
            agent.train()

        rewards_history.append(episode_reward)
        moving_avg.append(np.mean(rewards_history[-50:]))

        if (ep + 1) % 50 == 0:
            print(f"Episode: {ep + 1}, Avg Reward: {moving_avg[-1]:.2f}")

    return moving_avg, rewards_history


# 训练启动
ddpg_agent = DDPG(state_dim, action_dim, max_action)
moving_avg, rewards_history = train_ddpg(env, ddpg_agent)

# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.6, label='single round reward')
plt.plot(moving_avg, 'r-', linewidth=2, label='moving average (50 rounds)')
plt.xlabel('episodes')
plt.ylabel('reward')
plt.title('DDPG training performance on Pendulum-v1')
plt.legend()
plt.grid(True)
plt.show()

三、代码解析

1.Actor 和 Critic 网络

  1. Actor 网络输出连续动作,通过 tanh 函数将动作限制在 -max_action,max_action 范围内。
  2. Critic 网络输出状态-动作对的 Q 值。

2.DDPG 训练过程

  1. 使用当前策略采样一批数据。
  2. 使用目标网络计算目标 Q 值。
  3. 更新 Critic 网络以最小化 Q 值的误差。
  4. 更新 Actor 网络以最大化 Q 值。
  5. 更新目标网络。

3.训练过程

  1. 在训练过程中,每 50 个 episode 打印一次平均奖励。
  2. 训练结束后,绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。

五、总结

本文介绍了 DDPG 算法的基本原理,并使用 PyTorch 实现了一个简单的 DDPG 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 DDPG 算法进行连续动作空间的策略优化。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Twin Delayed DDPG (TD3)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:actor = actor.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解 DDPG 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL

阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...

Python数据分析:探索性分析

写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...

CSP-J/S冲奖第21天:插入排序

...

C++基础语法梳理:算法丨十大排序算法(二)

本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...

C 语言的标准库有哪些

C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...

[深度学习] ncnn安装和调用基础教程

1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...

用rust实现经典的冒泡排序和快速排序

1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...

ncnn+PPYOLOv2首次结合!全网最详细代码解读来了

编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...

C++特性使用建议

1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...

Qt4/5升级到Qt6吐血经验总结V202308

00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...

到底什么是C++11新特性,请看下文

C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...

掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!

C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...

经典算法——凸包算法

凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...

一起学习c++11——c++11中的新增的容器

c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...

C++ 编程中的一些最佳实践

1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...

取消回复欢迎 发表评论: