PyTorch编程常规套路(上)(pytorch 60分钟教程)
ztj100 2024-11-14 19:24 25 浏览 0 评论
之前的几篇介绍完PyTorch基本并且常用的API后,大家肯定对PyTorch对Tensor的操作有了一个基本的理解,当然我们介绍的API并不是很完整,也相对比较基础。还有很多复杂的API和一些比较不好理解的API没有讲解。不过我们后面遇到一例case也可以单独和大家一起学习分享吃透相关的问题。
今天开始我们就要开始学习PyTorch我们日常编程的基本套路,这个套路基本上是不怎么变换的,换的也是一些细节,换的也是对应场景的模型,当然本次学习大家最好都能够手动coding一下代码,主要是留下一些印象。
流程基本介绍
回归正题,正常来说PyTorch编程的常见套路如下:
一般来说,整体流程如下解释:
1.首先需要获取当前需要处理的数据,不管是图片,文字,声音,表格,视频等都需要通过PyTorch提供的一些API把这些数据转换成Tensor
2.我们需要选择一个和当前场景比较适合的深度学习模型,这边一般需要遵循如下的一些条件和规律:
- 首先这个模型是能解决你的问题的,没有万能模型。
- 我们也要选择对应的损失函数和优化器
- 训练阶段一般都是几个循环
3.有了处理好的数据,也有了正确的模型(最好走一个sample跑一下模型)接下来我们就应该直接跑一下训练一下模型的参数,通过训练让模型学习到这些数据的特性,简而言之,找到相关规律,将这些规律能够很好地在模型参数里面进行体现。
4.训练好了模型接下来一步就是评估模型是否符合预期,就是用测试数据来评估模型,是否真正地学习到这些数据潜在的规律。
5.在验证完模型后,如果模型符合预期,就需要把训练好的模型参数保存下来,持久化好参数,方便后续就可以直接加载使用这些训练好的模型参数。
基本上所有的PyTorch的编程范式和流程都是这样,在实际使用深度学习解决生活中实际问题的场景中,不同的地方可能就是一些细节,例如模型的复杂度,将实际的数据处理成Tensor的复杂度,训练过程中一些技巧。这边就是需要大家快速了解一下,有一个大概的印象。等后续看到代码的时候,大家就能够快速了解其中的逻辑。至少在hello world中这块还是很简单的。
接下来和大家使用jupyter notebook或者google的colab快速实现一下,废话不用多说,只需要“show me code”就ok,任何语言都是一样的,首先先导入包,如下图所示,后续导入torch和torch.nn包大家可以默认养成习惯,几乎是必须的包,这边还导入了matplotlib包,主要是用来方便用户进行数据可视化,方便用户理解,还记得之前PyTorch入门教材里面写的,可视化和coding是学习深度学习框架PyTorch中不二的法门。
数据部分处理
导入好相关的包,接下来我们就开始准备相关的数据了。刚才也说了,我们可以将我们生活各类的数据转换成Tensor,因为我们这边是入门,所以我们就先造一些数值类数据,方便大家理解,
这边是使用y=ax+b的模式来造数据的,这边我们是知道a和b的值的,这样我们就可以造出很多成对的(x,y)了,这里我们造出100对(x,y)送入到模型,看看模型能不能够通过我们送入的(x,y)键值对能不能够学习出a和b的值。
这边我们假设a等于0.7,b等于0.3。然后我们使用我们之前学到的torch.arange的api,从0开始,步长是0.02,到1为止(不包含1,左闭右开),这边大家顺便理解一下unsqueeze的作用。
好了,我们现在已经有了mock的数据了,并且我们也知道我们这次的模型最后需要得到的“答案”了,接下来我们要对数据进行简单的划分,我们需要将我们自造的这些数据划分为“训练集”,“验证集”和“测试集”了,相信如果对机器学习有过了解的话,这块应该是不陌生的。
- 训练集:这个集合里面的数据就是我们需要传送给模型,让模型进行学习,通过训练集里面的数据可以让模型找到其中的规律,当然这部分的数据的量级是比较大的,大概是80%
- 验证集:这部分数据是不给模型进行学习的,也就是说这部分的数据就像单元测试,模型无法提前获取到这个数据集的答案的,需要模型根据输出,根据平时在训练集所获得到的经验来计算相关的输出,当然这个验证集的数据不是必须的,就像素质教育一样,是没有单元测试一样。
- 测试集:这是需要模型根据测试集的输入传递给模型,让模型输出相关的结果,并且让这个结果和测试集正确的答案做对比,看看模型经过学习的准确率是否符合预期,这是评估模型好坏的一个重要数据集。
了解到数据集的理论知识后,接下来就用python的一些基本操作来进行数据集的划分。因为这是我们的比较简单的测试项目,就没有划分验证集,只有训练集和测试集,我们一共有50个数据对,按照4:1的原则,我们训练集是40个数据对,10个测试数据对。如下所示:
好了,我们现在就是用之前介绍的用“可视化”来方便我们理解,输入和输出之间的关系,这块就涉及到matplotlib里面的基本知识了,我们这里主要是学习PyTorch的,这边就不过多介绍了,我们就简单地把代码敲一遍,有一个大概印象就可以。
我们简单运行一下:
通过可视化,我们能够很清晰地发现是成线性关系的,说明在一些案例中,进行可视化确实可以方便我们理解,不过后面的课程我在学习的过程中,模型的复杂程度超过了可视化的范畴,更多地还是需要自己多看,来回看和来回理解的。
模型搭建
接下来就是正式的环节,需要搭建我们的模型了,我们之前说过选择一个合适的模型来解决我们的问题是至关重要的,但是在这个Hello world关卡里,我们是知道“答案”的,所以我们知道可以使用线性回归模型是能够解决我们的问题的。所以反而对于我们新手而言,更加重要的就是使用PyTorch来写出我们的线性回归模型。
我们来稍微解释一下这段代码,毕竟这是入门级的模型,里面还有的信息相对比较少也比较好理解:
- 我们构建了一个类,且这个类继承于PyTorch的nn.Module
- 在init初始化函数里面,我们定义了weights和bias这2个可学习的参数,是使用我们之前学习过torch.randn初始化的一个参数,类型是torch.float,并且requires_grad是true
- forward方法也是必须要写的,这个方法就是记录前向过程的函数,入参就是一个tensor,最终就是返回这个模型使用weights* x + bias 这个的前向结果作为这个模型的输出。
详细解释
一般来说构建一个网络模型PyTorch提供了四个比较基础的模块方便我们来构建各色各样的网络模型来满足各个不同的复杂场景,不管是后面的LLM模型还是Stable Diffusion模型,我们都是使用PyTorch提供的这四个模块进行构建的,这四个基础模型是
这些模块包含的东西确实非常多,看起来也比较复杂,几乎涵盖了PyTorch搭建模型的所有东西,但是仔细梳理一下你也会发现并不是很难。
torch.nn的模块里面包含了构建复杂模型的基础模块。
torch.nn.Parameter这个里面存储了模型里面的各项tensor参数,我们通过设置requires_grad设置为True的方式来确保可以用梯度下降的方式来更新模型的参数。
torch.nn.Module 这个是基类,是所有子模型的父类,也就是说如果我们使用PyTorch来构建,我们都需要继承torch.nn.Module,并且实现forward方法。
torch.optim 这个包内包含了各项各种各样的优化算法,这个里面optim里面是说在梯度下降的时候如何更好地提高梯度下降的效率和降低相关的loss。
forward函数,这个是所有nn.Module必须需要实现的类,定义了我们在具体模型前向传播的逻辑。
好的,到此为止,我们已经快速讲解了nn.Module模型信息了,我们需要手敲相关的模型,加深我们的理解,下一个小节,我们就开始讲解模型的推理和训练代码coding的流程了,大家周末愉快,本周降温,大家记得保暖。
相关推荐
- 再说圆的面积-蒙特卡洛(蒙特卡洛方法求圆周率的matlab程序)
-
在微积分-圆的面积和周长(1)介绍微积分方法求解圆的面积,本文使用蒙特卡洛方法求解圆面积。...
- python创建分类器小结(pytorch分类数据集创建)
-
简介:分类是指利用数据的特性将其分成若干类型的过程。监督学习分类器就是用带标记的训练数据建立一个模型,然后对未知数据进行分类。...
- matplotlib——绘制散点图(matplotlib散点图颜色和图例)
-
绘制散点图不同条件(维度)之间的内在关联关系观察数据的离散聚合程度...
- python实现实时绘制数据(python如何绘制)
-
方法一importmatplotlib.pyplotaspltimportnumpyasnpimporttimefrommathimport*plt.ion()#...
- 简单学Python——matplotlib库3——绘制散点图
-
前面我们学习了用matplotlib绘制折线图,今天我们学习绘制散点图。其实简单的散点图与折线图的语法基本相同,只是作图函数由plot()变成了scatter()。下面就绘制一个散点图:import...
- 数据分析-相关性分析可视化(相关性分析数据处理)
-
前面介绍了相关性分析的原理、流程和常用的皮尔逊相关系数和斯皮尔曼相关系数,具体可以参考...
- 免费Python机器学习课程一:线性回归算法
-
学习线性回归的概念并从头开始在python中开发完整的线性回归算法最基本的机器学习算法必须是具有单个变量的线性回归算法。如今,可用的高级机器学习算法,库和技术如此之多,以至于线性回归似乎并不重要。但是...
- 用Python进行机器学习(2)之逻辑回归
-
前面介绍了线性回归,本次介绍的是逻辑回归。逻辑回归虽然名字里面带有“回归”两个字,但是它是一种分类算法,通常用于解决二分类问题,比如某个邮件是否是广告邮件,比如某个评价是否为正向的评价。逻辑回归也可以...
- 【Python机器学习系列】拟合和回归傻傻分不清?一文带你彻底搞懂
-
一、拟合和回归的区别拟合...
- 推荐2个十分好用的pandas数据探索分析神器
-
作者:俊欣来源:关于数据分析与可视化...
- 向量数据库:解锁大模型记忆的关键!选型指南+实战案例全解析
-
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...
- 用Python进行机器学习(11)-主成分分析PCA
-
我们在机器学习中有时候需要处理很多个参数,但是这些参数有时候彼此之间是有着各种关系的,这个时候我们就会想:是否可以找到一种方式来降低参数的个数呢?这就是今天我们要介绍的主成分分析,英文是Princip...
- 神经网络基础深度解析:从感知机到反向传播
-
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...
- Python实现基于机器学习的RFM模型
-
CDA数据分析师出品作者:CDALevelⅠ持证人岗位:数据分析师行业:大数据...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)