PyTorch 深度学习实战(23):多任务强化学习(Multi-Task RL)
ztj100 2025-04-26 22:45 14 浏览 0 评论
一、多任务强化学习原理
1. 多任务学习核心思想
多任务强化学习(Multi-Task RL)旨在让智能体同时学习多个任务,通过共享知识提升学习效率和泛化能力。与单任务强化学习的区别在于:
对比维度 | 单任务强化学习 | 多任务强化学习 |
目标 | 优化单一任务策略 | 同时优化多个任务的共享策略 |
训练方式 | 单任务独立训练 | 多任务联合训练 |
知识迁移 | 无 | 共享表示或参数实现跨任务知识迁移 |
应用场景 | 任务特定场景 | 复杂环境中的通用智能体 |
2. 基于共享表示的多任务框架
通过共享网络层学习任务共性,任务特定层处理任务差异。算法流程如下:
- 任务采样:从任务分布中随机选择一个任务
- 策略执行:基于共享网络生成动作
- 梯度更新:联合优化共享参数和任务特定参数
数学表达:
二、多任务 PPO 算法实现(基于 Gymnasium)
我们将以 Meta-World 多任务机械臂环境 为例,实现基于 PPO 的多任务强化学习:
- 定义任务集合:包含 reach、push、pick-place 等任务
- 构建共享策略网络:共享卷积层 + 任务特定全连接层
- 实现多任务采样:动态切换任务训练
- 联合梯度更新:平衡多任务损失
三、代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
from torch.cuda.amp import autocast, GradScaler
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
from collections import deque
# ================== 配置参数 ==================
class MultiTaskPPOConfig:
task_names = [
'reach-v2-goal-observable',
'push-v2-goal-observable',
'pick-place-v2-goal-observable'
]
num_tasks = 3
hidden_dim = 512
task_specific_dim = 128
lr = 3e-4
gamma = 0.99
gae_lambda = 0.95
clip_epsilon = 0.2
ppo_epochs = 4
batch_size = 512
max_episodes = 2000
max_steps = 500
grad_clip = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================== 共享策略网络 ==================
class SharedPolicy(nn.Module):
def __init__(self, state_dim, action_dim):
super().__init__()
self.action_dim = action_dim
self.shared_net = nn.Sequential(
nn.Linear(state_dim, MultiTaskPPOConfig.hidden_dim),
nn.LayerNorm(MultiTaskPPOConfig.hidden_dim),
nn.GELU(),
nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.hidden_dim),
nn.GELU()
)
# 多任务头部
self.task_heads = nn.ModuleList([
nn.Sequential(
nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
nn.GELU(),
nn.Linear(MultiTaskPPOConfig.task_specific_dim, action_dim)
) for _ in range(MultiTaskPPOConfig.num_tasks)
])
self.value_heads = nn.ModuleList([
nn.Sequential(
nn.Linear(MultiTaskPPOConfig.hidden_dim, MultiTaskPPOConfig.task_specific_dim),
nn.GELU(),
nn.Linear(MultiTaskPPOConfig.task_specific_dim, 1)
) for _ in range(MultiTaskPPOConfig.num_tasks)
])
def forward(self, states, task_ids):
shared_features = self.shared_net(states)
batch_size = states.size(0)
# 初始化与输入相同dtype的输出张量
action_means = torch.zeros_like(
states[:, :self.action_dim], # 假设states维度足够
dtype=states.dtype,
device=states.device
)
values = torch.zeros(
batch_size, 1,
dtype=states.dtype,
device=states.device
)
unique_task_ids = torch.unique(task_ids)
for task_id_tensor in unique_task_ids:
task_id = task_id_tensor.item()
mask = (task_ids == task_id_tensor)
if not mask.any():
continue
selected_features = shared_features[mask]
# 显式转换输出类型到states.dtype (通常是float32)
task_action = self.task_heads[task_id](selected_features).to(dtype=states.dtype)
task_value = self.value_heads[task_id](selected_features).to(dtype=states.dtype)
action_means[mask] = task_action
values[mask] = task_value
return action_means, values
# ================== 训练系统 ==================
class MultiTaskPPOTrainer:
def __init__(self):
# 初始化多任务环境
self.envs = []
self.state_dim = None
self.action_dim = None
# 验证环境并获取维度
for task_name in MultiTaskPPOConfig.task_names:
env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[task_name]()
obs, _ = env.reset()
if self.state_dim is None:
self.state_dim = obs.shape[0]
self.action_dim = env.action_space.shape[0]
else:
assert obs.shape[0] == self.state_dim, f"状态维度不一致: {task_name}"
self.envs.append(env)
# 初始化策略网络
self.policy = SharedPolicy(self.state_dim, self.action_dim).to(MultiTaskPPOConfig.device)
self.optimizer = optim.AdamW(self.policy.parameters(), lr=MultiTaskPPOConfig.lr)
self.scaler = GradScaler()
# 初始化经验回放缓冲
self.buffer = deque(maxlen=MultiTaskPPOConfig.max_steps)
def collect_experience(self, num_steps):
"""并行收集多任务经验"""
for _ in range(num_steps):
task_id = int(np.random.randint(MultiTaskPPOConfig.num_tasks))
env = self.envs[task_id]
if not hasattr(env, '_last_obs'):
state, _ = env.reset()
else:
state = env._last_obs
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(MultiTaskPPOConfig.device)
# 将task_id转换为张量
task_id_tensor = torch.tensor([task_id], dtype=torch.long, device=MultiTaskPPOConfig.device)
action_mean, value = self.policy(state_tensor, task_id_tensor)
dist = Normal(action_mean, torch.ones_like(action_mean))
action = dist.sample().squeeze(0).cpu().numpy()
log_prob = dist.log_prob(action_mean).detach()
next_state, reward, done, trunc, _ = env.step(action)
self.buffer.append({
'state': state,
'action': action,
'log_prob': log_prob.cpu(),
'reward': float(reward),
'done': bool(done),
'task_id': task_id,
'value': float(value.item())
})
state = next_state if not (done or trunc) else env.reset()[0]
def compute_gae(self, values, rewards, dones):
"""计算广义优势估计(GAE)"""
advantages = []
last_advantage = 0
next_value = 0
for t in reversed(range(len(rewards))):
delta = rewards[t] + MultiTaskPPOConfig.gamma * next_value * (1 - dones[t]) - values[t]
last_advantage = delta + MultiTaskPPOConfig.gamma * MultiTaskPPOConfig.gae_lambda * (1 - dones[t]) * last_advantage
advantages.append(last_advantage)
next_value = values[t]
advantages = torch.tensor(advantages[::-1], dtype=torch.float32).to(MultiTaskPPOConfig.device)
returns = advantages + torch.tensor(values, dtype=torch.float32).to(MultiTaskPPOConfig.device)
return (advantages - advantages.mean()) / (advantages.std() + 1e-8), returns
def update_policy(self):
"""策略更新阶段正确转换张量"""
if not self.buffer:
return 0, 0
"""使用PPO进行策略优化"""
# 从缓冲中提取数据
batch = list(self.buffer)
states = torch.tensor(
[x['state'] for x in batch],
dtype=torch.float32,
device=MultiTaskPPOConfig.device
)
actions = torch.FloatTensor(np.array([x['action'] for x in batch])).to(MultiTaskPPOConfig.device)
old_log_probs = torch.cat([x['log_prob'] for x in batch]).to(MultiTaskPPOConfig.device)
rewards = torch.FloatTensor([x['reward'] for x in batch]).to(MultiTaskPPOConfig.device)
dones = torch.FloatTensor([x['done'] for x in batch]).to(MultiTaskPPOConfig.device)
task_ids = torch.tensor(
[x['task_id'] for x in batch],
dtype=torch.long, # 必须指定为long类型
device=MultiTaskPPOConfig.device
)
values = torch.FloatTensor([x['value'] for x in batch]).to(MultiTaskPPOConfig.device)
# 计算GAE和returns
advantages, returns = self.compute_gae(values.cpu().numpy(), rewards.cpu().numpy(), dones.cpu().numpy())
# 自动混合精度训练
with autocast():
total_policy_loss = 0
total_value_loss = 0
for _ in range(MultiTaskPPOConfig.ppo_epochs):
# 随机打乱数据
perm = torch.randperm(len(batch))
for i in range(0, len(batch), MultiTaskPPOConfig.batch_size):
idx = perm[i:i+MultiTaskPPOConfig.batch_size]
# 获取小批量数据
batch_states = states[idx]
batch_actions = actions[idx]
batch_old_log_probs = old_log_probs[idx]
batch_returns = returns[idx]
batch_advantages = advantages[idx]
batch_task_ids = task_ids[idx]
# 前向传播
action_means, new_values = self.policy(states, task_ids)
dist = Normal(action_means, torch.ones_like(action_means))
new_log_probs = dist.log_prob(batch_actions)
# 计算重要性采样比率
ratio = (new_log_probs - batch_old_log_probs).exp()
# 策略损失
surr1 = ratio * batch_advantages.unsqueeze(-1)
surr2 = torch.clamp(ratio, 1-MultiTaskPPOConfig.clip_epsilon,
1+MultiTaskPPOConfig.clip_epsilon) * batch_advantages.unsqueeze(-1)
policy_loss = -torch.min(surr1, surr2).mean()
# 值函数损失
value_loss = 0.5 * (new_values.squeeze() - batch_returns).pow(2).mean()
# 总损失
loss = policy_loss + value_loss
# 反向传播
self.scaler.scale(loss).backward()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
# 梯度裁剪和参数更新
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), MultiTaskPPOConfig.grad_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
return total_policy_loss / MultiTaskPPOConfig.ppo_epochs, total_value_loss / MultiTaskPPOConfig.ppo_epochs
def train(self):
print(f"开始训练,设备:{MultiTaskPPOConfig.device}")
start_time = time.time()
episode_rewards = {i: deque(maxlen=100) for i in range(MultiTaskPPOConfig.num_tasks)}
for episode in range(MultiTaskPPOConfig.max_episodes):
# 经验收集阶段
self.collect_experience(MultiTaskPPOConfig.max_steps)
# 策略优化阶段
policy_loss, value_loss = self.update_policy()
# 记录统计信息
task_id = np.random.randint(MultiTaskPPOConfig.num_tasks)
episode_reward = sum(x['reward'] for x in self.buffer if x['task_id'] == task_id)
episode_rewards[task_id].append(episode_reward)
# 定期输出日志
if (episode + 1) % 100 == 0:
avg_rewards = {k: np.mean(v) if v else 0 for k, v in episode_rewards.items()}
time_cost = time.time() - start_time
print(f"Episode {episode+1:5d} | Time: {time_cost:6.1f}s")
for task_id in range(MultiTaskPPOConfig.num_tasks):
task_name = MultiTaskPPOConfig.task_names[task_id]
print(f" {task_name:25s} | Avg Reward: {avg_rewards[task_id]:7.2f}")
print(f" Policy Loss: {policy_loss:.4f} | Value Loss: {value_loss:.4f}\n")
start_time = time.time()
if __name__ == "__main__":
trainer = MultiTaskPPOTrainer()
print(f"状态维度: {trainer.state_dim}, 动作维度: {trainer.action_dim}")
trainer.train()
四、关键代码解析
1.共享策略网络
- SharedPolicy 包含共享网络层和任务特定头部
- task_heads 和 value_heads 分别处理不同任务的动作和值函数
2.多任务采样机制
- 每个回合随机选择一个任务进行训练
- 动态切换环境实例 env = self.envs[task_id]
3.联合梯度更新
- 计算多任务的策略损失和值函数损失
- 通过 task_id 索引选择对应任务头部参数
五、训练输出示例
状态维度: 39, 动作维度: 4
开始训练,设备:cuda
/workspace/e23.py:184: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:278.)
states = torch.tensor(
/workspace/e23.py:204: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with autocast():
Episode 100 | Time: 931.2s
reach-v2-goal-observable | Avg Reward: 226.83
push-v2-goal-observable | Avg Reward: 8.82
pick-place-v2-goal-observable | Avg Reward: 3.31
Policy Loss: 0.0386 | Value Loss: 13.2587
Episode 200 | Time: 935.3s
reach-v2-goal-observable | Avg Reward: 227.12
push-v2-goal-observable | Avg Reward: 8.83
pick-place-v2-goal-observable | Avg Reward: 3.23
Policy Loss: 0.0434 | Value Loss: 14.9413
Episode 300 | Time: 939.4s
reach-v2-goal-observable | Avg Reward: 226.78
push-v2-goal-observable | Avg Reward: 8.82
pick-place-v2-goal-observable | Avg Reward: 3.23
Policy Loss: 0.0429 | Value Loss: 13.9076
Episode 400 | Time: 938.4s
reach-v2-goal-observable | Avg Reward: 225.74
push-v2-goal-observable | Avg Reward: 8.84
pick-place-v2-goal-observable | Avg Reward: 3.20
Policy Loss: 0.0378 | Value Loss: 14.7157
Episode 500 | Time: 938.4s
reach-v2-goal-observable | Avg Reward: 225.45
push-v2-goal-observable | Avg Reward: 8.81
pick-place-v2-goal-observable | Avg Reward: 3.20
Policy Loss: 0.0381 | Value Loss: 11.7940
Episode 600 | Time: 928.5s
reach-v2-goal-observable | Avg Reward: 225.39
push-v2-goal-observable | Avg Reward: 8.75
pick-place-v2-goal-observable | Avg Reward: 3.20
Policy Loss: 0.0462 | Value Loss: 14.5566
Episode 700 | Time: 926.6s
reach-v2-goal-observable | Avg Reward: 226.37
push-v2-goal-observable | Avg Reward: 8.65
pick-place-v2-goal-observable | Avg Reward: 3.23
Policy Loss: 0.0394 | Value Loss: 15.5556
Episode 800 | Time: 943.8s
reach-v2-goal-observable | Avg Reward: 224.72
push-v2-goal-observable | Avg Reward: 8.64
pick-place-v2-goal-observable | Avg Reward: 3.23
Policy Loss: 0.0361 | Value Loss: 16.0126
Episode 900 | Time: 937.2s
reach-v2-goal-observable | Avg Reward: 224.15
push-v2-goal-observable | Avg Reward: 8.72
pick-place-v2-goal-observable | Avg Reward: 3.21
Policy Loss: 0.0417 | Value Loss: 14.1907
Episode 1000 | Time: 940.7s
reach-v2-goal-observable | Avg Reward: 223.77
push-v2-goal-observable | Avg Reward: 8.73
pick-place-v2-goal-observable | Avg Reward: 3.19
Policy Loss: 0.0399 | Value Loss: 16.0540
Episode 1100 | Time: 937.0s
reach-v2-goal-observable | Avg Reward: 224.73
push-v2-goal-observable | Avg Reward: 8.68
pick-place-v2-goal-observable | Avg Reward: 3.17
Policy Loss: 0.0409 | Value Loss: 15.5525
Episode 1200 | Time: 933.0s
reach-v2-goal-observable | Avg Reward: 224.73
push-v2-goal-observable | Avg Reward: 8.68
pick-place-v2-goal-observable | Avg Reward: 3.17
Policy Loss: 0.0388 | Value Loss: 17.4549
Episode 1300 | Time: 942.1s
reach-v2-goal-observable | Avg Reward: 224.35
push-v2-goal-observable | Avg Reward: 8.71
pick-place-v2-goal-observable | Avg Reward: 3.19
Policy Loss: 0.0447 | Value Loss: 14.6700
Episode 1400 | Time: 966.6s
reach-v2-goal-observable | Avg Reward: 224.27
push-v2-goal-observable | Avg Reward: 8.73
pick-place-v2-goal-observable | Avg Reward: 3.19
Policy Loss: 0.0434 | Value Loss: 13.3487
Episode 1500 | Time: 943.0s
reach-v2-goal-observable | Avg Reward: 223.03
push-v2-goal-observable | Avg Reward: 8.69
pick-place-v2-goal-observable | Avg Reward: 3.21
Policy Loss: 0.0438 | Value Loss: 14.7557
Episode 1600 | Time: 929.1s
reach-v2-goal-observable | Avg Reward: 224.01
push-v2-goal-observable | Avg Reward: 8.69
pick-place-v2-goal-observable | Avg Reward: 3.21
Policy Loss: 0.0365 | Value Loss: 12.2506
Episode 1700 | Time: 937.9s
reach-v2-goal-observable | Avg Reward: 222.88
push-v2-goal-observable | Avg Reward: 8.71
pick-place-v2-goal-observable | Avg Reward: 3.21
Policy Loss: 0.0365 | Value Loss: 11.8954
Episode 1800 | Time: 930.1s
reach-v2-goal-observable | Avg Reward: 224.42
push-v2-goal-observable | Avg Reward: 8.75
pick-place-v2-goal-observable | Avg Reward: 3.18
Policy Loss: 0.0437 | Value Loss: 13.6396
Episode 1900 | Time: 927.0s
reach-v2-goal-observable | Avg Reward: 224.66
push-v2-goal-observable | Avg Reward: 8.71
pick-place-v2-goal-observable | Avg Reward: 3.18
Policy Loss: 0.0360 | Value Loss: 14.3216
Episode 2000 | Time: 934.3s
reach-v2-goal-observable | Avg Reward: 224.73
push-v2-goal-observable | Avg Reward: 8.63
pick-place-v2-goal-observable | Avg Reward: 3.18
Policy Loss: 0.0475 | Value Loss: 14.0712
六、总结与扩展
本文实现了多任务强化学习的核心范式——基于共享策略的 PPO 算法,展示了跨任务知识迁移的能力。读者可尝试以下扩展方向:
1.动态任务权重 根据任务难度自适应调整损失权重:
# 在 update() 中添加任务权重
task_weights = calculate_task_difficulty()
loss = sum([weight * loss_i for weight, loss_i in zip(task_weights, losses)])
2.分层强化学习 引入高层策略调度任务:
class MetaController(nn.Module):
def __init__(self, num_tasks):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, num_tasks)
)
3.课程学习 从简单任务逐步过渡到复杂任务:
def schedule_task():
if episode < 1000:
return 'reach-v2-goal-observable'
elif episode < 2000:
return 'push-v2-goal-observable'
else:
return 'pick-place-v2-goal-observable'
在下一篇文章中,我们将探索 分层强化学习(HRL),并实现 Option-Critic 算法!
注意事项
1.安装依赖:
pip install metaworld gymnasium torch
2.metaworld问题:
如果稳定版存在问题,尝试安装GitHub上的最新版:
pip install git+https://github.com/rlworkgroup/metaworld.git@master
相关推荐
- 如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL
-
阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...
- Python数据分析:探索性分析
-
写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...
- C++基础语法梳理:算法丨十大排序算法(二)
-
本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...
- C 语言的标准库有哪些
-
C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...
- [深度学习] ncnn安装和调用基础教程
-
1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...
- 用rust实现经典的冒泡排序和快速排序
-
1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...
- ncnn+PPYOLOv2首次结合!全网最详细代码解读来了
-
编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...
- C++特性使用建议
-
1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...
- Qt4/5升级到Qt6吐血经验总结V202308
-
00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...
- 到底什么是C++11新特性,请看下文
-
C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...
- 掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!
-
C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...
- 经典算法——凸包算法
-
凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...
- 一起学习c++11——c++11中的新增的容器
-
c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...
- C++ 编程中的一些最佳实践
-
1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)