百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(12):Actor-Critic 算法与策略优化

ztj100 2025-04-26 22:46 33 浏览 0 评论

在上一篇文章中,我们介绍了强化学习的基本概念,并使用深度 Q 网络(DQN)解决了 CartPole 问题。本文将深入探讨 Actor-Critic 算法,这是一种结合了策略梯度(Policy Gradient)和值函数(Value Function)的强化学习方法。我们将使用 PyTorch 实现 Actor-Critic 算法,并应用于经典的 CartPole 问题。

一、Actor-Critic 算法基础

Actor-Critic 算法是强化学习中的一种重要方法,它结合了策略梯度方法和值函数方法的优点。其核心思想是通过两个网络分别学习策略(Actor)和值函数(Critic),从而更高效地进行策略优化。

1. Actor-Critic 的核心组件

  • Actor(策略网络)
    • 负责学习策略,即在给定状态下选择动作的概率分布。
    • 目标是最大化累积奖励。
    • 通过策略梯度方法更新参数。
  • Critic(值函数网络)
    • 负责评估状态或状态-动作对的价值。
    • 目标是准确估计当前策略的价值函数。
    • 通过时序差分(Temporal Difference, TD)误差更新参数。

2. 优势与特点

  • 降低方差
    • 通过 Critic 提供的值函数估计,Actor-Critic 算法可以减少策略梯度方法的高方差问题。
  • 更稳定的训练
    • Critic 提供了更准确的反馈信号,使得策略优化更加稳定。
  • 适用范围广
    • 可以应用于连续动作空间和离散动作空间的问题。

3. 算法流程

  1. Actor 根据当前策略选择动作。
  2. 环境执行动作,返回奖励和下一状态。
  3. Critic 计算当前状态的价值和 TD 误差。
  4. 使用 TD 误差更新 Actor 和 Critic 的参数。
  5. 重复上述过程,直到策略收敛。

二、CartPole 问题实战

我们将使用 PyTorch 实现 Actor-Critic 算法,并应用于 CartPole 问题。目标是控制小车使其上的杆子保持直立。

1. 问题描述

CartPole 环境的状态空间包括小车的位置、速度、杆子的角度和角速度。动作空间包括向左或向右移动小车。智能体每保持杆子直立一步,就会获得 +1 的奖励,当杆子倾斜超过一定角度或小车移动超出范围时,游戏结束。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义 Actor 和 Critic 网络。
  3. 定义经验回放缓冲区。
  4. 定义 Actor-Critic 训练过程。
  5. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR

# 设置 Matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

env = gym.make('CartPole-v1')

# 固定随机种子
env.seed(42)
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class Actor(nn.Module):
    def __init__(self, state_size, action_size):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_size)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                if m == self.fc3:
                    nn.init.xavier_normal_(m.weight, gain=0.01)
                else:
                    nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc3(x)
        return logits

class Critic(nn.Module):
    def __init__(self, state_size):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), np.array(action), np.array(reward), np.array(next_state), np.array(done)

    def __len__(self):
        return len(self.buffer)

state_size = env.observation_space.shape[0]
action_size = env.action_space.n
actor = Actor(state_size, action_size)

class TwinCritic(nn.Module):
    def __init__(self, state_size):
        super().__init__()
        self.critic1 = Critic(state_size)
        self.critic2 = Critic(state_size)

    def forward(self, x):
        return self.critic1(x), self.critic2(x)

# 初始化网络
critic = TwinCritic(state_size)
target_critic = TwinCritic(state_size)
target_critic.load_state_dict(critic.state_dict())

actor_optimizer = optim.Adam(actor.parameters(), lr=0.0005)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.0005)
actor_scheduler = CosineAnnealingLR(actor_optimizer, T_max=500, eta_min=1e-5)
critic_scheduler = CosineAnnealingLR(critic_optimizer, T_max=500, eta_min=1e-5)
buffer = ReplayBuffer(10000)

batch_size = 128 
while len(buffer) < batch_size:
    state = env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state

def train(actor, critic, target_critic, buffer, batch_size, gamma=0.99):
    if len(buffer) < batch_size:
        return    

    state, action, reward, next_state, done = buffer.sample(batch_size)
    state = torch.FloatTensor(state)
    next_state = torch.FloatTensor(next_state)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    done = torch.FloatTensor(done)

    # Critic 更新
    value1, value2 = critic(state)
    next_value1, next_value2 = target_critic(next_state)
    next_value = torch.min(next_value1, next_value2).detach()
    target_value = reward + gamma * next_value * (1 - done)
    critic_loss = F.mse_loss(value1, target_value) + F.mse_loss(value2, target_value)

    critic_optimizer.zero_grad()
    critic_loss.backward()
    torch.nn.utils.clip_grad_norm_(critic.parameters(), max_norm=1.0)
    critic_optimizer.step()

    # Actor 更新
    logits = actor(state)
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -(probs * log_probs).sum(dim=-1).mean()
    log_prob = log_probs.gather(1, action.unsqueeze(1)).squeeze(1)
    advantage = target_value - 0.5 * (value1.detach() + value2.detach())  # 平均优势
    actor_loss = (-log_prob * advantage).mean() - 0.1 * entropy

    actor_optimizer.zero_grad()
    actor_loss.backward()
    torch.nn.utils.clip_grad_norm_(actor.parameters(), max_norm=5.0)
    actor_optimizer.step()

    actor_scheduler.step()
    critic_scheduler.step()

    # 更新目标网络
    for target_param, param in zip(target_critic.parameters(), critic.parameters()):
        target_param.data.copy_(0.95 * target_param.data + 0.05 * param.data)

def test(env, actor, episodes=10):
    total_reward = 0
    for _ in range(episodes):
        state = env.reset()
        done = False
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            logits = actor(state_tensor)
            probs = F.softmax(logits, dim=-1)
            action = torch.multinomial(probs, 1).item()
            next_state, reward, done, _ = env.step(action)
            total_reward += reward
            state = next_state
    return total_reward / episodes

episodes = 1000
batch_size = 64
gamma = 0.99
rewards = []

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        logits = actor(state_tensor)
        probs = F.softmax(logits, dim=-1)
        action = torch.multinomial(probs, 1).item()
        next_state, reward, done, _ = env.step(action)
        buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward

        if len(buffer) >= batch_size:
            train(actor, critic, target_critic, buffer, batch_size, gamma)

    rewards.append(total_reward)

    if (episode + 1) % 50 == 0:
        avg_reward = test(env, actor)
        print(f"Episode: {episode + 1}, Avg Reward: {avg_reward:.2f}")

plt.plot(rewards)
plt.xlabel("Episode")
plt.ylabel("Total Reward")
plt.title("Actor-Critic train")
plt.show()

三、代码解析

1.Actor 和 Critic 网络

  • Actor 网络输出动作的 logits(未归一化的概率),通过 F.softmax 转换为概率分布。
  • Critic 网络输出状态的价值估计。

2.经验回放缓冲区

  • 使用 deque 实现经验回放缓冲区,存储状态、动作、奖励等信息。

3.训练过程

  • 使用 Critic 计算 TD 误差,更新 Critic 网络。
  • 使用 TD 误差作为优势函数,更新 Actor 网络。
  • 使用梯度裁剪(gradient clipping)防止梯度爆炸。

4.测试过程

  • 在测试环境中评估模型性能,计算平均奖励。

5.可视化

  • 绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。

五、总结

本文介绍了 Actor-Critic 算法的基本原理,并使用 PyTorch 实现了一个简单的 Actor-Critic 模型来解决 CartPole 问题。通过这个例子,我们学习了如何结合策略梯度和值函数方法进行强化学习。

在下一篇文章中,我们将探讨更高级的强化学习算法,如 Proximal Policy Optimization (PPO) 和 Deep Deterministic Policy Gradient (DDPG)。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:actor = actor.to('cuda')state = state.to('cuda')

希望这篇文章能帮助你更好地理解 Actor-Critic 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

30天学会Python编程:16. Python常用标准库使用教程

16.1collections模块16.1.1高级数据结构16.1.2示例...

强烈推荐!Python 这个宝藏库 re 正则匹配

Python的re模块(RegularExpression正则表达式)提供各种正则表达式的匹配操作。...

Python爬虫中正则表达式的用法,只讲如何应用,不讲原理

Python爬虫:正则的用法(非原理)。大家好,这节课给大家讲正则的实际用法,不讲原理,通俗易懂的讲如何用正则抓取内容。·导入re库,这里是需要从html这段字符串中提取出中间的那几个文字。实例一个对...

Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)

实现功能:Python数据分析实战-利用正则表达式提取文本中的URL网址和邮箱...

python爬虫教程之爬取当当网 Top 500 本五星好评书籍

我们使用requests和re来写一个爬虫作为一个爱看书的你(说的跟真的似的)怎么能发现好书呢?所以我们爬取当当网的前500本好五星评书籍怎么样?ok接下来就是学习python的正确姿...

深入理解re模块:Python中的正则表达式神器解析

在Python中,"re"是一个强大的模块,用于处理正则表达式(regularexpressions)。正则表达式是一种强大的文本模式匹配工具,用于在字符串中查找、替换或提取特定模式...

如何使用正则表达式和 Python 匹配不以模式开头的字符串

需要在Python中使用正则表达式来匹配不以给定模式开头的字符串吗?如果是这样,你可以使用下面的语法来查找所有的字符串,除了那些不以https开始的字符串。r"^(?!https).*&...

先Mark后用!8分钟读懂 Python 性能优化

从本文总结了Python开发时,遇到的性能优化问题的定位和解决。概述:性能优化的原则——优化需要优化的部分。性能优化的一般步骤:首先,让你的程序跑起来结果一切正常。然后,运行这个结果正常的代码,看看它...

Python“三步”即可爬取,毋庸置疑

声明:本实例仅供学习,切忌遵守robots协议,请不要使用多线程等方式频繁访问网站。#第一步导入模块importreimportrequests#第二步获取你想爬取的网页地址,发送请求,获取网页内...

简单学Python——re库(正则表达式)2(split、findall、和sub)

1、split():分割字符串,返回列表语法:re.split('分隔符','目标字符串')例如:importrere.split(',','...

Lavazza拉瓦萨再度牵手上海大师赛

阅读此文前,麻烦您点击一下“关注”,方便您进行讨论和分享。Lavazza拉瓦萨再度牵手上海大师赛标题:2024上海大师赛:网球与咖啡的浪漫邂逅在2024年的上海劳力士大师赛上,拉瓦萨咖啡再次成为官...

ArkUI-X构建Android平台AAR及使用

本教程主要讲述如何利用ArkUI-XSDK完成AndroidAAR开发,实现基于ArkTS的声明式开发范式在android平台显示。包括:1.跨平台Library工程开发介绍...

Deepseek写歌详细教程(怎样用deepseek写歌功能)

以下为结合DeepSeek及相关工具实现AI写歌的详细教程,涵盖作词、作曲、演唱全流程:一、核心流程三步法1.AI生成歌词-打开DeepSeek(网页/APP/API),使用结构化提示词生成歌词:...

“AI说唱解说影视”走红,“零基础入行”靠谱吗?本报记者实测

“手里翻找冻鱼,精心的布局;老漠却不言语,脸上带笑意……”《狂飙》剧情被写成歌词,再配上“科目三”背景音乐的演唱,这段1分钟30秒的视频受到了无数网友的点赞。最近一段时间随着AI技术的发展,说唱解说影...

AI音乐制作神器揭秘!3款工具让你秒变高手

在音乐创作的领域里,每个人都有一颗想要成为大师的心。但是面对复杂的乐理知识和繁复的制作过程,许多人的热情被一点点消磨。...

取消回复欢迎 发表评论: