百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(16):Soft Actor-Critic (SAC) 算法

ztj100 2025-04-26 22:46 67 浏览 0 评论

在上一篇文章中,我们介绍了 Twin Delayed DDPG (TD3) 算法,并使用它解决了 Pendulum 问题。本文将深入探讨 Soft Actor-Critic (SAC) 算法,这是一种基于最大熵强化学习的算法,能够在连续动作空间中实现高效的策略优化。我们将使用 PyTorch 实现 SAC 算法,并应用于经典的 Pendulum 问题。



一、SAC 算法基础

SAC 是一种基于 Actor-Critic 框架的算法,旨在最大化累积奖励的同时最大化策略的熵。通过引入熵正则化项,SAC 能够在探索和利用之间取得更好的平衡,从而提高训练效率和稳定性。

1. SAC 的核心思想

  • 最大熵目标
    • SAC 不仅最大化累积奖励,还最大化策略的熵,从而鼓励策略探索更多的状态-动作对。
  • 双重 Q 网络
    • SAC 使用两个 Critic 网络来估计 Q 值,从而减少过估计问题。
  • 自动调整熵系数
    • SAC 自动调整熵正则化项的权重,从而避免手动调参。

2. SAC 的优势

  • 高效探索
    • 通过最大化熵,SAC 能够在探索和利用之间取得更好的平衡。
  • 训练稳定
    • 使用双重 Q 网络和目标网络,SAC 能够稳定地训练策略网络和价值网络。
  • 适用于连续动作空间
    • SAC 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。

3. SAC 的算法流程

  1. 使用当前策略采样一批数据。
  2. 使用目标网络计算目标 Q 值。
  3. 更新 Critic 网络以最小化 Q 值的误差。
  4. 更新 Actor 网络以最大化 Q 值和熵。
  5. 更新目标网络。
  6. 重复上述过程,直到策略收敛。

二、Pendulum 问题实战

我们将使用 PyTorch 实现 SAC 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。

1. 问题描述

Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 -2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。

2. 实现步骤

  1. 安装并导入必要的库。
  2. 定义 Actor 网络和 Critic 网络。
  3. 定义 SAC 训练过程。
  4. 测试模型并评估性能。

3. 代码实现

以下是完整的代码实现:

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque

# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 环境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])


# 增强型Actor网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            nn.ReLU()
        )
        self.mu = nn.Linear(512, action_dim)
        self.log_std = nn.Linear(512, action_dim)
        self.max_action = max_action
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                if m == self.mu:
                    nn.init.uniform_(m.weight, -0.1, 0.1)
                    nn.init.constant_(m.bias, 0.5)
                elif m == self.log_std:
                    nn.init.uniform_(m.weight, -0.1, 0.1)
                    nn.init.constant_(m.bias, 0.0)
                else:
                    nn.init.kaiming_normal_(m.weight, mode='fan_in')

    def forward(self, x):
        x = self.net(x)
        mu = self.mu(x)
        log_std = torch.clamp(self.log_std(x), min=-20, max=2)
        return mu, torch.exp(log_std)

    def sample(self, state):
        mu, std = self.forward(state)
        dist = torch.distributions.Normal(mu, std)
        action = dist.rsample()
        action_tanh = torch.tanh(action)
        log_prob = dist.log_prob(action).sum(1, keepdim=True)
        log_prob -= torch.log(1 - action_tanh.pow(2) + 1e-6).sum(1, keepdim=True)
        return action_tanh * self.max_action, log_prob


# 增强型Critic网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, state, action):
        return self.net(torch.cat([state, action], 1))


# SAC算法
class SAC:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.critic1 = Critic(state_dim, action_dim).to(device)
        self.critic2 = Critic(state_dim, action_dim).to(device)

        # 优化器配置
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=1e-4)
        self.critic_optim = optim.Adam(
            list(self.critic1.parameters()) + list(self.critic2.parameters()),
            lr=3e-4
        )

        # 目标网络
        self.critic1_target = Critic(state_dim, action_dim).to(device).eval()
        self.critic2_target = Critic(state_dim, action_dim).to(device).eval()
        self.hard_update()

        # 算法参数
        self.gamma = 0.999
        self.tau = 0.01
        self.alpha = 0.2
        self.target_entropy = -action_dim
        self.log_alpha = torch.log(torch.tensor([self.alpha], device=device)).requires_grad_(True)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=1e-4)

        # 经验回放
        self.buffer = deque(maxlen=1_000_000)
        self.batch_size = 512

    def select_action(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0).to(device)
            action, _ = self.actor.sample(state.unsqueeze(0))
            return action.cpu().numpy().flatten()

    def train(self):
        if len(self.buffer) < self.batch_size:
            return

        # 从缓冲区采样
        states, actions, rewards, next_states, dones = self.sample_batch()

        # Critic更新
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample(next_states)
            target_Q = torch.min(
                self.critic1_target(next_states, next_actions),
                self.critic2_target(next_states, next_actions)
            ) - self.alpha * next_log_probs
            target_Q = rewards + (1 - dones) * self.gamma * target_Q

        current_Q1 = self.critic1(states, actions)
        current_Q2 = self.critic2(states, actions)
        critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)
        nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
        self.critic_optim.step()

        # Actor更新
        new_actions, log_probs = self.actor.sample(states)
        Q = torch.min(self.critic1(states, new_actions),
                      self.critic2(states, new_actions))
        actor_loss = (self.alpha * log_probs - Q).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
        self.actor_optim.step()

        # 温度系数更新
        alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp().item()

        # 目标网络更新
        self.soft_update()

    def sample_batch(self):
        indices = np.random.choice(len(self.buffer), self.batch_size)
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return (
            torch.FloatTensor(states).to(device),
            torch.FloatTensor(actions).to(device),
            torch.FloatTensor(rewards).unsqueeze(1).to(device),
            torch.FloatTensor(next_states).to(device),
            torch.FloatTensor(dones).unsqueeze(1).to(device)
        )

    def hard_update(self):
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())

    def soft_update(self):
        for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def save(self, filename):
        """保存模型参数到文件"""
        torch.save(self.actor.state_dict(), f"{filename}_actor.pth")
        torch.save(self.critic1.state_dict(), f"{filename}_critic1.pth")
        torch.save(self.critic2.state_dict(), f"{filename}_critic2.pth")
        print(f"模型已保存到 {filename}_[network].pth")

    def load(self, filename):
        """从文件加载模型参数"""
        self.actor.load_state_dict(torch.load(f"{filename}_actor.pth", map_location=device))
        self.critic1.load_state_dict(torch.load(f"{filename}_critic1.pth", map_location=device))
        self.critic2.load_state_dict(torch.load(f"{filename}_critic2.pth", map_location=device))
        self.hard_update()  # 同步目标网络
        print(f"已从 {filename}_[network].pth 加载模型")

    def hard_update(self):
        """完全同步目标网络"""
        self.critic1_target.load_state_dict(self.critic1.state_dict())
        self.critic2_target.load_state_dict(self.critic2.state_dict())


# 训练流程优化
def train_sac(env, agent, episodes=2000):
    best_reward = -float('inf')

    for ep in range(episodes):
        state, _ = env.reset()
        episode_reward = 0

        for _ in range(200):  # 限制单轮步数
            if ep < 300:  # 前期探索阶段
                action = env.action_space.sample()
            else:
                action = agent.select_action(state)

            next_state, reward, done, _, _ = env.step(action)
            agent.buffer.append((state, action, reward, next_state, done))
            state = next_state
            episode_reward += reward

            if len(agent.buffer) >= agent.batch_size:
                agent.train()

            if done:
                break

        # 每100轮进行策略评估
        if ep % 100 == 0:
            test_reward = evaluate_policy(agent, env)
            if test_reward > best_reward:
                best_reward = test_reward
                agent.save("best_model")
            print(f"Ep:{ep:4d} | Train:{episode_reward:7.1f} | Test:{test_reward:7.1f} | Best:{best_reward:7.1f}")


def evaluate_policy(agent, env, trials=5):
    total_reward = 0
    for _ in range(trials):
        state, _ = env.reset()
        episode_reward = 0
        for _ in range(200):
            action = agent.select_action(state)
            state, reward, done, _, _ = env.step(action)
            episode_reward += reward
            if done:
                break
        total_reward += episode_reward
    return total_reward / trials


# 启动训练
sac_agent = SAC(state_dim, action_dim, max_action)
train_sac(env, sac_agent, episodes=2000)

三、代码解析

1.Actor 和 Critic 网络

  • Actor 网络输出连续动作,通过 tanh 函数将动作限制在 -max_action,max_action 范围内。
  • Critic 网络输出状态-动作对的 Q 值。

2.SAC 训练过程

  • 使用当前策略采样一批数据。
  • 使用目标网络计算目标 Q 值。
  • 更新 Critic 网络以最小化 Q 值的误差。
  • 更新 Actor 网络以最大化 Q 值和熵。
  • 更新目标网络。

3.训练过程

  • 在训练过程中,每 50 个 episode 打印一次平均奖励。
  • 训练结束后,绘制训练过程中的总奖励曲线。

四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 50 个 episode 打印一次平均奖励。
模型已保存到 best_model_[network].pth
Ep:   0 | Train:-1019.6 | Test:-1281.7 | Best:-1281.7
模型已保存到 best_model_[network].pth
Ep: 100 | Train: -992.1 | Test: -165.6 | Best: -165.6
Ep: 200 | Train:-1513.8 | Test: -182.1 | Best: -165.6
模型已保存到 best_model_[network].pth
Ep: 300 | Train: -234.2 | Test: -154.4 | Best: -154.4
模型已保存到 best_model_[network].pth
Ep: 400 | Train: -321.7 | Test: -143.5 | Best: -143.5
模型已保存到 best_model_[network].pth
Ep: 500 | Train: -133.9 | Test: -123.4 | Best: -123.4
Ep: 600 | Train:  -18.1 | Test: -177.7 | Best: -123.4
Ep: 700 | Train: -287.0 | Test: -147.3 | Best: -123.4
Ep: 800 | Train: -122.2 | Test: -126.6 | Best: -123.4
模型已保存到 best_model_[network].pth
Ep: 900 | Train: -121.5 | Test:  -97.8 | Best:  -97.8
Ep:1000 | Train: -131.7 | Test: -163.8 | Best:  -97.8
模型已保存到 best_model_[network].pth
Ep:1100 | Train: -229.1 | Test:  -75.5 | Best:  -75.5
Ep:1200 | Train: -121.4 | Test: -120.4 | Best:  -75.5
Ep:1300 | Train: -117.5 | Test: -152.6 | Best:  -75.5
Ep:1400 | Train:   -4.0 | Test: -139.0 | Best:  -75.5
Ep:1500 | Train: -124.4 | Test: -116.8 | Best:  -75.5
Ep:1600 | Train: -124.6 | Test: -117.8 | Best:  -75.5
Ep:1700 | Train: -118.5 | Test: -190.6 | Best:  -75.5
Ep:1800 | Train: -308.7 | Test: -137.0 | Best:  -75.5
Ep:1900 | Train: -116.7 | Test: -119.6 | Best:  -75.5

五、总结

本文介绍了 SAC 算法的基本原理,并使用 PyTorch 实现了一个简单的 SAC 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 SAC 算法进行连续动作空间的策略优化。

在下一篇文章中,我们将探讨强化学习领域的重要里程碑——Asynchronous Advantage Actor-Critic (A3C) 算法。敬请期待!

代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
  • 如果你有 GPU,代码会自动检测并使用 GPU 加速。

希望这篇文章能帮助你更好地理解 SAC 算法!如果有任何问题,欢迎在评论区留言讨论。

相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: