PyTorch 深度学习实战(16):Soft Actor-Critic (SAC) 算法
ztj100 2025-04-26 22:46 43 浏览 0 评论
在上一篇文章中,我们介绍了 Twin Delayed DDPG (TD3) 算法,并使用它解决了 Pendulum 问题。本文将深入探讨 Soft Actor-Critic (SAC) 算法,这是一种基于最大熵强化学习的算法,能够在连续动作空间中实现高效的策略优化。我们将使用 PyTorch 实现 SAC 算法,并应用于经典的 Pendulum 问题。
一、SAC 算法基础
SAC 是一种基于 Actor-Critic 框架的算法,旨在最大化累积奖励的同时最大化策略的熵。通过引入熵正则化项,SAC 能够在探索和利用之间取得更好的平衡,从而提高训练效率和稳定性。
1. SAC 的核心思想
- 最大熵目标:
- SAC 不仅最大化累积奖励,还最大化策略的熵,从而鼓励策略探索更多的状态-动作对。
- 双重 Q 网络:
- SAC 使用两个 Critic 网络来估计 Q 值,从而减少过估计问题。
- 自动调整熵系数:
- SAC 自动调整熵正则化项的权重,从而避免手动调参。
2. SAC 的优势
- 高效探索:
- 通过最大化熵,SAC 能够在探索和利用之间取得更好的平衡。
- 训练稳定:
- 使用双重 Q 网络和目标网络,SAC 能够稳定地训练策略网络和价值网络。
- 适用于连续动作空间:
- SAC 能够直接输出连续动作,适用于机器人控制、自动驾驶等任务。
3. SAC 的算法流程
- 使用当前策略采样一批数据。
- 使用目标网络计算目标 Q 值。
- 更新 Critic 网络以最小化 Q 值的误差。
- 更新 Actor 网络以最大化 Q 值和熵。
- 更新目标网络。
- 重复上述过程,直到策略收敛。
二、Pendulum 问题实战
我们将使用 PyTorch 实现 SAC 算法,并应用于 Pendulum 问题。目标是控制摆杆使其保持直立。
1. 问题描述
Pendulum 环境的状态空间包括摆杆的角度和角速度。动作空间是一个连续的扭矩值,范围在 -2,2 之间。智能体每保持摆杆直立一步,就会获得一个负的奖励,目标是最大化累积奖励。
2. 实现步骤
- 安装并导入必要的库。
- 定义 Actor 网络和 Critic 网络。
- 定义 SAC 训练过程。
- 测试模型并评估性能。
3. 代码实现
以下是完整的代码实现:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
# 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 环境初始化
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
# 增强型Actor网络
class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.ReLU()
)
self.mu = nn.Linear(512, action_dim)
self.log_std = nn.Linear(512, action_dim)
self.max_action = max_action
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
if m == self.mu:
nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant_(m.bias, 0.5)
elif m == self.log_std:
nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant_(m.bias, 0.0)
else:
nn.init.kaiming_normal_(m.weight, mode='fan_in')
def forward(self, x):
x = self.net(x)
mu = self.mu(x)
log_std = torch.clamp(self.log_std(x), min=-20, max=2)
return mu, torch.exp(log_std)
def sample(self, state):
mu, std = self.forward(state)
dist = torch.distributions.Normal(mu, std)
action = dist.rsample()
action_tanh = torch.tanh(action)
log_prob = dist.log_prob(action).sum(1, keepdim=True)
log_prob -= torch.log(1 - action_tanh.pow(2) + 1e-6).sum(1, keepdim=True)
return action_tanh * self.max_action, log_prob
# 增强型Critic网络
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Linear(512, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Linear(512, 1)
)
def forward(self, state, action):
return self.net(torch.cat([state, action], 1))
# SAC算法
class SAC:
def __init__(self, state_dim, action_dim, max_action):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.critic1 = Critic(state_dim, action_dim).to(device)
self.critic2 = Critic(state_dim, action_dim).to(device)
# 优化器配置
self.actor_optim = optim.Adam(self.actor.parameters(), lr=1e-4)
self.critic_optim = optim.Adam(
list(self.critic1.parameters()) + list(self.critic2.parameters()),
lr=3e-4
)
# 目标网络
self.critic1_target = Critic(state_dim, action_dim).to(device).eval()
self.critic2_target = Critic(state_dim, action_dim).to(device).eval()
self.hard_update()
# 算法参数
self.gamma = 0.999
self.tau = 0.01
self.alpha = 0.2
self.target_entropy = -action_dim
self.log_alpha = torch.log(torch.tensor([self.alpha], device=device)).requires_grad_(True)
self.alpha_optim = optim.Adam([self.log_alpha], lr=1e-4)
# 经验回放
self.buffer = deque(maxlen=1_000_000)
self.batch_size = 512
def select_action(self, state):
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(device)
action, _ = self.actor.sample(state.unsqueeze(0))
return action.cpu().numpy().flatten()
def train(self):
if len(self.buffer) < self.batch_size:
return
# 从缓冲区采样
states, actions, rewards, next_states, dones = self.sample_batch()
# Critic更新
with torch.no_grad():
next_actions, next_log_probs = self.actor.sample(next_states)
target_Q = torch.min(
self.critic1_target(next_states, next_actions),
self.critic2_target(next_states, next_actions)
) - self.alpha * next_log_probs
target_Q = rewards + (1 - dones) * self.gamma * target_Q
current_Q1 = self.critic1(states, actions)
current_Q2 = self.critic2(states, actions)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)
self.critic_optim.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(self.critic1.parameters(), 1.0)
nn.utils.clip_grad_norm_(self.critic2.parameters(), 1.0)
self.critic_optim.step()
# Actor更新
new_actions, log_probs = self.actor.sample(states)
Q = torch.min(self.critic1(states, new_actions),
self.critic2(states, new_actions))
actor_loss = (self.alpha * log_probs - Q).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), 1.0)
self.actor_optim.step()
# 温度系数更新
alpha_loss = -(self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.exp().item()
# 目标网络更新
self.soft_update()
def sample_batch(self):
indices = np.random.choice(len(self.buffer), self.batch_size)
batch = [self.buffer[i] for i in indices]
states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
return (
torch.FloatTensor(states).to(device),
torch.FloatTensor(actions).to(device),
torch.FloatTensor(rewards).unsqueeze(1).to(device),
torch.FloatTensor(next_states).to(device),
torch.FloatTensor(dones).unsqueeze(1).to(device)
)
def hard_update(self):
self.critic1_target.load_state_dict(self.critic1.state_dict())
self.critic2_target.load_state_dict(self.critic2.state_dict())
def soft_update(self):
for param, target_param in zip(self.critic1.parameters(), self.critic1_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.critic2.parameters(), self.critic2_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def save(self, filename):
"""保存模型参数到文件"""
torch.save(self.actor.state_dict(), f"{filename}_actor.pth")
torch.save(self.critic1.state_dict(), f"{filename}_critic1.pth")
torch.save(self.critic2.state_dict(), f"{filename}_critic2.pth")
print(f"模型已保存到 {filename}_[network].pth")
def load(self, filename):
"""从文件加载模型参数"""
self.actor.load_state_dict(torch.load(f"{filename}_actor.pth", map_location=device))
self.critic1.load_state_dict(torch.load(f"{filename}_critic1.pth", map_location=device))
self.critic2.load_state_dict(torch.load(f"{filename}_critic2.pth", map_location=device))
self.hard_update() # 同步目标网络
print(f"已从 {filename}_[network].pth 加载模型")
def hard_update(self):
"""完全同步目标网络"""
self.critic1_target.load_state_dict(self.critic1.state_dict())
self.critic2_target.load_state_dict(self.critic2.state_dict())
# 训练流程优化
def train_sac(env, agent, episodes=2000):
best_reward = -float('inf')
for ep in range(episodes):
state, _ = env.reset()
episode_reward = 0
for _ in range(200): # 限制单轮步数
if ep < 300: # 前期探索阶段
action = env.action_space.sample()
else:
action = agent.select_action(state)
next_state, reward, done, _, _ = env.step(action)
agent.buffer.append((state, action, reward, next_state, done))
state = next_state
episode_reward += reward
if len(agent.buffer) >= agent.batch_size:
agent.train()
if done:
break
# 每100轮进行策略评估
if ep % 100 == 0:
test_reward = evaluate_policy(agent, env)
if test_reward > best_reward:
best_reward = test_reward
agent.save("best_model")
print(f"Ep:{ep:4d} | Train:{episode_reward:7.1f} | Test:{test_reward:7.1f} | Best:{best_reward:7.1f}")
def evaluate_policy(agent, env, trials=5):
total_reward = 0
for _ in range(trials):
state, _ = env.reset()
episode_reward = 0
for _ in range(200):
action = agent.select_action(state)
state, reward, done, _, _ = env.step(action)
episode_reward += reward
if done:
break
total_reward += episode_reward
return total_reward / trials
# 启动训练
sac_agent = SAC(state_dim, action_dim, max_action)
train_sac(env, sac_agent, episodes=2000)
三、代码解析
1.Actor 和 Critic 网络:
- Actor 网络输出连续动作,通过 tanh 函数将动作限制在 -max_action,max_action 范围内。
- Critic 网络输出状态-动作对的 Q 值。
2.SAC 训练过程:
- 使用当前策略采样一批数据。
- 使用目标网络计算目标 Q 值。
- 更新 Critic 网络以最小化 Q 值的误差。
- 更新 Actor 网络以最大化 Q 值和熵。
- 更新目标网络。
3.训练过程:
- 在训练过程中,每 50 个 episode 打印一次平均奖励。
- 训练结束后,绘制训练过程中的总奖励曲线。
四、运行结果
运行上述代码后,你将看到以下输出:
- 训练过程中每 50 个 episode 打印一次平均奖励。
模型已保存到 best_model_[network].pth
Ep: 0 | Train:-1019.6 | Test:-1281.7 | Best:-1281.7
模型已保存到 best_model_[network].pth
Ep: 100 | Train: -992.1 | Test: -165.6 | Best: -165.6
Ep: 200 | Train:-1513.8 | Test: -182.1 | Best: -165.6
模型已保存到 best_model_[network].pth
Ep: 300 | Train: -234.2 | Test: -154.4 | Best: -154.4
模型已保存到 best_model_[network].pth
Ep: 400 | Train: -321.7 | Test: -143.5 | Best: -143.5
模型已保存到 best_model_[network].pth
Ep: 500 | Train: -133.9 | Test: -123.4 | Best: -123.4
Ep: 600 | Train: -18.1 | Test: -177.7 | Best: -123.4
Ep: 700 | Train: -287.0 | Test: -147.3 | Best: -123.4
Ep: 800 | Train: -122.2 | Test: -126.6 | Best: -123.4
模型已保存到 best_model_[network].pth
Ep: 900 | Train: -121.5 | Test: -97.8 | Best: -97.8
Ep:1000 | Train: -131.7 | Test: -163.8 | Best: -97.8
模型已保存到 best_model_[network].pth
Ep:1100 | Train: -229.1 | Test: -75.5 | Best: -75.5
Ep:1200 | Train: -121.4 | Test: -120.4 | Best: -75.5
Ep:1300 | Train: -117.5 | Test: -152.6 | Best: -75.5
Ep:1400 | Train: -4.0 | Test: -139.0 | Best: -75.5
Ep:1500 | Train: -124.4 | Test: -116.8 | Best: -75.5
Ep:1600 | Train: -124.6 | Test: -117.8 | Best: -75.5
Ep:1700 | Train: -118.5 | Test: -190.6 | Best: -75.5
Ep:1800 | Train: -308.7 | Test: -137.0 | Best: -75.5
Ep:1900 | Train: -116.7 | Test: -119.6 | Best: -75.5
五、总结
本文介绍了 SAC 算法的基本原理,并使用 PyTorch 实现了一个简单的 SAC 模型来解决 Pendulum 问题。通过这个例子,我们学习了如何使用 SAC 算法进行连续动作空间的策略优化。
在下一篇文章中,我们将探讨强化学习领域的重要里程碑——Asynchronous Advantage Actor-Critic (A3C) 算法。敬请期待!
代码实例说明:
- 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。
- 如果你有 GPU,代码会自动检测并使用 GPU 加速。
希望这篇文章能帮助你更好地理解 SAC 算法!如果有任何问题,欢迎在评论区留言讨论。
相关推荐
- Sublime Text 4 稳定版 Build 4113 发布
-
IT之家7月18日消息知名编辑器SublimeText4近日发布了Build4113版本,是SublimeText4的第二个稳定版。IT之家了解到,SublimeTe...
- 【小白课程】openKylin便签贴的设计与实现
-
openKylin便签贴作为侧边栏的一个小插件,提供便捷的文本记录和灵活的页面展示。openKylin便签贴分为两个部分:便签列表...
- 壹啦罐罐 Android 手机里的 Xposed 都装了啥
-
这是少数派推出的系列专题,叫做「我的手机里都装了啥」。这个系列将邀请到不同的玩家,从他们各自的角度介绍手机中最爱的或是日常使用最频繁的App。文章将以「每周一篇」的频率更新,内容范围会包括iOS、...
- 电气自动化专业词汇中英文对照表(电气自动化专业英语单词)
-
专业词汇中英文对照表...
- Python界面设计Tkinter模块的核心组件
-
我们使用一个模块,我们要熟悉这个模块的主要元件。如我们设计一个窗口,我们可以用Tk()来完成创建;一些交互元素,按钮、标签、编辑框用到控件;怎么去布局你的界面,我们可以用到pack()、grid()...
- 以色列发现“死海古卷”新残片(死海古卷是真的吗)
-
编译|陈家琦据艺术新闻网(artnews.com)报道,3月16日,以色列考古学家发现了死海古卷(DeadSeaScrolls)新残片。新出土的羊皮纸残片中包括以希腊文书写的《十二先知书》段落,这...
- 鸿蒙Next仓颉语言开发实战教程:订单列表
-
大家上午好,最近不断有友友反馈仓颉语言和ArkTs很像,所以要注意不要混淆。今天要分享的是仓颉语言开发商城应用的订单列表页。首先来分析一下这个页面,它分为三大部分,分别是导航栏、订单类型和订单列表部分...
- 哪些模块可以用在 Xposed for Lollipop 上?Xposed 模块兼容性解答
-
虽然已经有了XposedforLollipop的安装教程,但由于其还处在alpha阶段,一些Xposed模块能不能依赖其正常工作还未可知。为了解决大家对于模块兼容性的疑惑,笔者尽可能多...
- 利用 Fluid 自制 Mac 版 Overcast 应用
-
我喜爱收听播客,健身、上/下班途中,工作中,甚至是忙着做家务时。大多数情况下我会用MarcoArment开发的Overcast(Freemium)在iPhone上收听,这是我目前最喜爱的Po...
- 浅色Al云食堂APP代码(三)(手机云食堂)
-
以下是进一步优化完善后的浅色AI云食堂APP完整代码,新增了数据可视化、用户反馈、智能推荐等功能,并优化了代码结构和性能。项目结构...
- 实战PyQt5: 121-使用QImage实现一个看图应用
-
QImage简介QImage类提供了独立于硬件的图像表示形式,该图像表示形式可以直接访问像素数据,并且可以用作绘制设备。QImage是QPaintDevice子类,因此可以使用QPainter直接在图...
- 滚动条隐藏及美化(滚动条隐藏但是可以滚动)
-
1、滚动条隐藏背景/场景:在移动端,滑动的时候,会显示默认滚动条,如图1://隐藏代码:/*隐藏滚轮*/.ul-scrool-box::-webkit-scrollbar,.ul-scrool...
- 浅色AI云食堂APP完整代码(二)(ai 食堂)
-
以下是整合后的浅色AI云食堂APP完整代码,包含后端核心功能、前端界面以及优化增强功能。项目采用Django框架开发,支持库存管理、订单处理、财务管理等核心功能,并包含库存预警、数据导出、权限管理等增...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)