百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL

ztj100 2025-04-26 22:45 227 浏览 0 评论

一、多目标强化学习原理

1. 多目标学习核心思想

多目标强化学习(Multi-Objective RL)旨在让智能体同时优化多个冲突目标,通过平衡目标间的权衡关系找到帕累托最优解集。与传统强化学习的区别在于:

对比维度

传统强化学习

多目标强化学习

目标数量

单一奖励函数

多个奖励函数(可能相互冲突)

优化目标

最大化单一累计奖励

找到帕累托最优策略集合

解的唯一性

唯一最优解

多个非支配解(Pareto Front)

应用场景

目标明确且无冲突的任务

自动驾驶(安全 vs 效率)、资源分配

2. 多目标问题建模


二、标量化方法(Scalarization)


三、多目标 PPO 算法实现(基于 Gymnasium)

自定义机器人控制环境 为例,实现基于权重调整的多目标 PPO 算法:

  1. 定义多目标环境:机器人需同时最大化前进速度和最小化能耗
  2. 构建策略网络:共享特征提取层 + 多目标价值头
  3. 动态权重调整:根据训练阶段调整目标权重
  4. 帕累托前沿分析:评估不同权重下的策略性能

四、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal
import gymnasium as gym
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ================== 自定义多目标环境 ==================
class MultiObjectiveRobotEnv(gym.Env):
    def __init__(self):
        self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(4,))
        self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
        self.state = None
        self.steps = 0
        self.max_steps = 200  # 添加最大步数限制
    
    def reset(self):
        self.state = np.random.uniform(-1, 1, size=(4,))
        self.steps = 0
        return self.state.copy()
    
    def step(self, action):
        # Clip actions to valid range
        action = np.clip(action, -1, 1)
        
        # 定义两个冲突目标:速度(正向奖励)和能耗(负向奖励)
        velocity_reward = np.abs(action[0])  # 速度与动作绝对值正相关
        energy_cost = 0.1 * np.sum(np.square(action))  # 能耗与动作平方正相关
        
        self.state += 0.1 * np.array([action[0], action[1], 0, 0])  # 简化动力学
        self.state = np.clip(self.state, -1, 1)  # 保持状态在合理范围内
        
        self.steps += 1
        done = self.steps >= self.max_steps  # 超过最大步数时终止
        
        return self.state.copy(), [velocity_reward, -energy_cost], done, {}

# ================== 多目标 PPO 策略网络 ==================
class MultiObjectivePPO(nn.Module):
    def __init__(self, state_dim, action_dim, num_objectives=2):
        super().__init__()
        # 共享特征提取层
        self.shared_net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.Tanh(),  # Using Tanh for more stable gradients
            nn.Linear(256, 256),
            nn.Tanh()
        )
        # 多目标价值头
        self.value_heads = nn.ModuleList([
            nn.Sequential(nn.Linear(256, 1), nn.Tanh()) for _ in range(num_objectives)
        ])
        # 策略头
        self.actor_mean = nn.Sequential(
            nn.Linear(256, action_dim),
            nn.Tanh()  # Output between -1 and 1
        )
        self.actor_log_std = nn.Parameter(torch.zeros(action_dim) - 1.0)  # Initialize to smaller std
    
    def forward(self, state):
        features = self.shared_net(state)
        values = [head(features) for head in self.value_heads]
        action_mean = self.actor_mean(features)
        action_std = torch.exp(self.actor_log_std).clamp(1e-4, 1.0)  # Clamp std to avoid NaN
        return action_mean, action_std, values

# ================== 训练系统 ==================
class MultiObjectivePPOTrainer:
    def __init__(self):
        self.env = MultiObjectiveRobotEnv()
        self.state_dim = self.env.observation_space.shape[0]
        self.action_dim = self.env.action_space.shape[0]
        self.num_objectives = 2
        self.policy = MultiObjectivePPO(self.state_dim, self.action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-4)  # Reduced learning rate
        self.weights = np.array([0.5, 0.5])  # 初始权重
        self.gamma = 0.99  # Discount factor
        self.clip_epsilon = 0.2  # PPO clip parameter
        self.max_steps_per_episode = 200  # 每episode最大步数
    
    def update_weights(self, episode):
        # 动态调整权重(示例:周期变化)
        self.weights = np.array([np.sin(episode * 0.01) * 0.5 + 0.5, 
                                1 - (np.sin(episode * 0.01) * 0.5 + 0.5)])
        self.weights /= np.sum(self.weights)
    
    def train(self, max_episodes=1000):
        for episode in range(max_episodes):
            self.update_weights(episode)
            state = self.env.reset()
            episode_rewards = np.zeros(self.num_objectives)
            episode_steps = 0
            
            # 采集轨迹数据
            states, actions, log_probs, objectives = [], [], [], []
            
            for _ in range(self.max_steps_per_episode):
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
                with torch.no_grad():
                    action_mean, action_std, values = self.policy(state_tensor)
                    dist = Normal(action_mean, action_std)
                    action = dist.sample()
                    log_prob = dist.log_prob(action).sum(dim=-1)
                
                next_state, obj_rewards, done, _ = self.env.step(action.squeeze(0).cpu().numpy())
                
                states.append(state)
                actions.append(action.squeeze(0))
                log_probs.append(log_prob)
                objectives.append(obj_rewards)
                episode_rewards += np.array(obj_rewards)
                episode_steps += 1
                state = next_state
                
                if done:
                    break
            
            # 计算加权累计奖励
            weighted_reward = np.dot(episode_rewards, self.weights)
            if (episode + 1) % 100 == 0:
                print(f"Episode {episode+1}/{max_episodes} | Steps: {episode_steps} | "
                    f"Weight: {self.weights} | Reward: {weighted_reward:.1f} | "
                    f"Obj1: {episode_rewards[0]:.1f} | Obj2: {episode_rewards[1]:.1f}")

if __name__ == "__main__":
    start = time.time()
    start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))
    print(f"开始时间: {start_str}")
    print("初始化环境...")
    trainer = MultiObjectivePPOTrainer()
    trainer.train(max_episodes=500)  # 先测试少量episode
    end = time.time()
    end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))
    print(f"训练完成时间: {end_str}")
    print(f"训练完成,耗时: {end - start:.2f}秒")

五、关键代码解析

1.多目标环境设计

  • step() 返回两个奖励:速度奖励(velocity_reward)和能耗惩罚(-energy_cost)。
  • 通过调整动作的绝对值(速度)和平方值(能耗)实现目标冲突。

2.动态权重调整

  • update_weights() 周期性调整目标权重(示例中使用正弦函数)。
  • 实际应用中可根据需求设计自适应权重策略。

3.策略网络结构

  • 共享特征提取层(shared_net)学习状态共性表示。
  • 独立价值头(value_heads)分别预测各目标的价值。

六、训练输出示例

开始时间: 2025-03-26 03:37:15
初始化环境...
Episode 100/500 | Steps: 200 | Weight: [0.91801299 0.08198701] | Reward: 53.4 | Obj1: 58.6 | Obj2: -5.2
Episode 200/500 | Steps: 200 | Weight: [0.95670668 0.04329332] | Reward: 59.9 | Obj1: 62.9 | Obj2: -5.7
Episode 300/500 | Steps: 200 | Weight: [0.57550636 0.42449364] | Reward: 31.9 | Obj1: 59.3 | Obj2: -5.2
Episode 400/500 | Steps: 200 | Weight: [0.12488584 0.87511416] | Reward: 2.2 | Obj1: 61.9 | Obj2: -6.3
Episode 500/500 | Steps: 200 | Weight: [0.01914355 0.98085645] | Reward: -4.4 | Obj1: 64.0 | Obj2: -5.8
训练完成时间: 2025-03-26 03:39:16
训练完成,耗时: 120.95秒

七、总结与扩展

本文实现了多目标强化学习的核心范式——基于动态权重的标量化方法,展示了帕累托前沿的探索能力。

在下一篇文章中,我们将探索 稳定扩散模型(Stable Diffusion),并实现文本到图像生成(Text-to-Image Generation)的完整流程!


注意事项

1.安装依赖:

pip install gymnasium torch numpy

2.自定义环境需继承 gym.Env 并实现 reset()step() 方法。

3.动态权重调整策略可根据实际需求设计(如基于任务难度或用户偏好)。

相关推荐

xls文件保存宏_excel如何保存宏为其他文件使用

一、直接保存为旧版.xls格式(兼容性优先)1.操作方法-在Excel中打开文件→点击「文件」→「另存为」→保存类型选择“Excel97-2003工作簿(*.xls)”。-系统...

C 插入或删除word分页符_怎么删除插了分页符的空白页

分页符是word中常用的一种分页的符号,它标志着上一页的结束和下一页的开始。在word中分页符有两种,一种是自动分页符,也叫软分页符,即一页数据写满以后转到下一页时word自动插入的一个分页符;另一种...

177.C# SqlSugar 删除数据_sql删除数据代码

摘要普通删除、单表删除、表达式删除,子查询删除正文根据主键Where条件删除varret=Db.Deleteable<wms_user>().Where(newwms_user...

C#使用handle实现获取占用指定文件或文件夹的进程(Locksmith功能)

前言:很多时候,一些不知道啥进程,把你的文件给占用了,然后就没办法删掉或者做其他操作。如果使用Locksmith功能,就可以实现快速锁定是哪个进程在搞事情,把对应进程干掉就可以了。下面内容演示C#使用...

小材大用!用好Windows 10文件缩略图

当我们将图片传输到电脑中后,默认情况下Windows会显示小图预览,因此我们可以不打开图片就能看到图的基本模样。为了防止系统负担过重,Windows只在打开特定的文件夹时生成缩略图,且在系统关机时缩略...

C#:删除 Word 中的页眉或页脚_c# 删除文件

C#:删除Word中的页眉或页脚在处理Word文档批量操作时,我们经常需要清除页眉页脚——比如合并文档后去除冗余信息,或为标准化报告格式。手动操作不仅繁琐,更难以集成到自动化流程中。使用Spire...

C# INI文件读写方法_c#ini文件如何一次读取所有数据

在C#项目的开发实践里,存在着一种十分常见且实用的操作习惯,那就是把一部分常用的参数值写入到.ini文件当中。这种做法背后有着充分的考量。从软件系统的设计角度来看,将常用参数集中存放在.ini文件...

C# 基础知识系列- 14 IO篇 文件的操作(1)

0.前言本章节是IO篇的第二集,我们在上一篇中介绍了C#中IO的基本概念和一些基本方法,接下来我们介绍一下操作文件的方法。在编程的世界中,操作文件是一个很重要的技能。...

C# 删除 Excel 工作表中的空白行和空白列

在日常处理Excel数据时,经常会遇到表格中夹杂着许多空白行或空白列。这些空白内容不仅影响数据的整洁性,还可能导致数据处理和分析结果出错。手动逐一删除这些空白行列不仅效率低下,而且容易遗漏。本文将...

微信小程序原生开发【辅助框架】 LWX

项目介绍作者开发了一年多的小程序,在开发过程中遇到了很多的坑与不方便之处,同时又对原生开发有着一定的执著,但是对于习惯了我这种用惯了vue的人来说,原生小程序中的一些写法确实让人感到难受,我想大家在进...

谷歌正式发布Android 12,UI更好看,打造属于自己的定制化属性

焕然一新的Android12今年5月的GoogleI/O大会上,谷歌推出了Android12系统,这是原生安卓系统史上最大的设计变化,从此旧貌换新颜。...

【推荐】一个基于 SpringBoot 框架开发的 OA 办公自动化系统

如果您对源码&技术感兴趣,请点赞+收藏+转发+关注,大家的支持是我分享最大的动力!!!项目介绍...

「干货」9个最热门React PC端组件库|UI框架

最近一直在使用React.js开发项目,在开发过程中也用到了一些开源UI组件库。上次有给大家分享React移动端组件库,今天,就给大家推荐9个常用ReactPC端组件库。...

Android主流UI开源库整理_android ui 布局开源框架

前言最近老大让我整理一份Android主流UI开源库的资料,以补充公司的Android知识库。由于对格式不做特别限制,于是打算用博客的形式记录下来,方便查看、防丢并且可以持续维护、不断更新。标题隐...

Datetimepicker.js用法_datepicker的用法

$('.form_date').datetimepicker({//初始化language:'zh-CN',//weekStart:1,...

取消回复欢迎 发表评论: