PyTorch 深度学习实战(25):逆向强化学习(Inverse RL)
ztj100 2025-04-26 22:45 25 浏览 0 评论
一、逆向强化学习原理
1. 逆向强化学习核心思想
逆向强化学习(Inverse Reinforcement Learning, IRL)旨在从专家示范中推断奖励函数,而非直接学习策略。与强化学习的区别在于:
对比维度 | 强化学习 (RL) | 逆向强化学习 (IRL) |
输入 | 已知奖励函数 | 已知专家轨迹(状态-动作序列) |
输出 | 最优策略 | 推测的奖励函数 + 模仿策略 |
目标 | 最大化累计奖励 | 使专家轨迹在推测的奖励函数下最优 |
应用场景 | 游戏、机器人控制 | 模仿学习、自动驾驶策略推断 |
2. 最大熵逆向强化学习框架
最大熵 IRL 通过最大化专家轨迹的似然推断奖励函数,核心公式:
二、生成对抗模仿学习(GAIL)算法
生成对抗模仿学习(Generative Adversarial Imitation Learning, GAIL)结合了 IRL 和 GAN 的思想:
- 生成器(策略网络):生成模仿专家行为的轨迹
- 判别器(奖励网络):区分专家轨迹和生成轨迹
- 对抗训练:策略网络欺骗判别器,使其无法分辨轨迹来源
数学表达:
三、GAIL 实现步骤(基于 Gymnasium)
我们将以 MuJoCo HalfCheetah 环境 为例,实现 GAIL 算法:
- 采集专家数据:使用预训练策略生成专家轨迹
- 构建策略网络:基于 PPO 的生成器
- 构建判别器网络:二分类网络区分专家/生成数据
- 对抗训练:交替优化生成器和判别器
四、代码实现
生成专家数据:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
from collections import deque
import time
# ================== 配置参数优化 ==================
class SACConfig:
env_name = "HalfCheetah-v5" # 使用与 GAIL 相同的环境
hidden_dim = 256 # 网络隐藏层维度
actor_lr = 3e-4 # 策略网络学习率
critic_lr = 3e-4 # 价值网络学习率
alpha_lr = 3e-4 # 温度系数学习率
gamma = 0.99 # 折扣因子
tau = 0.005 # 软更新系数
buffer_size = 100000 # 经验回放缓冲区大小
batch_size = 256 # 批量大小
max_episodes = 1000 # 最大训练回合数(可根据需要调整)
target_entropy = -torch.prod(torch.Tensor([1])).item() # 熵目标
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================== 策略网络(Actor) ==================
class Actor(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, SACConfig.hidden_dim),
nn.ReLU(),
nn.Linear(SACConfig.hidden_dim, SACConfig.hidden_dim),
nn.ReLU(),
nn.Linear(SACConfig.hidden_dim, action_dim),
nn.Tanh() # 假设动作空间在 [-1, 1]
)
self.log_std = nn.Parameter(torch.zeros(action_dim))
def forward(self, state):
mean = self.net(state)
std = self.log_std.exp()
return mean, std
# ================== 价值网络(Critic) ==================
class Critic(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, SACConfig.hidden_dim),
nn.ReLU(),
nn.Linear(SACConfig.hidden_dim, SACConfig.hidden_dim),
nn.ReLU(),
nn.Linear(SACConfig.hidden_dim, 1)
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.net(x)
# ================== SAC 训练系统 ==================
class SACTrainer:
def __init__(self):
self.env = gym.make(SACConfig.env_name)
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = self.env.action_space.shape[0]
# 初始化网络
self.actor = Actor(self.state_dim, self.action_dim).to(SACConfig.device)
self.critic1 = Critic(self.state_dim, self.action_dim).to(SACConfig.device)
self.critic2 = Critic(self.state_dim, self.action_dim).to(SACConfig.device)
self.target_critic1 = Critic(self.state_dim, self.action_dim).to(SACConfig.device)
self.target_critic2 = Critic(self.state_dim, self.action_dim).to(SACConfig.device)
self.target_critic1.load_state_dict(self.critic1.state_dict())
self.target_critic2.load_state_dict(self.critic2.state_dict())
# 优化器
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=SACConfig.actor_lr)
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=SACConfig.critic_lr)
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=SACConfig.critic_lr)
# 自动调节温度系数 alpha
self.log_alpha = torch.zeros(1, requires_grad=True, device=SACConfig.device)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=SACConfig.alpha_lr)
# 经验回放缓冲区
self.buffer = deque(maxlen=SACConfig.buffer_size)
def select_action(self, state):
state = torch.FloatTensor(state).to(SACConfig.device)
mean, std = self.actor(state)
dist = Normal(mean, std)
action = dist.sample()
return action.detach().cpu().numpy()
def update(self):
if len(self.buffer) < SACConfig.batch_size:
return
# 从缓冲区采样
samples = random.sample(self.buffer, SACConfig.batch_size)
states, actions, rewards, next_states, dones = zip(*samples)
states = torch.FloatTensor(np.array(states)).to(SACConfig.device)
actions = torch.FloatTensor(np.array(actions)).to(SACConfig.device)
rewards = torch.FloatTensor(np.array(rewards)).unsqueeze(-1).to(SACConfig.device)
next_states = torch.FloatTensor(np.array(next_states)).to(SACConfig.device)
dones = torch.FloatTensor(np.array(dones)).unsqueeze(-1).to(SACConfig.device)
# 更新 Critic
with torch.no_grad():
next_actions, next_log_probs = self.actor(next_states)
target_q1 = self.target_critic1(next_states, next_actions)
target_q2 = self.target_critic2(next_states, next_actions)
target_q = torch.min(target_q1, target_q2) - self.log_alpha.exp() * next_log_probs
target_q = rewards + SACConfig.gamma * (1 - dones) * target_q
current_q1 = self.critic1(states, actions)
current_q2 = self.critic2(states, actions)
critic1_loss = nn.MSELoss()(current_q1, target_q)
critic2_loss = nn.MSELoss()(current_q2, target_q)
self.critic1_optimizer.zero_grad()
critic1_loss.backward()
self.critic1_optimizer.step()
self.critic2_optimizer.zero_grad()
critic2_loss.backward()
self.critic2_optimizer.step()
# 更新 Actor
new_actions, log_probs = self.actor(states)
q1 = self.critic1(states, new_actions)
q2 = self.critic2(states, new_actions)
actor_loss = (self.log_alpha.exp() * log_probs - torch.min(q1, q2)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 更新 alpha
alpha_loss = -(self.log_alpha * (log_probs + SACConfig.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
# 软更新目标网络
for param, target_param in zip(self.critic1.parameters(), self.target_critic1.parameters()):
target_param.data.copy_(SACConfig.tau * param.data + (1 - SACConfig.tau) * target_param.data)
for param, target_param in zip(self.critic2.parameters(), self.target_critic2.parameters()):
target_param.data.copy_(SACConfig.tau * param.data + (1 - SACConfig.tau) * target_param.data)
def train_and_save_expert_data(self, save_path="expert_data.npy"):
expert_states = []
expert_actions = []
for episode in range(SACConfig.max_episodes):
state = self.env.reset()
episode_reward = 0
while True:
action = self.select_action(state)
next_state, reward, done, _ = self.env.step(action)
self.buffer.append((state, action, reward, next_state, done))
# 收集专家数据(后期训练阶段)
if episode > SACConfig.max_episodes // 2: # 后半段训练数据作为专家数据
expert_states.append(state)
expert_actions.append(action)
state = next_state
episode_reward += reward
self.update()
if done:
break
if (episode + 1) % 100 == 0:
print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
# 保存专家数据
np.save(save_path, {
'states': np.array(expert_states),
'actions': np.array(expert_actions)
})
print(f"专家数据已保存至 {save_path}")
if __name__ == "__main__":
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
print(f"开始时间: {start_str}")
print("训练专家策略...")
trainer = SACTrainer()
trainer.train_and_save_expert_data()
end = time.time()
end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
print(f"训练完成时间: {end_str}")
print(f"总耗时: {end - start:.2f}秒")
实现代码:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import gymnasium as gym
from collections import deque
import time
import random
class GAILConfig:
env_name = "HalfCheetah-v5"
expert_data_path = "expert_data.npy"
hidden_dim = 256
policy_lr = 3e-4
discriminator_lr = 1e-4
gamma = 0.99
lam = 0.95
clip_epsilon = 0.2
batch_size = 64
max_episodes = 100
max_steps = 1000 # Added max steps per episode
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Policy(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.actor = nn.Sequential(
nn.Linear(state_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, action_dim)
)
self.critic = nn.Sequential(
nn.Linear(state_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, 1)
)
self.log_std = nn.Parameter(torch.zeros(action_dim))
def forward(self, state):
action_mean = self.actor(state)
value = self.critic(state)
return action_mean, value
def act(self, state):
with torch.no_grad():
action_mean, value = self.forward(state)
dist = Normal(action_mean, self.log_std.exp())
action = dist.sample()
log_prob = dist.log_prob(action).sum(-1)
return action.cpu().numpy(), log_prob.cpu().numpy(), value.cpu().numpy()
class Discriminator(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, GAILConfig.hidden_dim),
nn.ReLU(),
nn.Linear(GAILConfig.hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.net(x)
class ReplayBuffer:
def __init__(self):
self.buffer = deque(maxlen=100000)
def add(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
samples = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*samples)
return (
torch.FloatTensor(np.array(states)).to(GAILConfig.device),
torch.FloatTensor(np.array(actions)).to(GAILConfig.device),
torch.FloatTensor(np.array(rewards)).unsqueeze(-1).to(GAILConfig.device),
torch.FloatTensor(np.array(next_states)).to(GAILConfig.device),
torch.FloatTensor(np.array(dones)).unsqueeze(-1).to(GAILConfig.device)
)
def __len__(self):
return len(self.buffer)
class GAILTrainer:
def __init__(self):
self.env = gym.make(GAILConfig.env_name)
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = self.env.action_space.shape[0]
self.policy = Policy(self.state_dim, self.action_dim).to(GAILConfig.device)
self.discriminator = Discriminator(self.state_dim, self.action_dim).to(GAILConfig.device)
self.optimizer_policy = optim.Adam(self.policy.parameters(), lr=GAILConfig.policy_lr)
self.optimizer_discriminator = optim.Adam(self.discriminator.parameters(), lr=GAILConfig.discriminator_lr)
# Load expert data
self.expert_data = np.load(GAILConfig.expert_data_path, allow_pickle=True).item()
self.expert_states = torch.FloatTensor(self.expert_data['states']).to(GAILConfig.device)
self.expert_actions = torch.FloatTensor(self.expert_data['actions']).to(GAILConfig.device)
self.buffer = ReplayBuffer()
def compute_reward(self, states, actions):
with torch.no_grad():
d = self.discriminator(states, actions)
return -torch.log(1 - d + 1e-8)
def update_discriminator(self):
states, actions, _, _, _ = self.buffer.sample(GAILConfig.batch_size)
idx = np.random.randint(0, len(self.expert_states), GAILConfig.batch_size)
expert_states = self.expert_states[idx]
expert_actions = self.expert_actions[idx]
real_output = self.discriminator(expert_states, expert_actions)
fake_output = self.discriminator(states, actions)
loss_real = -torch.log(real_output + 1e-8).mean()
loss_fake = -torch.log(1 - fake_output + 1e-8).mean()
loss = loss_real + loss_fake
self.optimizer_discriminator.zero_grad()
loss.backward()
self.optimizer_discriminator.step()
return loss.item()
def compute_gae(self, rewards, values, dones):
advantages = torch.zeros_like(rewards)
last_advantage = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
next_non_terminal = 1.0 - dones[t]
else:
next_value = values[t + 1]
next_non_terminal = 1.0 - dones[t]
delta = rewards[t] + GAILConfig.gamma * next_value * next_non_terminal - values[t]
advantages[t] = delta + GAILConfig.gamma * GAILConfig.lam * next_non_terminal * last_advantage
last_advantage = advantages[t]
returns = advantages + values
return advantages, returns
def update_policy(self, states, actions, log_probs, rewards, dones):
# Calculate values for all states
_, values = self.policy(states)
values = values.squeeze(-1)
# Compute GAE
advantages, returns = self.compute_gae(rewards.squeeze(-1), values, dones.squeeze(-1))
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Calculate new log probs and values
action_means, new_values = self.policy(states)
dist = Normal(action_means, torch.ones_like(action_means))
new_log_probs = dist.log_prob(actions).sum(-1, keepdim=True)
# PPO loss
ratio = (new_log_probs - log_probs).exp()
surr1 = ratio * advantages.unsqueeze(-1)
surr2 = torch.clamp(ratio, 1 - GAILConfig.clip_epsilon, 1 + GAILConfig.clip_epsilon) * advantages.unsqueeze(-1)
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = 0.5 * (new_values - returns.unsqueeze(-1)).pow(2).mean()
# Total loss
total_loss = policy_loss + value_loss
self.optimizer_policy.zero_grad()
total_loss.backward()
self.optimizer_policy.step()
return total_loss.item()
def train(self):
for episode in range(GAILConfig.max_episodes):
state, _ = self.env.reset()
episode_reward = 0
episode_states = []
episode_actions = []
episode_rewards = []
episode_log_probs = []
episode_dones = []
for _ in range(GAILConfig.max_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(GAILConfig.device)
action, log_prob, _ = self.policy.act(state_tensor)
action = action[0]
next_state, reward, terminated, truncated, _ = self.env.step(action)
done = terminated or truncated
episode_states.append(state)
episode_actions.append(action)
episode_rewards.append(reward)
episode_log_probs.append(log_prob)
episode_dones.append(done)
episode_reward += reward
self.buffer.add(state, action, reward, next_state, done)
state = next_state
if done:
break
# Convert to tensors
states_tensor = torch.FloatTensor(np.array(episode_states)).to(GAILConfig.device)
actions_tensor = torch.FloatTensor(np.array(episode_actions)).to(GAILConfig.device)
rewards_tensor = torch.FloatTensor(np.array(episode_rewards)).unsqueeze(-1).to(GAILConfig.device)
log_probs_tensor = torch.FloatTensor(np.array(episode_log_probs)).unsqueeze(-1).to(GAILConfig.device)
dones_tensor = torch.FloatTensor(np.array(episode_dones)).unsqueeze(-1).to(GAILConfig.device)
# Compute rewards from discriminator
rewards_tensor = self.compute_reward(states_tensor, actions_tensor)
# Update policy
p_loss = self.update_policy(
states_tensor,
actions_tensor,
log_probs_tensor,
rewards_tensor,
dones_tensor
)
# Update discriminator
d_loss = self.update_discriminator()
if (episode + 1) % 10 == 0:
print(f"Episode {episode+1} | Reward: {episode_reward:.1f} | Policy Loss: {p_loss:.2f} | D Loss: {d_loss:.2f}")
if __name__ == "__main__":
print("初始化环境...")
trainer = GAILTrainer()
trainer.train()
五、关键代码解析
1.判别器网络
- 输入为状态和动作的拼接,输出专家数据的概率(Sigmoid 激活)
- 损失函数为二元交叉熵,区分专家数据和生成数据
2.奖励计算
- 生成器的奖励来自判别器的输出:$r(s,a) = -\log(1 - D(s,a))$
- 判别器 $D(s,a)$ 越接近 1(专家数据),奖励越大
3.对抗训练流程
- 步骤 1:生成器采样轨迹并存储到经验回放缓冲区
- 步骤 2:用生成数据和专家数据更新判别器
- 步骤 3:用判别器计算的奖励更新生成器(基于 PPO)
六、训练输出示例
初始化环境...
Episode 10 | Reward: -953.6 | Policy Loss: 78.33 | D Loss: 1.23
Episode 20 | Reward: -1001.0 | Policy Loss: 58.94 | D Loss: 1.05
Episode 30 | Reward: -1096.4 | Policy Loss: 50.88 | D Loss: 0.88
Episode 40 | Reward: -1108.3 | Policy Loss: 34.90 | D Loss: 0.78
Episode 50 | Reward: -1144.0 | Policy Loss: 34.77 | D Loss: 0.66
Episode 60 | Reward: -1292.6 | Policy Loss: 33.78 | D Loss: 0.60
Episode 70 | Reward: -1403.1 | Policy Loss: 38.53 | D Loss: 0.59
Episode 80 | Reward: -1741.3 | Policy Loss: 29.86 | D Loss: 0.45
Episode 90 | Reward: -2023.8 | Policy Loss: 45.42 | D Loss: 0.45
Episode 100 | Reward: -2192.2 | Policy Loss: 111.91 | D Loss: 0.36
在下一篇文章中,我们将探索 多目标强化学习(Multi-Objective RL),并实现基于 Pareto 前沿的优化算法!
注意事项
1.安装依赖:
pip install gymnasium torch numpy
2.专家数据生成:
- 使用预训练策略(如 SAC)在目标环境中生成专家轨迹并保存为 expert_data.npy
- 示例数据格式:
expert_data = {
'states': np.array([s0, s1, ..., sn]), # 状态序列
'actions': np.array([a0, a1, ..., an]) # 动作序列
}
3.完整训练需要 GPU 加速(推荐显存 ≥ 8GB)
相关推荐
- WPS 隐藏黑科技!OCT2HEX 函数用法全攻略,数据转换不再愁
-
WPS隐藏黑科技!OCT2HEX函数用法全攻略,数据转换不再愁在WPS表格的强大函数库中,OCT2HEX函数堪称数据进制转换的“魔法钥匙”。无论是程序员处理代码数据,还是工程师进行电路设计...
- WPS 表格隐藏神器!LEFTB 函数让文本处理更高效
-
WPS表格隐藏神器!LEFTB函数让文本处理更高效在职场办公和日常数据处理中,WPS表格堪称我们的得力助手,而其中丰富多样的函数更是提升效率的关键。今天,要为大家介绍一个“宝藏函数”——LEF...
- Java lombok 使用教程(lombok.jar idea)
-
简介Lombok是...
- PART 48: 万能结果自定义,SWITCH函数!
-
公式解析SWITCH:根据值列表计算表达式并返回与第一个匹配值对应的结果。如果没有匹配项,则返回可选默认值用法解析1:评级=SWITCH(TRUE,C2>=90,"优秀",C2...
- Excel 必备if函数使用方法详解(excel表if函数使用)
-
excel表格if函数使用方法介绍打开Excel,在想输出数据的单元格点击工具栏上的“公式”--“插入函数”--“IF”,然后点击确定。...
- Jetty使用场景(jetty入门)
-
Jetty作为一款高性能、轻量级的嵌入式Web服务器和Servlet容器,其核心优势在于模块化设计、快速启动、低资源消耗...
- 【Java教程】基础语法到高级特性(java语言高级特性)
-
Java作为一门面向对象的编程语言,拥有清晰规范的语法体系。本文将系统性地介绍Java的核心语法特性,帮助开发者全面掌握Java编程基础。...
- WPS里这个EVEN 函数,90%的人都没用过!
-
一、开篇引入在日常工作中,我们常常会与各种数据打交道。比如,在统计员工绩效时,需要对绩效分数进行一系列处理;在计算销售数据时,可能要对销售额进行特定的运算。这些看似简单的数据处理任务,实则隐藏着许多技...
- 64 AI助力Excel,查函数查用法简单方便
-
在excel表格当中接入ai之后会是一种什么样的使用体验?今天就跟大家一起来分享一下小程序商店的下一步重大的版本更新。下一个版本将会加入ai功能,接下来会跟大家演示一下基础的用法。ai功能规划的是有三...
- python入门到脱坑 函数—函数的调用
-
Python函数调用详解函数调用是Python编程中最基础也是最重要的操作之一。下面我将详细介绍Python中函数调用的各种方式和注意事项。...
- 从简到繁,一文说清vlookup函数的常见用法
-
VLOOKUP函数是Excel中常用的查找与引用函数,用于在表格中按列查找数据。本文将从简单到复杂,逐步讲解VLOOKUP的用法、语法、应用场景及注意事项。一、VLOOKUP基础:快速入门1.什么是...
- Java新特性:Lambda表达式(java lambda表达式的3种简写方式)
-
1、Lambda表达式概述1.1、Lambda表达式的简介Lambda表达式(Lambdaexpression),也可称为闭包(Closure),是Java(SE)8中一个重要的新特性。Lam...
- WPS 冷门却超实用!ODD 函数用法大揭秘,轻松解决数据处理难题
-
WPS冷门却超实用!ODD函数用法大揭秘,轻松解决数据处理难题在WPS表格庞大的函数家族里,有一些函数虽然不像SUM、VLOOKUP那样广为人知,却在特定场景下能发挥出令人惊叹的作用,OD...
- Python 函数式编程的 8 大核心技巧,不允许你还不会
-
函数式编程是一种强调使用纯函数、避免共享状态和可变数据的编程范式。Python虽然不是纯函数式语言,但提供了丰富的函数式编程特性。以下是Python函数式编程的8个核心技巧:...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)