百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch编程常规套路(上)(pytorch 60分钟教程)

ztj100 2024-11-14 19:24 35 浏览 0 评论

之前的几篇介绍完PyTorch基本并且常用的API后,大家肯定对PyTorch对Tensor的操作有了一个基本的理解,当然我们介绍的API并不是很完整,也相对比较基础。还有很多复杂的API和一些比较不好理解的API没有讲解。不过我们后面遇到一例case也可以单独和大家一起学习分享吃透相关的问题。

今天开始我们就要开始学习PyTorch我们日常编程的基本套路,这个套路基本上是不怎么变换的,换的也是一些细节,换的也是对应场景的模型,当然本次学习大家最好都能够手动coding一下代码,主要是留下一些印象。

流程基本介绍

回归正题,正常来说PyTorch编程的常见套路如下:

一般来说,整体流程如下解释:

1.首先需要获取当前需要处理的数据,不管是图片,文字,声音,表格,视频等都需要通过PyTorch提供的一些API把这些数据转换成Tensor

2.我们需要选择一个和当前场景比较适合的深度学习模型,这边一般需要遵循如下的一些条件和规律:

  • 首先这个模型是能解决你的问题的,没有万能模型。
  • 我们也要选择对应的损失函数和优化器
  • 训练阶段一般都是几个循环

3.有了处理好的数据,也有了正确的模型(最好走一个sample跑一下模型)接下来我们就应该直接跑一下训练一下模型的参数,通过训练让模型学习到这些数据的特性,简而言之,找到相关规律,将这些规律能够很好地在模型参数里面进行体现。

4.训练好了模型接下来一步就是评估模型是否符合预期,就是用测试数据来评估模型,是否真正地学习到这些数据潜在的规律。

5.在验证完模型后,如果模型符合预期,就需要把训练好的模型参数保存下来,持久化好参数,方便后续就可以直接加载使用这些训练好的模型参数。

基本上所有的PyTorch的编程范式和流程都是这样,在实际使用深度学习解决生活中实际问题的场景中,不同的地方可能就是一些细节,例如模型的复杂度,将实际的数据处理成Tensor的复杂度,训练过程中一些技巧。这边就是需要大家快速了解一下,有一个大概的印象。等后续看到代码的时候,大家就能够快速了解其中的逻辑。至少在hello world中这块还是很简单的。

接下来和大家使用jupyter notebook或者google的colab快速实现一下,废话不用多说,只需要“show me code”就ok,任何语言都是一样的,首先先导入包,如下图所示,后续导入torch和torch.nn包大家可以默认养成习惯,几乎是必须的包,这边还导入了matplotlib包,主要是用来方便用户进行数据可视化,方便用户理解,还记得之前PyTorch入门教材里面写的,可视化和coding是学习深度学习框架PyTorch中不二的法门。

数据部分处理

导入好相关的包,接下来我们就开始准备相关的数据了。刚才也说了,我们可以将我们生活各类的数据转换成Tensor,因为我们这边是入门,所以我们就先造一些数值类数据,方便大家理解,

这边是使用y=ax+b的模式来造数据的,这边我们是知道a和b的值的,这样我们就可以造出很多成对的(x,y)了,这里我们造出100对(x,y)送入到模型,看看模型能不能够通过我们送入的(x,y)键值对能不能够学习出a和b的值。

这边我们假设a等于0.7,b等于0.3。然后我们使用我们之前学到的torch.arange的api,从0开始,步长是0.02,到1为止(不包含1,左闭右开),这边大家顺便理解一下unsqueeze的作用。

好了,我们现在已经有了mock的数据了,并且我们也知道我们这次的模型最后需要得到的“答案”了,接下来我们要对数据进行简单的划分,我们需要将我们自造的这些数据划分为“训练集”,“验证集”和“测试集”了,相信如果对机器学习有过了解的话,这块应该是不陌生的。

  • 训练集:这个集合里面的数据就是我们需要传送给模型,让模型进行学习,通过训练集里面的数据可以让模型找到其中的规律,当然这部分的数据的量级是比较大的,大概是80%
  • 验证集:这部分数据是不给模型进行学习的,也就是说这部分的数据就像单元测试,模型无法提前获取到这个数据集的答案的,需要模型根据输出,根据平时在训练集所获得到的经验来计算相关的输出,当然这个验证集的数据不是必须的,就像素质教育一样,是没有单元测试一样。
  • 测试集:这是需要模型根据测试集的输入传递给模型,让模型输出相关的结果,并且让这个结果和测试集正确的答案做对比,看看模型经过学习的准确率是否符合预期,这是评估模型好坏的一个重要数据集。

了解到数据集的理论知识后,接下来就用python的一些基本操作来进行数据集的划分。因为这是我们的比较简单的测试项目,就没有划分验证集,只有训练集和测试集,我们一共有50个数据对,按照4:1的原则,我们训练集是40个数据对,10个测试数据对。如下所示:

好了,我们现在就是用之前介绍的用“可视化”来方便我们理解,输入和输出之间的关系,这块就涉及到matplotlib里面的基本知识了,我们这里主要是学习PyTorch的,这边就不过多介绍了,我们就简单地把代码敲一遍,有一个大概印象就可以。

我们简单运行一下:

通过可视化,我们能够很清晰地发现是成线性关系的,说明在一些案例中,进行可视化确实可以方便我们理解,不过后面的课程我在学习的过程中,模型的复杂程度超过了可视化的范畴,更多地还是需要自己多看,来回看和来回理解的。

模型搭建

接下来就是正式的环节,需要搭建我们的模型了,我们之前说过选择一个合适的模型来解决我们的问题是至关重要的,但是在这个Hello world关卡里,我们是知道“答案”的,所以我们知道可以使用线性回归模型是能够解决我们的问题的。所以反而对于我们新手而言,更加重要的就是使用PyTorch来写出我们的线性回归模型。

我们来稍微解释一下这段代码,毕竟这是入门级的模型,里面还有的信息相对比较少也比较好理解:

  • 我们构建了一个类,且这个类继承于PyTorch的nn.Module
  • 在init初始化函数里面,我们定义了weights和bias这2个可学习的参数,是使用我们之前学习过torch.randn初始化的一个参数,类型是torch.float,并且requires_grad是true
  • forward方法也是必须要写的,这个方法就是记录前向过程的函数,入参就是一个tensor,最终就是返回这个模型使用weights* x + bias 这个的前向结果作为这个模型的输出。

详细解释

一般来说构建一个网络模型PyTorch提供了四个比较基础的模块方便我们来构建各色各样的网络模型来满足各个不同的复杂场景,不管是后面的LLM模型还是Stable Diffusion模型,我们都是使用PyTorch提供的这四个模块进行构建的,这四个基础模型是

这些模块包含的东西确实非常多,看起来也比较复杂,几乎涵盖了PyTorch搭建模型的所有东西,但是仔细梳理一下你也会发现并不是很难。

torch.nn的模块里面包含了构建复杂模型的基础模块。

torch.nn.Parameter这个里面存储了模型里面的各项tensor参数,我们通过设置requires_grad设置为True的方式来确保可以用梯度下降的方式来更新模型的参数。

torch.nn.Module 这个是基类,是所有子模型的父类,也就是说如果我们使用PyTorch来构建,我们都需要继承torch.nn.Module,并且实现forward方法。

torch.optim 这个包内包含了各项各种各样的优化算法,这个里面optim里面是说在梯度下降的时候如何更好地提高梯度下降的效率和降低相关的loss。

forward函数,这个是所有nn.Module必须需要实现的类,定义了我们在具体模型前向传播的逻辑。

好的,到此为止,我们已经快速讲解了nn.Module模型信息了,我们需要手敲相关的模型,加深我们的理解,下一个小节,我们就开始讲解模型的推理和训练代码coding的流程了,大家周末愉快,本周降温,大家记得保暖。

相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: