使用PyTorch进行迁移学习(pytorch模型迁移)
ztj100 2024-10-31 16:13 22 浏览 0 评论
以及为什么不应该从头开始编写CNN架构
如今,训练深度学习模型(尤其是与图像识别相关的模型)是一项非常简单的任务。 您不应该过多强调架构的原因很多,主要是有人已经为您完成了这一步骤。 其余的,您需要进一步阅读。
源代码:Colab Notebook
如今,作为工程师,您唯一应关注的就是数据准备-在深度学习领域,该术语概括了数据收集,加载,规范化和扩充的过程。
今天的议程很简单-解释什么是转移学习以及如何使用转移学习,然后给出带有或不带有预训练架构的模型训练的实际示例。
听起来很简单,所以我们直接开始吧!
数据集下载和基本准备
让我们从导入开始。 在这里,我们有像Numpy,Pandas和Matplotlib这样的常见嫌疑人,还有我们最喜欢的深度学习库Pytorch,其次是它所提供的一切。
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision import models, transforms, datasets
我们将在Colab或Colab Pro中更精确地编写此代码,因此我们将利用GPU的强大功能进行培训。
由于我们正在使用GPU进行培训,而您可能并非如此,因此,我们需要一种可靠的方法来进行处理。 这是一种标准方法:
device = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)
device
>>> device(type=’cuda’, index=0)
如果您正在使用CPU进行培训,则应该输入type ='cpu'之类的字眼,但是由于Colab是免费的,因此您无需这样做。
现在到数据集上。 我们将为此使用Dog或Cat数据集。 它具有大量各种尺寸的图像,我们将在以后处理这些图像。 现在,我们需要下载并解压缩它。 就是这样:
%mkdir data
%cd /content/data/
!wget http://files.fast.ai/data/dogscats.zip
!unzip dogscats.zip
大约一分钟后,根据您的互联网速度,可以使用该数据集。 现在,我们可以将其声明为数据目录-不是必需的,但可以节省一些时间。
DIR_DATA = '/content/data/dogscats/'
资料准备
现在已经完成了第一部分的第一部分。 接下来,我们必须对训练和验证子集应用一些转换,然后使用DataLoaders加载转换后的数据。 这是我们应用的转换:
· 随机旋转
· 随机水平翻转
· 调整为224x224-预训练架构所需
· 转换为张量
· 正常化
这是代码:
train_transforms = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize(224),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
valid_transforms = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
现在,我们使用DataLoaders加载数据。 此步骤也很简单,您可能已经熟悉了:
train_data = datasets.ImageFolder(os.path.join(DIR_DATA, ‘train’), transform=train_transforms)
valid_data = datasets.ImageFolder(os.path.join(DIR_DATA, ‘valid’), transform=valid_transforms)
torch.manual_seed(42)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=64, shuffle=False)
class_names = train_data.classes
class_names
>>> ['cats', 'dogs']
如果现在要对单个批次进行逆归一化并可视化,则可以得到以下信息:
快速浏览上图表明我们的转换工作符合预期。
数据准备部分现已完成,在下一节中,我们将声明一个自定义的CNN架构,对其进行训练并评估性能。
定制架构CNN
对于这一部分,我们想要做一些非常简单的事情-3个卷积层,每个卷积层之后是max-pooling和ReLU,然后是一个完全连接的层和一个输出层。
这是该架构的代码:
class CustomCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1)
self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)
self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
self.fc1 = nn.Linear(in_features=26*26*64, out_features=128)
self.out = nn.Linear(in_features=128, out_features=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = F.relu(self.conv3(x))
x = F.max_pool2d(x, kernel_size=2, stride=2)
x = x.view(-1, 26*26*64)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.2)
x = self.out(x)
return F.log_softmax(x, dim=1)
torch.manual_seed(42)
model = CustomCNN()
model.to(device)
从这里我们可以定义一个优化器和标准,我们准备进行训练:
custom_criterion = nn.CrossEntropyLoss()
custom_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
由于您可以访问源代码,并且train_model函数很长,因此我们决定不将其放在此处。 因此,如果您要继续,请参考源代码。 我们将训练模型10个时间段:
custom_model_trained = train_model(
train_loader=train_loader,
test_loader=valid_loader,
model=model,
criterion=custom_criterion,
optimizer=custom_optimizer,
epochs=10
)
一段时间后,这里是获得的结果:
无论如何,这都不是可怕的结果,但是我们如何才能做得更好? 迁移学习就派得上用场了。
迁移学习
您可以轻松地在线查找正式定义。 对我们而言,迁移学习意味着下载预制的体系结构,该体系结构接受过1M +图像的训练,并调整输出层,以便根据需要对尽可能多的类进行分类。
由于我们这里只有猫和狗,因此我们需要将此数字修改为两个。
现在,我们将下载ResNet101架构的预训练版本,并使它的参数不可训练-因为该网络已经过训练:
pretrained_model = models.resnet101(pretrained=True)
for param in pretrained_model.parameters():
param.requires_grad = False
赞! 让我们检查一下输出层的外观:
pretrained_model.fc
>>> Linear(in_features=2048, out_features=1000, bias=True)
因此,默认情况下,该体系结构具有1000个可能的类,但是我们只需要两个类-一个用于猫,一个用于狗。 调整方法如下:
pretrained_model.fc = nn.Sequential(
nn.Linear(2048, 1000),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(1000, 2),
nn.LogSoftmax(dim=1)
)
pretrained_model.to(device)
这就是我们要做的。
好了,我们仍然必须定义和优化器以及一个准则,但是您知道如何做到这一点:
pretrained_criterion = nn.CrossEntropyLoss()
pretrained_optimizer = torch.optim.Adam(pretrained_model.fc.parameters(), lr=0.001)
训练过程与自定义体系结构相同,但是我们不需要太多的时间,因为好了,我们已经知道权重和偏差的正确值。
pretrained_model_trained = train_model(
train_loader=train_loader,
test_loader=valid_loader,
model=pretrained_model,
criterion=pretrained_criterion,
optimizer=pretrained_optimizer,
epochs=1
)
经过一段时间后,得出的结果如下:
那有多神奇? 不但提高了准确性,而且还因为没有训练太多的时间段而节省了很多时间。
现在您知道了迁移学习可以做什么,以及如何以及为什么使用它。 让我们在下一节中总结一下。
结论
而且,您已获得了— PyTorch最简单的迁移学习指南。 当然,如果网络更深入,自定义模型的结果可能会更好,但这不是重点。 关键是,无需强调多少层就足够了,以及最佳超参数值是多少。 至少在大多数情况下。
确保尝试不同的体系结构,并随时在下面的评论部分中告知我们有关结果的信息。
谢谢阅读。
(本文翻译自Dario Rade?i?的文章《Transfer Learning with PyTorch》,参考:https://towardsdatascience.com/transfer-learning-with-pytorch-95dd5dca82a)
相关推荐
- Vue 技术栈(全家桶)(vue technology)
-
Vue技术栈(全家桶)尚硅谷前端研究院第1章:Vue核心Vue简介官网英文官网:https://vuejs.org/中文官网:https://cn.vuejs.org/...
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
-
前言《vue基础》系列是再次回炉vue记的笔记,除了官网那部分知识点外,还会加入自己的一些理解。(里面会有部分和官网相同的文案,有经验的同学择感兴趣的阅读)在开发时,是不是遇到过这样的场景,响应...
- vue3 组件初始化流程(vue组件初始化顺序)
-
学习完成响应式系统后,咋们来看看vue3组件的初始化流程既然是看vue组件的初始化流程,咋们先来创建基本的代码,跑跑流程(在app.vue中写入以下内容,来跑流程)...
- vue3优雅的设置element-plus的table自动滚动到底部
-
场景我是需要在table最后添加一行数据,然后把滚动条滚动到最后。查网上的解决方案都是读取html结构,暴力的去获取,虽能解决问题,但是不喜欢这种打补丁的解决方案,我想着官方应该有相关的定义,于是就去...
- Vue3为什么推荐使用ref而不是reactive
-
为什么推荐使用ref而不是reactivereactive本身具有很大局限性导致使用过程需要额外注意,如果忽视这些问题将对开发造成不小的麻烦;ref更像是vue2时代optionapi的data的替...
- 9、echarts 在 vue 中怎么引用?(必会)
-
首先我们初始化一个vue项目,执行vueinitwebpackechart,接着我们进入初始化的项目下。安装echarts,npminstallecharts-S//或...
- 无所不能,将 Vue 渲染到嵌入式液晶屏
-
该文章转载自公众号@前端时刻,https://mp.weixin.qq.com/s/WDHW36zhfNFVFVv4jO2vrA前言...
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
-
此篇幅比较长,涉及到的小知识点也比较多,一定要耐心看完,记住学东西没有耐心可不行!!!一、添加和修改注:添加和编辑用到了同一个组件,也就是此篇文章你能学会如何封装组件及引用组件;第二能学会async和...
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
-
前言本文整理了...
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
-
今天给大家分享的是Vue3聊天实例中的朋友圈的实现及登录验证和倒计时操作。先上效果图这个是最新开发的vue3.x网页端聊天项目中的朋友圈模块。用到了ElementPlus...
- 不来看看这些 VUE 的生命周期钩子函数?| 原力计划
-
作者|huangfuyk责编|王晓曼出品|CSDN博客VUE的生命周期钩子函数:就是指在一个组件从创建到销毁的过程自动执行的函数,包含组件的变化。可以分为:创建、挂载、更新、销毁四个模块...
- Vue3.5正式上线,父传子props用法更丝滑简洁
-
前言Vue3.5在2024-09-03正式上线,目前在Vue官网显最新版本已经是Vue3.5,其中主要包含了几个小改动,我留意到日常最常用的改动就是props了,肯定是用Vue3的人必用的,所以针对性...
- Vue 3 生命周期完整指南(vue生命周期及使用)
-
Vue2和Vue3中的生命周期钩子的工作方式非常相似,我们仍然可以访问相同的钩子,也希望将它们能用于相同的场景。...
- 救命!这 10 个 Vue3 技巧藏太深了!性能翻倍 + 摸鱼神器全揭秘
-
前端打工人集合!是不是经常遇到这些崩溃瞬间:Vue3项目越写越卡,组件通信像走迷宫,复杂逻辑写得脑壳疼?别慌!作为在一线摸爬滚打多年的老前端,今天直接甩出10个超实用的Vue3实战技巧,手把...
- 怎么在 vue 中使用 form 清除校验状态?
-
在Vue中使用表单验证时,经常需要清除表单的校验状态。下面我将介绍一些方法来清除表单的校验状态。1.使用this.$refs...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- Vue 技术栈(全家桶)(vue technology)
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
- vue3 组件初始化流程(vue组件初始化顺序)
- vue3优雅的设置element-plus的table自动滚动到底部
- Vue3为什么推荐使用ref而不是reactive
- 9、echarts 在 vue 中怎么引用?(必会)
- 无所不能,将 Vue 渲染到嵌入式液晶屏
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)