PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
ztj100 2025-04-26 22:45 227 浏览 0 评论
一、多目标强化学习原理
1. 多目标学习核心思想
多目标强化学习(Multi-Objective RL)旨在让智能体同时优化多个冲突目标,通过平衡目标间的权衡关系找到帕累托最优解集。与传统强化学习的区别在于:
对比维度 | 传统强化学习 | 多目标强化学习 |
目标数量 | 单一奖励函数 | 多个奖励函数(可能相互冲突) |
优化目标 | 最大化单一累计奖励 | 找到帕累托最优策略集合 |
解的唯一性 | 唯一最优解 | 多个非支配解(Pareto Front) |
应用场景 | 目标明确且无冲突的任务 | 自动驾驶(安全 vs 效率)、资源分配 |
2. 多目标问题建模
二、标量化方法(Scalarization)
三、多目标 PPO 算法实现(基于 Gymnasium)
以 自定义机器人控制环境 为例,实现基于权重调整的多目标 PPO 算法:
- 定义多目标环境:机器人需同时最大化前进速度和最小化能耗
- 构建策略网络:共享特征提取层 + 多目标价值头
- 动态权重调整:根据训练阶段调整目标权重
- 帕累托前沿分析:评估不同权重下的策略性能
四、代码实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import gymnasium as gym
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ================== 自定义多目标环境 ==================
class MultiObjectiveRobotEnv(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(4,))
self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
self.state = None
self.steps = 0
self.max_steps = 200 # 添加最大步数限制
def reset(self):
self.state = np.random.uniform(-1, 1, size=(4,))
self.steps = 0
return self.state.copy()
def step(self, action):
# Clip actions to valid range
action = np.clip(action, -1, 1)
# 定义两个冲突目标:速度(正向奖励)和能耗(负向奖励)
velocity_reward = np.abs(action[0]) # 速度与动作绝对值正相关
energy_cost = 0.1 * np.sum(np.square(action)) # 能耗与动作平方正相关
self.state += 0.1 * np.array([action[0], action[1], 0, 0]) # 简化动力学
self.state = np.clip(self.state, -1, 1) # 保持状态在合理范围内
self.steps += 1
done = self.steps >= self.max_steps # 超过最大步数时终止
return self.state.copy(), [velocity_reward, -energy_cost], done, {}
# ================== 多目标 PPO 策略网络 ==================
class MultiObjectivePPO(nn.Module):
def __init__(self, state_dim, action_dim, num_objectives=2):
super().__init__()
# 共享特征提取层
self.shared_net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.Tanh(), # Using Tanh for more stable gradients
nn.Linear(256, 256),
nn.Tanh()
)
# 多目标价值头
self.value_heads = nn.ModuleList([
nn.Sequential(nn.Linear(256, 1), nn.Tanh()) for _ in range(num_objectives)
])
# 策略头
self.actor_mean = nn.Sequential(
nn.Linear(256, action_dim),
nn.Tanh() # Output between -1 and 1
)
self.actor_log_std = nn.Parameter(torch.zeros(action_dim) - 1.0) # Initialize to smaller std
def forward(self, state):
features = self.shared_net(state)
values = [head(features) for head in self.value_heads]
action_mean = self.actor_mean(features)
action_std = torch.exp(self.actor_log_std).clamp(1e-4, 1.0) # Clamp std to avoid NaN
return action_mean, action_std, values
# ================== 训练系统 ==================
class MultiObjectivePPOTrainer:
def __init__(self):
self.env = MultiObjectiveRobotEnv()
self.state_dim = self.env.observation_space.shape[0]
self.action_dim = self.env.action_space.shape[0]
self.num_objectives = 2
self.policy = MultiObjectivePPO(self.state_dim, self.action_dim).to(device)
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-4) # Reduced learning rate
self.weights = np.array([0.5, 0.5]) # 初始权重
self.gamma = 0.99 # Discount factor
self.clip_epsilon = 0.2 # PPO clip parameter
self.max_steps_per_episode = 200 # 每episode最大步数
def update_weights(self, episode):
# 动态调整权重(示例:周期变化)
self.weights = np.array([np.sin(episode * 0.01) * 0.5 + 0.5,
1 - (np.sin(episode * 0.01) * 0.5 + 0.5)])
self.weights /= np.sum(self.weights)
def train(self, max_episodes=1000):
for episode in range(max_episodes):
self.update_weights(episode)
state = self.env.reset()
episode_rewards = np.zeros(self.num_objectives)
episode_steps = 0
# 采集轨迹数据
states, actions, log_probs, objectives = [], [], [], []
for _ in range(self.max_steps_per_episode):
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
with torch.no_grad():
action_mean, action_std, values = self.policy(state_tensor)
dist = Normal(action_mean, action_std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
next_state, obj_rewards, done, _ = self.env.step(action.squeeze(0).cpu().numpy())
states.append(state)
actions.append(action.squeeze(0))
log_probs.append(log_prob)
objectives.append(obj_rewards)
episode_rewards += np.array(obj_rewards)
episode_steps += 1
state = next_state
if done:
break
# 计算加权累计奖励
weighted_reward = np.dot(episode_rewards, self.weights)
if (episode + 1) % 100 == 0:
print(f"Episode {episode+1}/{max_episodes} | Steps: {episode_steps} | "
f"Weight: {self.weights} | Reward: {weighted_reward:.1f} | "
f"Obj1: {episode_rewards[0]:.1f} | Obj2: {episode_rewards[1]:.1f}")
if __name__ == "__main__":
start = time.time()
start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
print(f"开始时间: {start_str}")
print("初始化环境...")
trainer = MultiObjectivePPOTrainer()
trainer.train(max_episodes=500) # 先测试少量episode
end = time.time()
end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
print(f"训练完成时间: {end_str}")
print(f"训练完成,耗时: {end - start:.2f}秒")
五、关键代码解析
1.多目标环境设计
- step() 返回两个奖励:速度奖励(velocity_reward)和能耗惩罚(-energy_cost)。
- 通过调整动作的绝对值(速度)和平方值(能耗)实现目标冲突。
2.动态权重调整
- update_weights() 周期性调整目标权重(示例中使用正弦函数)。
- 实际应用中可根据需求设计自适应权重策略。
3.策略网络结构
- 共享特征提取层(shared_net)学习状态共性表示。
- 独立价值头(value_heads)分别预测各目标的价值。
六、训练输出示例
开始时间: 2025-03-26 03:37:15
初始化环境...
Episode 100/500 | Steps: 200 | Weight: [0.91801299 0.08198701] | Reward: 53.4 | Obj1: 58.6 | Obj2: -5.2
Episode 200/500 | Steps: 200 | Weight: [0.95670668 0.04329332] | Reward: 59.9 | Obj1: 62.9 | Obj2: -5.7
Episode 300/500 | Steps: 200 | Weight: [0.57550636 0.42449364] | Reward: 31.9 | Obj1: 59.3 | Obj2: -5.2
Episode 400/500 | Steps: 200 | Weight: [0.12488584 0.87511416] | Reward: 2.2 | Obj1: 61.9 | Obj2: -6.3
Episode 500/500 | Steps: 200 | Weight: [0.01914355 0.98085645] | Reward: -4.4 | Obj1: 64.0 | Obj2: -5.8
训练完成时间: 2025-03-26 03:39:16
训练完成,耗时: 120.95秒
七、总结与扩展
本文实现了多目标强化学习的核心范式——基于动态权重的标量化方法,展示了帕累托前沿的探索能力。
在下一篇文章中,我们将探索 稳定扩散模型(Stable Diffusion),并实现文本到图像生成(Text-to-Image Generation)的完整流程!
注意事项
1.安装依赖:
pip install gymnasium torch numpy
2.自定义环境需继承 gym.Env 并实现 reset() 和 step() 方法。
3.动态权重调整策略可根据实际需求设计(如基于任务难度或用户偏好)。
相关推荐
- xls文件保存宏_excel如何保存宏为其他文件使用
-
一、直接保存为旧版.xls格式(兼容性优先)1.操作方法-在Excel中打开文件→点击「文件」→「另存为」→保存类型选择“Excel97-2003工作簿(*.xls)”。-系统...
- C 插入或删除word分页符_怎么删除插了分页符的空白页
-
分页符是word中常用的一种分页的符号,它标志着上一页的结束和下一页的开始。在word中分页符有两种,一种是自动分页符,也叫软分页符,即一页数据写满以后转到下一页时word自动插入的一个分页符;另一种...
- 177.C# SqlSugar 删除数据_sql删除数据代码
-
摘要普通删除、单表删除、表达式删除,子查询删除正文根据主键Where条件删除varret=Db.Deleteable<wms_user>().Where(newwms_user...
- C#使用handle实现获取占用指定文件或文件夹的进程(Locksmith功能)
-
前言:很多时候,一些不知道啥进程,把你的文件给占用了,然后就没办法删掉或者做其他操作。如果使用Locksmith功能,就可以实现快速锁定是哪个进程在搞事情,把对应进程干掉就可以了。下面内容演示C#使用...
- 小材大用!用好Windows 10文件缩略图
-
当我们将图片传输到电脑中后,默认情况下Windows会显示小图预览,因此我们可以不打开图片就能看到图的基本模样。为了防止系统负担过重,Windows只在打开特定的文件夹时生成缩略图,且在系统关机时缩略...
- C#:删除 Word 中的页眉或页脚_c# 删除文件
-
C#:删除Word中的页眉或页脚在处理Word文档批量操作时,我们经常需要清除页眉页脚——比如合并文档后去除冗余信息,或为标准化报告格式。手动操作不仅繁琐,更难以集成到自动化流程中。使用Spire...
- C# INI文件读写方法_c#ini文件如何一次读取所有数据
-
在C#项目的开发实践里,存在着一种十分常见且实用的操作习惯,那就是把一部分常用的参数值写入到.ini文件当中。这种做法背后有着充分的考量。从软件系统的设计角度来看,将常用参数集中存放在.ini文件...
- C# 基础知识系列- 14 IO篇 文件的操作(1)
-
0.前言本章节是IO篇的第二集,我们在上一篇中介绍了C#中IO的基本概念和一些基本方法,接下来我们介绍一下操作文件的方法。在编程的世界中,操作文件是一个很重要的技能。...
- C# 删除 Excel 工作表中的空白行和空白列
-
在日常处理Excel数据时,经常会遇到表格中夹杂着许多空白行或空白列。这些空白内容不仅影响数据的整洁性,还可能导致数据处理和分析结果出错。手动逐一删除这些空白行列不仅效率低下,而且容易遗漏。本文将...
- 微信小程序原生开发【辅助框架】 LWX
-
项目介绍作者开发了一年多的小程序,在开发过程中遇到了很多的坑与不方便之处,同时又对原生开发有着一定的执著,但是对于习惯了我这种用惯了vue的人来说,原生小程序中的一些写法确实让人感到难受,我想大家在进...
- 谷歌正式发布Android 12,UI更好看,打造属于自己的定制化属性
-
焕然一新的Android12今年5月的GoogleI/O大会上,谷歌推出了Android12系统,这是原生安卓系统史上最大的设计变化,从此旧貌换新颜。...
- 【推荐】一个基于 SpringBoot 框架开发的 OA 办公自动化系统
-
如果您对源码&技术感兴趣,请点赞+收藏+转发+关注,大家的支持是我分享最大的动力!!!项目介绍...
- 「干货」9个最热门React PC端组件库|UI框架
-
最近一直在使用React.js开发项目,在开发过程中也用到了一些开源UI组件库。上次有给大家分享React移动端组件库,今天,就给大家推荐9个常用ReactPC端组件库。...
- Android主流UI开源库整理_android ui 布局开源框架
-
前言最近老大让我整理一份Android主流UI开源库的资料,以补充公司的Android知识库。由于对格式不做特别限制,于是打算用博客的形式记录下来,方便查看、防丢并且可以持续维护、不断更新。标题隐...
- Datetimepicker.js用法_datepicker的用法
-
$('.form_date').datetimepicker({//初始化language:'zh-CN',//weekStart:1,...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)