百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(13):Proximal Policy Optimization 算法

ztj100 2025-04-26 22:46 43 浏览 0 评论

在上一篇文章中,我们介绍了 Actor-Critic 算法,并使用它解决了 CartPole 问题。本文将深入探讨 Proximal Policy Optimization (PPO) 算法,这是一种更稳定、更高效的策略优化方法。我们将使用 PyTorch 实现 PPO 算法,并应用于经典的 CartPole 问题。



一、PPO 算法基础

PPO 是 OpenAI 提出的一种强化学习算法,旨在解决策略梯度方法中的训练不稳定问题。PPO 通过限制策略更新的幅度,确保每次更新不会偏离当前策略太远,从而提高训练的稳定性。

1. PPO 的核心思想

  • 重要性采样
    • PPO 使用重要性采样(Importance Sampling)来估计新策略的期望回报,从而避免重新采样。
  • 裁剪目标函数
    • PPO 通过裁剪策略更新的幅度,确保新策略不会偏离旧策略太远。
  • 优势函数
    • PPO 使用优势函数(Advantage Function)来评估动作的好坏,从而更高效地更新策略。

2. PPO 的优势

  • 训练稳定
    • 通过限制策略更新的幅度,PPO 避免了策略梯度方法中的训练不稳定问题。
  • 高效采样
    • PPO 可以重复使用旧策略的采样数据,从而提高数据利用率。
  • 适用范围广
    • PPO 可以应用于连续动作空间和离散动作空间的问题。

3. PPO 的算法流程

  1. 使用当前策略采样一批数据。
  2. 计算优势函数和旧策略的概率。
  3. 通过裁剪目标函数更新策略。
  4. 重复上述过程,直到策略收敛。

二、CartPole 问题实战

我们将使用 PyTorch 实现 PPO 算法,并应用于 CartPole 问题。目标是控制小车使其上的杆子保持直立。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义策略网络和价值网络。
  3. 定义 PPO 训练过程。
  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR

# 可视化设置
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# 环境初始化
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# 随机种子设置
SEED = 42
# env.seed(SEED) 
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# 增强型策略网络
class PolicyNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_size)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        return self.net(x)


# 增强型价值网络
class ValueNetwork(nn.Module):
    def __init__(self, state_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_size, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=1.0)
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        return self.net(x)


class PPO:
    def __init__(self, state_dim, action_dim):
        # 超参数设置
        self.gamma = 0.995
        self.gae_lambda = 0.97
        self.clip_ratio = 0.15
        self.update_epochs = 4
        self.norm_adv = True

        # 网络初始化
        self.actor = PolicyNetwork(state_dim, action_dim)
        self.critic = ValueNetwork(state_dim)

        # 优化器设置
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=2e-4)
        self.critic_optim = optim.Adam(self.critic.parameters(), lr=8e-4)

        # 学习率调度
        self.actor_scheduler = StepLR(self.actor_optim, step_size=100, gamma=0.95)
        self.critic_scheduler = StepLR(self.critic_optim, step_size=100, gamma=0.95)

    def get_action(self, state):
        # state_tensor = torch.FloatTensor(state) # 如果使用 Gym <0.26
        state_tensor = torch.FloatTensor(state).unsqueeze(0)  # 如果使用 Gym >=0.26
        with torch.no_grad():
            logits = self.actor(state_tensor)
        probs = torch.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        return action

    def compute_gae(self, rewards, values, next_values, dones):
        advantages = np.zeros_like(rewards)
        last_gae = 0
        for t in reversed(range(len(rewards))):
            if dones[t]:
                delta = rewards[t] - values[t]
                last_gae = delta
            else:
                delta = rewards[t] + self.gamma * next_values[t] - values[t]
                last_gae = delta + self.gamma * self.gae_lambda * last_gae
            advantages[t] = last_gae
        return torch.FloatTensor(advantages)

    def update(self, states, actions, rewards, next_states, dones):
        # 转换为张量
        states = torch.FloatTensor(np.array(states))
        actions = torch.LongTensor(np.array(actions))
        rewards = torch.FloatTensor(np.array(rewards))
        next_states = torch.FloatTensor(np.array(next_states))
        dones = torch.BoolTensor(np.array(dones))

        # 计算价值估计
        with torch.no_grad():
            current_values = self.critic(states).squeeze()
            next_values = self.critic(next_states).squeeze()
            next_values[dones] = 0.0

        # 计算GAE
        advantages = self.compute_gae(rewards.numpy(),
                                      current_values.numpy(),
                                      next_values.numpy(),
                                      dones.numpy())

        # 标准化优势
        if self.norm_adv:
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        # 获取旧策略概率
        with torch.no_grad():
            old_logits = self.actor(states)
            old_probs = torch.softmax(old_logits, dim=-1)
            old_log_probs = torch.log(old_probs.gather(1, actions.unsqueeze(1))).squeeze()

        # 策略优化
        for _ in range(self.update_epochs):
            logits = self.actor(states)
            probs = torch.softmax(logits, dim=-1)
            new_log_probs = torch.log(probs.gather(1, actions.unsqueeze(1))).squeeze()

            ratios = torch.exp(new_log_probs - old_log_probs)
            clipped_ratios = torch.clamp(ratios, 1 - self.clip_ratio, 1 + self.clip_ratio)

            policy_loss = -torch.min(ratios * advantages, clipped_ratios * advantages).mean()

            self.actor_optim.zero_grad()
            policy_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.6)
            self.actor_optim.step()

        # 价值优化
        for _ in range(self.update_epochs):
            current_values = self.critic(states).squeeze()
            target_values = rewards + self.gamma * next_values
            value_loss = F.mse_loss(current_values, target_values)

            self.critic_optim.zero_grad()
            value_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.8)
            self.critic_optim.step()

        # 更新学习率
        self.actor_scheduler.step()
        self.critic_scheduler.step()


# 训练流程
def train_ppo(env, agent, episodes=800, early_stop=30):
    rewards_history = []
    moving_avg = []
    best_score = -np.inf
    no_improve = 0

    for ep in range(episodes):
        # 重置环境时确保正确处理返回值
        # state = env.reset() # 如果使用 Gym <0.26
        state, _ = env.reset()  # 如果使用 Gym >=0.26
        episode_data = {
            'states': [],
            'actions': [],
            'rewards': [],
            'next_states': [],
            'dones': []
        }
        total_reward = 0
        done = False

        while not done:
            action = agent.get_action(state)
            # next_state, reward, done, _ = env.step(action) # 如果使用 Gym <0.26
            next_state, reward, done, _,_ = env.step(action) # 如果使用 Gym >=0.26

            # 处理环境提前终止
            if env._elapsed_steps >= env.spec.max_episode_steps:
                done = True
                reward = 0

            # 存储轨迹
            episode_data['states'].append(state)
            episode_data['actions'].append(action)
            episode_data['rewards'].append(reward)
            episode_data['next_states'].append(next_state)
            episode_data['dones'].append(done)

            state = next_state
            total_reward += reward

        # 更新模型
        agent.update(**episode_data)

        # 记录训练进度
        rewards_history.append(total_reward)
        current_avg = np.mean(rewards_history[-50:])
        moving_avg.append(current_avg)

        # 早停机制
        if current_avg > best_score:
            best_score = current_avg
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= early_stop:
            print(f"早停触发于第{ep + 1}轮,最佳平均奖励: {best_score:.1f}")
            break

        # 进度报告
        if (ep + 1) % 50 == 0:
            print(f"Episode {ep + 1:3d} | 当前奖励: {total_reward:4.0f} | 平均奖励: {current_avg:6.1f}")

    return moving_avg, rewards_history


# 训练启动
ppo_agent = PPO(state_dim, action_dim)
moving_avg, rewards_history = train_ppo(env, ppo_agent)
# 可视化结果
plt.figure(figsize=(12, 6))
plt.plot(rewards_history, alpha=0.4, label='单轮奖励')
plt.plot(moving_avg, 'r-', linewidth=2, label='滑动平均(50轮)')
plt.xlabel('训练轮次')
plt.ylabel('奖励')
plt.title('PPO训练进程')
plt.legend()
plt.grid(True)
plt.show()

三、代码解析

1.策略网络和价值网络

  • 策略网络输出动作的 logits(未归一化的概率),通过 F.softmax 转换为概率分布。
  • 价值网络输出状态的价值估计。

2.PPO 训练过程

  • 使用当前策略采样一批数据。
  • 计算优势函数和旧策略的概率。
  • 通过裁剪目标函数更新策略。

3.训练过程

  • 在训练过程中,每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。
Episode 50 | 当前奖励: 29 | 平均奖励: 20.8
Episode 100 | 当前奖励: 132 | 平均奖励: 69.1
Episode 150 | 当前奖励: 499 | 平均奖励: 261.6
Episode 200 | 当前奖励: 295 | 平均奖励: 412.3
Episode 250 | 当前奖励: 499 | 平均奖励: 481.8
早停触发于第290轮,最佳平均奖励: 497.7

五、总结

本文介绍了 PPO 算法的基本原理,并使用 PyTorch 实现了一个简单的 PPO 模型来解决 CartPole 问题。通过这个例子,我们学习了如何使用 PPO 算法进行策略优化。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Deep Deterministic Policy Gradient (DDPG)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:actor = actor.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解 PPO 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

30天学会Python编程:16. Python常用标准库使用教程

16.1collections模块16.1.1高级数据结构16.1.2示例...

强烈推荐!Python 这个宝藏库 re 正则匹配

Python的re模块(RegularExpression正则表达式)提供各种正则表达式的匹配操作。...

Python爬虫中正则表达式的用法,只讲如何应用,不讲原理

Python爬虫:正则的用法(非原理)。大家好,这节课给大家讲正则的实际用法,不讲原理,通俗易懂的讲如何用正则抓取内容。·导入re库,这里是需要从html这段字符串中提取出中间的那几个文字。实例一个对...

Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)

实现功能:Python数据分析实战-利用正则表达式提取文本中的URL网址和邮箱...

python爬虫教程之爬取当当网 Top 500 本五星好评书籍

我们使用requests和re来写一个爬虫作为一个爱看书的你(说的跟真的似的)怎么能发现好书呢?所以我们爬取当当网的前500本好五星评书籍怎么样?ok接下来就是学习python的正确姿...

深入理解re模块:Python中的正则表达式神器解析

在Python中,"re"是一个强大的模块,用于处理正则表达式(regularexpressions)。正则表达式是一种强大的文本模式匹配工具,用于在字符串中查找、替换或提取特定模式...

如何使用正则表达式和 Python 匹配不以模式开头的字符串

需要在Python中使用正则表达式来匹配不以给定模式开头的字符串吗?如果是这样,你可以使用下面的语法来查找所有的字符串,除了那些不以https开始的字符串。r"^(?!https).*&...

先Mark后用!8分钟读懂 Python 性能优化

从本文总结了Python开发时,遇到的性能优化问题的定位和解决。概述:性能优化的原则——优化需要优化的部分。性能优化的一般步骤:首先,让你的程序跑起来结果一切正常。然后,运行这个结果正常的代码,看看它...

Python“三步”即可爬取,毋庸置疑

声明:本实例仅供学习,切忌遵守robots协议,请不要使用多线程等方式频繁访问网站。#第一步导入模块importreimportrequests#第二步获取你想爬取的网页地址,发送请求,获取网页内...

简单学Python——re库(正则表达式)2(split、findall、和sub)

1、split():分割字符串,返回列表语法:re.split('分隔符','目标字符串')例如:importrere.split(',','...

Lavazza拉瓦萨再度牵手上海大师赛

阅读此文前,麻烦您点击一下“关注”,方便您进行讨论和分享。Lavazza拉瓦萨再度牵手上海大师赛标题:2024上海大师赛:网球与咖啡的浪漫邂逅在2024年的上海劳力士大师赛上,拉瓦萨咖啡再次成为官...

ArkUI-X构建Android平台AAR及使用

本教程主要讲述如何利用ArkUI-XSDK完成AndroidAAR开发,实现基于ArkTS的声明式开发范式在android平台显示。包括:1.跨平台Library工程开发介绍...

Deepseek写歌详细教程(怎样用deepseek写歌功能)

以下为结合DeepSeek及相关工具实现AI写歌的详细教程,涵盖作词、作曲、演唱全流程:一、核心流程三步法1.AI生成歌词-打开DeepSeek(网页/APP/API),使用结构化提示词生成歌词:...

“AI说唱解说影视”走红,“零基础入行”靠谱吗?本报记者实测

“手里翻找冻鱼,精心的布局;老漠却不言语,脸上带笑意……”《狂飙》剧情被写成歌词,再配上“科目三”背景音乐的演唱,这段1分钟30秒的视频受到了无数网友的点赞。最近一段时间随着AI技术的发展,说唱解说影...

AI音乐制作神器揭秘!3款工具让你秒变高手

在音乐创作的领域里,每个人都有一颗想要成为大师的心。但是面对复杂的乐理知识和繁复的制作过程,许多人的热情被一点点消磨。...

取消回复欢迎 发表评论: