百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch编程常规套路(上)(pytorch 60分钟教程)

ztj100 2024-11-14 19:24 19 浏览 0 评论

之前的几篇介绍完PyTorch基本并且常用的API后,大家肯定对PyTorch对Tensor的操作有了一个基本的理解,当然我们介绍的API并不是很完整,也相对比较基础。还有很多复杂的API和一些比较不好理解的API没有讲解。不过我们后面遇到一例case也可以单独和大家一起学习分享吃透相关的问题。

今天开始我们就要开始学习PyTorch我们日常编程的基本套路,这个套路基本上是不怎么变换的,换的也是一些细节,换的也是对应场景的模型,当然本次学习大家最好都能够手动coding一下代码,主要是留下一些印象。

流程基本介绍

回归正题,正常来说PyTorch编程的常见套路如下:

一般来说,整体流程如下解释:

1.首先需要获取当前需要处理的数据,不管是图片,文字,声音,表格,视频等都需要通过PyTorch提供的一些API把这些数据转换成Tensor

2.我们需要选择一个和当前场景比较适合的深度学习模型,这边一般需要遵循如下的一些条件和规律:

  • 首先这个模型是能解决你的问题的,没有万能模型。
  • 我们也要选择对应的损失函数和优化器
  • 训练阶段一般都是几个循环

3.有了处理好的数据,也有了正确的模型(最好走一个sample跑一下模型)接下来我们就应该直接跑一下训练一下模型的参数,通过训练让模型学习到这些数据的特性,简而言之,找到相关规律,将这些规律能够很好地在模型参数里面进行体现。

4.训练好了模型接下来一步就是评估模型是否符合预期,就是用测试数据来评估模型,是否真正地学习到这些数据潜在的规律。

5.在验证完模型后,如果模型符合预期,就需要把训练好的模型参数保存下来,持久化好参数,方便后续就可以直接加载使用这些训练好的模型参数。

基本上所有的PyTorch的编程范式和流程都是这样,在实际使用深度学习解决生活中实际问题的场景中,不同的地方可能就是一些细节,例如模型的复杂度,将实际的数据处理成Tensor的复杂度,训练过程中一些技巧。这边就是需要大家快速了解一下,有一个大概的印象。等后续看到代码的时候,大家就能够快速了解其中的逻辑。至少在hello world中这块还是很简单的。

接下来和大家使用jupyter notebook或者google的colab快速实现一下,废话不用多说,只需要“show me code”就ok,任何语言都是一样的,首先先导入包,如下图所示,后续导入torch和torch.nn包大家可以默认养成习惯,几乎是必须的包,这边还导入了matplotlib包,主要是用来方便用户进行数据可视化,方便用户理解,还记得之前PyTorch入门教材里面写的,可视化和coding是学习深度学习框架PyTorch中不二的法门。

数据部分处理

导入好相关的包,接下来我们就开始准备相关的数据了。刚才也说了,我们可以将我们生活各类的数据转换成Tensor,因为我们这边是入门,所以我们就先造一些数值类数据,方便大家理解,

这边是使用y=ax+b的模式来造数据的,这边我们是知道a和b的值的,这样我们就可以造出很多成对的(x,y)了,这里我们造出100对(x,y)送入到模型,看看模型能不能够通过我们送入的(x,y)键值对能不能够学习出a和b的值。

这边我们假设a等于0.7,b等于0.3。然后我们使用我们之前学到的torch.arange的api,从0开始,步长是0.02,到1为止(不包含1,左闭右开),这边大家顺便理解一下unsqueeze的作用。

好了,我们现在已经有了mock的数据了,并且我们也知道我们这次的模型最后需要得到的“答案”了,接下来我们要对数据进行简单的划分,我们需要将我们自造的这些数据划分为“训练集”,“验证集”和“测试集”了,相信如果对机器学习有过了解的话,这块应该是不陌生的。

  • 训练集:这个集合里面的数据就是我们需要传送给模型,让模型进行学习,通过训练集里面的数据可以让模型找到其中的规律,当然这部分的数据的量级是比较大的,大概是80%
  • 验证集:这部分数据是不给模型进行学习的,也就是说这部分的数据就像单元测试,模型无法提前获取到这个数据集的答案的,需要模型根据输出,根据平时在训练集所获得到的经验来计算相关的输出,当然这个验证集的数据不是必须的,就像素质教育一样,是没有单元测试一样。
  • 测试集:这是需要模型根据测试集的输入传递给模型,让模型输出相关的结果,并且让这个结果和测试集正确的答案做对比,看看模型经过学习的准确率是否符合预期,这是评估模型好坏的一个重要数据集。

了解到数据集的理论知识后,接下来就用python的一些基本操作来进行数据集的划分。因为这是我们的比较简单的测试项目,就没有划分验证集,只有训练集和测试集,我们一共有50个数据对,按照4:1的原则,我们训练集是40个数据对,10个测试数据对。如下所示:

好了,我们现在就是用之前介绍的用“可视化”来方便我们理解,输入和输出之间的关系,这块就涉及到matplotlib里面的基本知识了,我们这里主要是学习PyTorch的,这边就不过多介绍了,我们就简单地把代码敲一遍,有一个大概印象就可以。

我们简单运行一下:

通过可视化,我们能够很清晰地发现是成线性关系的,说明在一些案例中,进行可视化确实可以方便我们理解,不过后面的课程我在学习的过程中,模型的复杂程度超过了可视化的范畴,更多地还是需要自己多看,来回看和来回理解的。

模型搭建

接下来就是正式的环节,需要搭建我们的模型了,我们之前说过选择一个合适的模型来解决我们的问题是至关重要的,但是在这个Hello world关卡里,我们是知道“答案”的,所以我们知道可以使用线性回归模型是能够解决我们的问题的。所以反而对于我们新手而言,更加重要的就是使用PyTorch来写出我们的线性回归模型。

我们来稍微解释一下这段代码,毕竟这是入门级的模型,里面还有的信息相对比较少也比较好理解:

  • 我们构建了一个类,且这个类继承于PyTorch的nn.Module
  • 在init初始化函数里面,我们定义了weights和bias这2个可学习的参数,是使用我们之前学习过torch.randn初始化的一个参数,类型是torch.float,并且requires_grad是true
  • forward方法也是必须要写的,这个方法就是记录前向过程的函数,入参就是一个tensor,最终就是返回这个模型使用weights* x + bias 这个的前向结果作为这个模型的输出。

详细解释

一般来说构建一个网络模型PyTorch提供了四个比较基础的模块方便我们来构建各色各样的网络模型来满足各个不同的复杂场景,不管是后面的LLM模型还是Stable Diffusion模型,我们都是使用PyTorch提供的这四个模块进行构建的,这四个基础模型是

这些模块包含的东西确实非常多,看起来也比较复杂,几乎涵盖了PyTorch搭建模型的所有东西,但是仔细梳理一下你也会发现并不是很难。

torch.nn的模块里面包含了构建复杂模型的基础模块。

torch.nn.Parameter这个里面存储了模型里面的各项tensor参数,我们通过设置requires_grad设置为True的方式来确保可以用梯度下降的方式来更新模型的参数。

torch.nn.Module 这个是基类,是所有子模型的父类,也就是说如果我们使用PyTorch来构建,我们都需要继承torch.nn.Module,并且实现forward方法。

torch.optim 这个包内包含了各项各种各样的优化算法,这个里面optim里面是说在梯度下降的时候如何更好地提高梯度下降的效率和降低相关的loss。

forward函数,这个是所有nn.Module必须需要实现的类,定义了我们在具体模型前向传播的逻辑。

好的,到此为止,我们已经快速讲解了nn.Module模型信息了,我们需要手敲相关的模型,加深我们的理解,下一个小节,我们就开始讲解模型的推理和训练代码coding的流程了,大家周末愉快,本周降温,大家记得保暖。

相关推荐

如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL

阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...

Python数据分析:探索性分析

写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...

CSP-J/S冲奖第21天:插入排序

...

C++基础语法梳理:算法丨十大排序算法(二)

本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...

C 语言的标准库有哪些

C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...

[深度学习] ncnn安装和调用基础教程

1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...

用rust实现经典的冒泡排序和快速排序

1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...

ncnn+PPYOLOv2首次结合!全网最详细代码解读来了

编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...

C++特性使用建议

1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...

Qt4/5升级到Qt6吐血经验总结V202308

00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...

到底什么是C++11新特性,请看下文

C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...

掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!

C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...

经典算法——凸包算法

凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...

一起学习c++11——c++11中的新增的容器

c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...

C++ 编程中的一些最佳实践

1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...

取消回复欢迎 发表评论: