PyTorch 深度学习实战(13):Proximal Policy Optimization 算法
ztj100 2025-04-26 22:46 9 浏览 0 评论
在上一篇文章中,我们介绍了 Actor-Critic 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Proximal Policy Optimization (PPO) 算法,这是一种更稳定、更高效的策略优化方法。我们将使用 PyTorch 实现 PPO 算法,并应用于经典的 CartPole 问题。
一、PPO 算法基础
PPO 是 OpenAI 提出的一种强化学习算法,旨在解决策略梯度方法中的训练不稳定问题。PPO 通过限制策略更新的幅度,确保每次更新不会偏离当前策略太远,从而提高训练的稳定性。
1. PPO 的核心思想
- 重要性采样:
- PPO 使用重要性采样(Importance Sampling)来估计新策略的期望回报,从而避免重新采样。
- 裁剪目标函数:
- PPO 通过裁剪策略更新的幅度,确保新策略不会偏离旧策略太远。
- 优势函数:
- PPO 使用优势函数(Advantage Function)来评估动作的好坏,从而更高效地更新策略。
2. PPO 的优势
- 训练稳定:
- 通过限制策略更新的幅度,PPO 避免了策略梯度方法中的训练不稳定问题。
- 高效采样:
- PPO 可以重复使用旧策略的采样数据,从而提高数据利用率。
- 适用范围广:
- PPO 可以应用于连续动作空间和离散动作空间的问题。
3. PPO 的算法流程
- 使用当前策略采样一批数据。
- 计算优势函数和旧策略的概率。
- 通过裁剪目标函数更新策略。
- 重复上述过程,直到策略收敛。
二、CartPole 问题实战
我们将使用 PyTorch 实现 PPO 算法,并应用于 CartPole 问题。目标是控制小车使其上的杆子保持直立。
1. 问题描述
CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。
2. 实现步骤
- 安装并导入必要的库。
- 定义策略网络和价值网络。
- 定义 PPO 训练过程。
- 测试模型并评估性能。
3. 代码实现
以下是完整的代码实现:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
# 可视化设置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 环境初始化
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
# 随机种子设置
SEED = 42
# env.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
# 增强型策略网络
class PolicyNetwork(nn.Module):
def __init__(self, state_size, action_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_size, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, action_size)
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
return self.net(x)
# 增强型价值网络
class ValueNetwork(nn.Module):
def __init__(self, state_size):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_size, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.orthogonal_(m.weight, gain=1.0)
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
return self.net(x)
class PPO:
def __init__(self, state_dim, action_dim):
# 超参数设置
self.gamma = 0.995
self.gae_lambda = 0.97
self.clip_ratio = 0.15
self.update_epochs = 4
self.norm_adv = True
# 网络初始化
self.actor = PolicyNetwork(state_dim, action_dim)
self.critic = ValueNetwork(state_dim)
# 优化器设置
self.actor_optim = optim.Adam(self.actor.parameters(), lr=2e-4)
self.critic_optim = optim.Adam(self.critic.parameters(), lr=8e-4)
# 学习率调度
self.actor_scheduler = StepLR(self.actor_optim, step_size=100, gamma=0.95)
self.critic_scheduler = StepLR(self.critic_optim, step_size=100, gamma=0.95)
def get_action(self, state):
# state_tensor = torch.FloatTensor(state) # 如果使用 Gym <0.26
state_tensor = torch.FloatTensor(state).unsqueeze(0) # 如果使用 Gym >=0.26
with torch.no_grad():
logits = self.actor(state_tensor)
probs = torch.softmax(logits, dim=-1)
action = torch.multinomial(probs, 1).item()
return action
def compute_gae(self, rewards, values, next_values, dones):
advantages = np.zeros_like(rewards)
last_gae = 0
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values[t]
last_gae = delta
else:
delta = rewards[t] + self.gamma * next_values[t] - values[t]
last_gae = delta + self.gamma * self.gae_lambda * last_gae
advantages[t] = last_gae
return torch.FloatTensor(advantages)
def update(self, states, actions, rewards, next_states, dones):
# 转换为张量
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(np.array(actions))
rewards = torch.FloatTensor(np.array(rewards))
next_states = torch.FloatTensor(np.array(next_states))
dones = torch.BoolTensor(np.array(dones))
# 计算价值估计
with torch.no_grad():
current_values = self.critic(states).squeeze()
next_values = self.critic(next_states).squeeze()
next_values[dones] = 0.0
# 计算GAE
advantages = self.compute_gae(rewards.numpy(),
current_values.numpy(),
next_values.numpy(),
dones.numpy())
# 标准化优势
if self.norm_adv:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 获取旧策略概率
with torch.no_grad():
old_logits = self.actor(states)
old_probs = torch.softmax(old_logits, dim=-1)
old_log_probs = torch.log(old_probs.gather(1, actions.unsqueeze(1))).squeeze()
# 策略优化
for _ in range(self.update_epochs):
logits = self.actor(states)
probs = torch.softmax(logits, dim=-1)
new_log_probs = torch.log(probs.gather(1, actions.unsqueeze(1))).squeeze()
ratios = torch.exp(new_log_probs - old_log_probs)
clipped_ratios = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio)
policy_loss = -torch.min(ratios * advantages, clipped_ratios * advantages).mean()
self.actor_optim.zero_grad()
policy_loss.backward()
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.6)
self.actor_optim.step()
# 价值优化
for _ in range(self.update_epochs):
current_values = self.critic(states).squeeze()
target_values = rewards + self.gamma * next_values
value_loss = F.mse_loss(current_values, target_values)
self.critic_optim.zero_grad()
value_loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.8)
self.critic_optim.step()
# 更新学习率
self.actor_scheduler.step()
self.critic_scheduler.step()
# 训练流程
def train_ppo(env, agent, episodes=800, early_stop=30):
rewards_history = []
moving_avg = []
best_score = -np.inf
no_improve = 0
for ep in range(episodes):
# 重置环境时确保正确处理返回值
# state = env.reset() # 如果使用 Gym <0.26
state, _ = env.reset() # 如果使用 Gym >=0.26
episode_data = {
'states': [],
'actions': [],
'rewards': [],
'next_states': [],
'dones': []
}
total_reward = 0
done = False
while not done:
action = agent.get_action(state)
# next_state, reward, done, _ = env.step(action) # 如果使用 Gym <0.26
next_state, reward, done, _,_ = env.step(action) # 如果使用 Gym >=0.26
# 处理环境提前终止
if env._elapsed_steps >= env.spec.max_episode_steps:
done = True
reward = 0
# 存储轨迹
episode_data['states'].append(state)
episode_data['actions'].append(action)
episode_data['rewards'].append(reward)
episode_data['next_states'].append(next_state)
episode_data['dones'].append(done)
state = next_state
total_reward += reward
# 更新模型
agent.update(**episode_data)
# 记录训练进度
rewards_history.append(total_reward)
current_avg = np.mean(rewards_history[-50:])
moving_avg.append(current_avg)
# 早停机制
if current_avg > best_score:
best_score = current_avg
no_improve = 0
else:
no_improve += 1
if no_improve >= early_stop:
print(f"早停触发于第{ep + 1}轮,最佳平均奖励: {best_score:.1f}")
break
# 进度报告
if (ep + 1) % 50 == 0:
print(f"Episode {ep + 1:3d} | 当前奖励: {total_reward:4.0f} | 平均奖励: {current_avg:6.1f}")
return moving_avg, rewards_history
# 训练启动
ppo_agent = PPO(state_dim, action_dim)
moving_avg, rewards_history = train_ppo(env, ppo_agent)
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.4, label='单轮奖励')
plt.plot(moving_avg, 'r-', linewidth=2, label='滑动平均(50轮)')
plt.xlabel('训练轮次')
plt.ylabel('奖励')
plt.title('PPO训练进程')
plt.legend()
plt.grid(True)
plt.show()
三、代码解析
1.策略网络和价值网络:
- 策略网络输出动作的 logits(未归一化的概率),通过 F.softmax 转换为概率分布。
- 价值网络输出状态的价值估计。
2.PPO 训练过程:
- 使用当前策略采样一批数据。
- 计算优势函数和旧策略的概率。
- 通过裁剪目标函数更新策略。
3.训练过程:
- 在训练过程中,每 50 个 episode 打印一次平均奖励。
- 训练结束后,绘制训练过程中的总奖励曲线。
四、运行结果
运行上述代码后,你将看到以下输出:
- 训练过程中每 50 个 episode 打印一次平均奖励。
- 训练结束后,绘制训练过程中的总奖励曲线。
Episode 50 | 当前奖励: 29 | 平均奖励: 20.8
Episode 100 | 当前奖励: 132 | 平均奖励: 69.1
Episode 150 | 当前奖励: 499 | 平均奖励: 261.6
Episode 200 | 当前奖励: 295 | 平均奖励: 412.3
Episode 250 | 当前奖励: 499 | 平均奖励: 481.8
早停触发于第290轮,最佳平均奖励: 497.7
五、总结
本文介绍了 PPO 算法的基本原理,并使用 PyTorch 实现了一个简单的 PPO 模型来解决 CartPole 问题。通过这个例子,我们学习了如何使用 PPO 算法进行策略优化。
在下一篇文章中,我们将探讨更高级的强化学习算法,如 Deep Deterministic Policy Gradient (DDPG)。敬请期待!
代码实例说明:
- 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
- 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:actor = actor.to('cuda'),state = state.to('cuda')。
希望这篇文章能帮助你更好地理解 PPO 算法!如果有任何问题,欢迎在评论区留言讨论。
相关推荐
- 如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL
-
阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...
- Python数据分析:探索性分析
-
写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...
- C++基础语法梳理:算法丨十大排序算法(二)
-
本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...
- C 语言的标准库有哪些
-
C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...
- [深度学习] ncnn安装和调用基础教程
-
1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...
- 用rust实现经典的冒泡排序和快速排序
-
1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...
- ncnn+PPYOLOv2首次结合!全网最详细代码解读来了
-
编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...
- C++特性使用建议
-
1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...
- Qt4/5升级到Qt6吐血经验总结V202308
-
00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...
- 到底什么是C++11新特性,请看下文
-
C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...
- 掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!
-
C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...
- 经典算法——凸包算法
-
凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...
- 一起学习c++11——c++11中的新增的容器
-
c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...
- C++ 编程中的一些最佳实践
-
1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)