阿里达摩院魔塔ModelScope模型训练Train
ztj100 2025-01-02 20:33 12 浏览 0 评论
模型训练介绍
ModelScope提供了很多模型,这些模型可以直接在推理中使用,也可以根据用户数据集重新生成模型的参数,这个过程叫做训练。特别地,基于预训练backbone进行训练的过程叫做微调(finetune)。
一般来说,一次完整的模型训练包含了训练(train)和评估(evaluate)两个过程。训练过程使用训练数据集,将数据输入模型计算出loss后更新模型参数。评估过程使用评估数据集,将数据输入模型后评估模型效果。
ModelScope提供了完整的训练组件,其中的主要组件被称为trainer(训练器),这些组件可以在预训练或普通训练场景下使用。
PyTorch训练流程
ModelScope的模型训练步骤如下:
- 使用MsDataset加载数据集
- 编写cfg_modify_fn方法,按需修改部分参数
- 构造trainer,开始训练
- 【训练后步骤】进行模型评估
- 【训练后步骤】使用训练后的模型进行推理
PyTorch模型的训练使用EpochBasedTrainer(及其子类),该类会根据配置文件实例化模型、预处理器、优化器、指标等模块。因此训练模型的重点在于修改出合理的配置,其中用到的各组件都是ModelScope的标准模块。
trainer的重要构造参数
model: 模型id、模型本地路径或模型实例,必填
cfg_file: 额外的配置文件,可选。如果填写,trainer会使用这个配置进行训练
cfg_modify_fn: 读取配置后trainer调用这个回调方法修改配置项,可选。如果不传就使用原始配置
train_dataset: 训练用的数据集,调用训练时必传
eval_dataset: 评估用的数据集,调用评估时必传
optimizers: 自定义的(optimizer、lr_scheduler),可选,如果传入就不会使用配置文件中的
seed: 随机种子
launcher: 支持使用pytorch/mpi/slurm开启分布式训练
device: 训练用设备。可选,值为cpu, gpu, gpu:0, cuda:0等,默认gpu
一个简单的例子:文本分类
下面以一个简单的文本分类任务为例,演示如何通过十几行代码,就可以端到端执行一个finetune任务。假设待训练模型为:
# structbert的backbone,该模型没有有效分类器,因此使用前需要finetune(微调)
model_id = 'damo/nlp_structbert_backbone_base_std'
使用MsDataset加载数据集
MsDataset提供了加载数据集的能力,包括用户的数据和ModelScope生态数据集。下面的示例加载了ModelScope提供的afqmc(Ant Financial Question Matching Corpus,双句相似度任务)数据集:
from modelscope.msdatasets import MsDataset
# 载入训练数据,数据格式类似于{'sentence1': 'some content here', 'sentence2': 'other content here', 'label': 0}
train_dataset = MsDataset.load('clue', subset_name='afqmc', split='train')
# 载入评估数据
eval_dataset = MsDataset.load('clue', subset_name='afqmc', split='validation')
或者,也可以加载用户自己的数据集:
from modelscope.msdatasets import MsDataset
# 载入训练数据
train_dataset = MsDataset.load('/path/to/my_train_file.txt')
# 载入评估数据
eval_dataset = MsDataset.load('/path/to/my_eval_file.txt')
编写cfg_modify_fn方法,按需修改部分参数
建议首先查看模型的配置文件,并查看需要额外修改的参数:
from modelscope.utils.hub import read_config
# 上面的model_id
config = read_config(model_id)
print(config.pretty_text)
一般的配置文件中,在训练时需要修改的参数一般分为:
1. 预处理器参数
# 使用该模型适配的预处理器sen-sim-tokenizer
cfg.preprocessor.type='sen-sim-tokenizer'
# 预处理器输入的dict中,句子1的key,参考上文加载数据集中的afqmc的格式
cfg.preprocessor.first_sequence = 'sentence1'
# 预处理器输入的dict中,句子2的key
cfg.preprocessor.second_sequence = 'sentence2'
# 预处理器输入的dict中,label的key
cfg.preprocessor.label = 'label'
# 预处理器需要的label和id的mapping
cfg.preprocessor.label2id = {'0': 0, '1': 1}
某些模态中,预处理的参数需要根据数据集修改(比如NLP一般需要修改,而CV一般不需要修改),后续可以查看ModelCard或各任务最佳实践中各任务训练的详细描述。
2. 模型参数
# num_labels是该模型分类数
cfg.model.num_labels = 2
3. 任务参数
# 修改task类型为'text-classification'
cfg.task = 'text-classification'
# 修改pipeline名称,用于后续推理
cfg.pipeline = {'type': 'text-classification'}
4. 训练参数
一般训练超参数的调节都在这里进行:
# 设置训练epoch
cfg.train.max_epochs = 5
# 工作目录
cfg.train.work_dir = '/tmp'
# 设置batch_size
cfg.train.dataloader.batch_size_per_gpu = 32
cfg.evaluation.dataloader.batch_size_per_gpu = 32
# 设置learning rate
cfg.train.optimizer.lr = 2e-5
# 设置LinearLR的total_iters,这项和数据集大小相关
cfg.train.lr_scheduler.total_iters = int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
# 设置评估metric类
cfg.evaluation.metrics = 'seq-cls-metric'
使用cfg_modify_fn将上述配置修改应用起来:
# 这个方法在trainer读取configuration.json后立即执行,先于构造模型、预处理器等组件
def cfg_modify_fn(cfg):
cfg.preprocessor.type='sen-sim-tokenizer'
cfg.preprocessor.first_sequence = 'sentence1'
cfg.preprocessor.second_sequence = 'sentence2'
cfg.preprocessor.label = 'label'
cfg.preprocessor.label2id = {'0': 0, '1': 1}
cfg.model.num_labels = 2
cfg.task = 'text-classification'
cfg.pipeline = {'type': 'text-classification'}
cfg.train.max_epochs = 5
cfg.train.work_dir = '/tmp'
cfg.train.dataloader.batch_size_per_gpu = 32
cfg.evaluation.dataloader.batch_size_per_gpu = 32
cfg.train.dataloader.workers_per_gpu = 0
cfg.evaluation.dataloader.workers_per_gpu = 0
cfg.train.optimizer.lr = 2e-5
cfg.train.lr_scheduler.total_iters = int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
cfg.evaluation.metrics = 'seq-cls-metric'
# 注意这里需要返回修改后的cfg
return cfg
构造trainer,开始训练
首先,配置训练所需参数:
from modelscope.trainers import build_trainer
# 配置参数
kwargs = dict(
model=model_id,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
cfg_modify_fn=cfg_modify_fn)
trainer = build_trainer(default_args=kwargs)
trainer.train()
需要注意,数据由trainer从dataloader取数据的时候调用预处理器进行处理。
恭喜,你完成了一次模型训练。
进行模型评估
可选地,在训练后可以进行额外数据集的评估。用户可以单独调用evaluate方法对模型进行评估:
from modelscope.msdatasets import MsDataset
# 载入评估数据
eval_dataset = MsDataset.load('clue', subset_name='afqmc', split='validation')
from modelscope.trainers import build_trainer
# 配置参数
kwargs = dict(
# 由于使用的模型训练后的目录,因此不需要传入cfg_modify_fn
model='/tmp/output',
eval_dataset=eval_dataset)
trainer = build_trainer(default_args=kwargs)
trainer.evaluate()
或者,也可以调用predict方法将预测结果保存下来,以供后续打榜:
from modelscope.msdatasets import MsDataset
import numpy as np
# 载入评估数据
eval_dataset = MsDataset.load('clue', subset_name='afqmc', split='test').to_hf_dataset()
from modelscope.trainers import build_trainer
def cfg_modify_fn(cfg):
# 预处理器在mini-batch中留存冗余字段
cfg.preprocessor.val.keep_original_columns = ['sentence1', 'sentence2']
# 预测数据集没有label,将对应key置空
cfg.preprocessor.val.label = None
return cfg
kwargs = dict(
model='damo/nlp_structbert_sentence-similarity_chinese-tiny',
work_dir='/tmp',
cfg_modify_fn=cfg_modify_fn,
# remove_unused_data会将上述keep_original_columns的列转为attributes
remove_unused_data=True)
trainer = build_trainer(default_args=kwargs)
def saving_fn(inputs, outputs):
with open(f'/tmp/predicts.txt', 'a') as f:
# 通过attribute取冗余值
sentence1 = inputs.sentence1
sentence2 = inputs.sentence2
predictions = np.argmax(outputs['logits'].cpu().numpy(), axis=1)
for sent1, sent2, pred in zip(sentence1, sentence2, predictions):
f.writelines(f'{sent1}, {sent2}, {pred}\n')
trainer.predict(predict_datasets=eval_dataset,
saving_fn=saving_fn)
使用训练后的模型进行推理
训练完成以后,文件夹中会生成推理用的模型配置,可以直接用于pipeline:
- {work_dir}/output:训练完成后,存储模型配置文件,及最后一个epoch/iter的模型参数(配置中需要指定CheckpointHook)
- {work_dir}/output_best:最佳模型参数时,存储模型配置文件,及最佳的模型参数(配置中需要指定BestCkptSaverHook)
from modelscope.pipelines import pipeline
pipeline_ins = pipeline('text-classification', model='/tmp/output')
pipeline_ins(('这个功能可用吗', '这个功能现在可用吗'))
此外,ModelScope也会存储*.pth文件,用于后续继续训练、训练后验证、训练后推理。一般一次存储会存储两个pth文件:
- epoch_*.pth 存储模型的state_dict,output/output_best的bin文件是此文件的硬链接
- epoch_*_trainer_state.pth,存储trainer的state_dict
在继续训练场景时,只需要加载模型的pth文件,trainer的pth文件会被同时读取。用户也可以手动link某个pth文件到output/output_best,实现使用任意一个存储节点的推理.
pth的文件名格式如下:
- epoch_{n}/iter_{n}.pth(如epoch_3.pth): 每interval个epoch/iter周期存储(配置中需要指定CheckpointHook)
- best_epoch{n}_{metricname}{m}.pth(如best_iter13_accuracy22.pth):取得最佳模型参数时存储(配置中需要指定BestCkptSaverHook)
# 用于继续训练
trainer.train(checkpoint_path=os.path.join(self.tmp_dir, 'iter_3.pth'))
# 用于训练后评估
trainer.evaluate(checkpoint_path=os.path.join(self.tmp_dir, 'iter_3.pth'))
# 用于训练后推理并通过saving_fn存储预测的label为文件
trainer.predict(checkpoint_path=os.path.join(self.tmp_dir, 'iter_3.pth'),
predict_datasets=some_dataset,
saving_fn=some-saving-fn)
相关推荐
- Java项目宝塔搭建实战MES-Springboot开源MES智能制造系统源码
-
大家好啊,我是测评君,欢迎来到web测评。...
- 一个令人头秃的问题,Logback 日志级别设置竟然无效?
-
原文链接:https://mp.weixin.qq.com/s/EFvbFwetmXXA9ZGBGswUsQ原作者:小黑十一点半...
- 实战!SpringBoot + RabbitMQ死信队列实现超时关单
-
需求背景之为什么要有超时关单原因一:...
- 火了!阿里P8架构师编写堪称神级SpringBoot手册,GitHub星标99+
-
Springboot现在已成为企业面试中必备的知识点,以及企业应用的重要模块。今天小编给大家分享一份来着阿里P8架构师编写的...
- Java本地搭建宝塔部署实战springboot仓库管理系统源码
-
大家好啊,我是测评君,欢迎来到web测评。...
- 工具尝鲜(1)-Fleet构建运行一个Springboot入门Web项目
-
Fleet是JetBrains公司推出的轻量级编辑器,对标VSCode。该款产品还在公测当中,具体下载链接如下JetBrainsFleet:由JetBrains打造的下一代IDE。想要尝试的...
- SPRINGBOOT WEB 实现文件夹上传(保留目录结构)
-
网上搜到的SpringBoot的代码不多,完整的不多,能用的也不多,基本上大部分的文章只是提供了少量的代码,讲一下思路,或者实现方案。之前一般的做法都是使用HTML5来做的,大部都是传文件的,传文件夹...
- Java项目本地部署宝塔搭建实战报修小程序springboot版系统源码
-
大家好啊,我是测评君,欢迎来到web测评。...
- 新年IT界大笑料“工行取得基于SpringBoot的web系统后端实现专利
-
先看看专利描述...
- 看完SpringBoot源码后,整个人都精神了
-
前言当读完SpringBoot源码后,被Spring的设计者们折服,Spring系列中没有几行代码是我们看不懂的,而是难在理解设计思路,阅读Spring、SpringMVC、SpringBoot需要花...
- 阿里大牛再爆神著:SpringBoot+Cloud微服务手册
-
今天给大家分享的这份“Springboot+Springcloud微服务开发实战手册”共有以下三大特点...
- WebClient是什么?SpringBoot中如何使用WebClient?
-
WebClient是什么?WebClient是SpringFramework5引入的一个非阻塞、响应式的Web客户端库。它提供了一种简单而强大的方式来进行HTTP请求,并处理来自服务器的响应。与传...
- SpringBoot系列——基于mui的H5套壳APP开发web框架
-
前言 大致原理:创建一个main主页面,只有主页面有头部、尾部,中间内容嵌入iframe内容子页面,如果在当前页面进行跳转操作,也是在iframe中进行跳转,而如果点击尾部按钮切换模块、页面,那...
- 在Spring Boot中使用 jose4j 实现 JSON Web Token (JWT)
-
JSONWebToken或JWT作为服务之间安全通信的一种方式而闻名。...
- Spring Boot使用AOP方式实现统一的Web请求日志记录?
-
AOP简介AOP(AspectOrientedProgramming),面相切面编程,是通过代码预编译与运行时动态代理的方式来实现程序的统一功能维护的方案。AOP作为Spring框架的核心内容,通...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- Java项目宝塔搭建实战MES-Springboot开源MES智能制造系统源码
- 一个令人头秃的问题,Logback 日志级别设置竟然无效?
- 实战!SpringBoot + RabbitMQ死信队列实现超时关单
- 火了!阿里P8架构师编写堪称神级SpringBoot手册,GitHub星标99+
- Java本地搭建宝塔部署实战springboot仓库管理系统源码
- 工具尝鲜(1)-Fleet构建运行一个Springboot入门Web项目
- SPRINGBOOT WEB 实现文件夹上传(保留目录结构)
- Java项目本地部署宝塔搭建实战报修小程序springboot版系统源码
- 新年IT界大笑料“工行取得基于SpringBoot的web系统后端实现专利
- 看完SpringBoot源码后,整个人都精神了
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)