PyTorch 教程:第五篇 —— Build Neural Networks 教程
ztj100 2025-07-24 23:23 22 浏览 0 评论
PyTorch 神经网络模型构建详解:第五篇 —— Build Neural Networks 教程
导语
神经网络是深度学习的核心,PyTorch 以其灵活、易扩展的模型定义方式受到了开发者和研究者的广泛欢迎。本章将带你全流程掌握 PyTorch 模型的定义、结构扩展和前向传播原理。无论你想搭建最基础的全连接神经网络,还是构建更复杂的卷积、循环网络,本章都能为你打下坚实基础。
你将学到:
- 1. 如何通过继承 nn.Module 自定义模型结构;
- 2. 常用的神经网络层和激活函数如何组合;
- 3. 如何灵活实现模型的前向传播逻辑;
- 4. 在 GPU 上高效部署神经网络模型;
- 5. 实用技巧与常见问题解答。
一、模型构建的基本思想
概念讲解
在 PyTorch 中,所有神经网络模型都应继承自 torch.nn.Module,并重写两个核心方法:
- o __init__():定义网络结构(如各层和参数);
- o forward():定义输入数据如何前向传递经过各层得到输出。
这一机制既保证了结构的清晰可维护,也支持极大灵活性——你可以在 forward里嵌入分支、循环、条件判断等任意 Python 逻辑。
为什么要这样设计?
- o 易扩展:可像写普通 Python 类一样随时添加、修改模型结构;
- o 模块化:可将常用网络结构拆分成子模块,易于组合和重用;
- o 支持动态图:前向传播过程中,可动态决定模型行为(如 Transformer、RNN 中常见的分支结构)。
二、从零定义一个神经网络
步骤详解
- 1. 继承 nn.Module 并实现构造方法
- o 在 __init__ 方法中定义各层(如 nn.Linear、nn.Conv2d)。
- o 用 super().__init__() 保证父类初始化正确执行。
- 2. 实现 forward 方法
- o 描述数据经过每一层的计算逻辑。
- o forward 方法会被训练与推理自动调用,无需手动调用。
代码实战:三层全连接网络(MLP)
import torch
from torch import nn
# 定义一个简单的神经网络
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
# 1. 拉平成一维向量
self.flatten = nn.Flatten()
# 2. 三层全连接网络(带激活函数)
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512), # 输入层,28*28=784(如用于MNIST)
nn.ReLU(), # 激活函数
nn.Linear(512, 512), # 隐藏层
nn.ReLU(),
nn.Linear(512, 10) # 输出层,假设10类
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 实例化模型并打印结构
model = NeuralNetwork()
print(model)
输出结构:
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
思考题
- 1. 为什么 forward() 不能直接命名为其他方法?
答案:PyTorch 框架内部会自动调用 forward() 方法实现正向传播。如果不用这个名字,model(input) 语法和各种内置训练工具都无法正常工作。 - 2. nn.Sequential 有什么优点?
答案:nn.Sequential 适用于将多个层串联,结构简单,代码更简洁。对于需要分支、跳连等复杂结构则需自定义 forward。
三、常用网络层与激活函数
概念讲解
torch.nn 提供了丰富的神经网络层与激活函数,包括:
- o 全连接层:nn.Linear
- o 卷积层:nn.Conv1d、nn.Conv2d、nn.Conv3d
- o 池化层:nn.MaxPool2d、nn.AvgPool2d
- o 归一化层:nn.BatchNorm1d、nn.LayerNorm
- o 激活函数:nn.ReLU、nn.Sigmoid、nn.Tanh、nn.Softmax
- o Dropout:nn.Dropout,训练时防止过拟合
你可以在 __init__ 中组合这些模块,也可在 forward 里动态调用。
代码实战:简单卷积神经网络(CNN)
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_stack = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 输入通道1, 输出通道32
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 降采样
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(64*7*7, 10) # 以MNIST为例,输入图片28x28,经过两次2x2池化后为7x7
def forward(self, x):
x = self.conv_stack(x)
x = self.flatten(x)
x = self.fc(x)
return x
cnn_model = SimpleCNN()
print(cnn_model)
常见问题解答
- o Q:如何为模型添加 Dropout?
A:可在 nn.Sequential 或 forward 中加入 nn.Dropout(p),如 nn.Dropout(0.5),在训练时随机置零部分神经元。 - o Q:如何为不同输入通道数适配不同网络?
A:将首层卷积的 in_channels 参数设为对应的输入通道数。例如彩色图片为3,灰度为1。
四、模型在设备上的管理(CPU / GPU)
概念讲解
PyTorch 支持将模型和数据灵活迁移至 CPU 或 GPU。通常流程为:
- o 检查是否有 GPU 可用:torch.cuda.is_available()
- o 用 .to(device) 或 .cuda() 将模型和张量迁移到目标设备
- o 训练与推理时,确保所有输入和模型都在同一设备上
代码实战
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用设备:", device)
model = NeuralNetwork().to(device) # 模型迁移到设备
# 随机生成一组输入,模拟一个 batch(如MNIST:batch=64,1通道,28x28)
X = torch.rand(64, 1, 28, 28, device=device)
logits = model(X)
print("模型输出 shape:", logits.shape) # [64, 10]
思考题
- 1. 如果模型和输入数据不在同一设备上会发生什么?
答案:会报错:“Expected all tensors to be on the same device” 或类似信息。模型与输入需严格在同一设备上。
五、模型子模块与模块复用
概念讲解
一个大型神经网络可以由多个子模块(子类化 nn.Module)嵌套组合,极大提升复用性和可读性。例如常用的残差块、注意力模块都可独立实现后集成。
代码实战:模块化残差块(ResBlock)
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity # 残差连接
out = self.relu(out)
return out
# 将残差块集成到更大网络
class NetWithResBlock(nn.Module):
def __init__(self):
super().__init__()
self.input_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.resblock = ResBlock(16)
self.pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.input_layer(x)
x = self.resblock(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
model = NetWithResBlock()
print(model)
六、模型正向传播与自动求导
概念讲解
调用 model(input) 或 model.forward(input) 即可执行正向传播。PyTorch 会自动为所有参数建立计算图,训练时自动反向传播。只要参数属于 nn.Module 的子模块,都会被自动追踪与优化。
代码实战
model = NeuralNetwork()
X = torch.rand(2, 1, 28, 28) # 模拟2个输入样本
output = model(X) # 正向传播
print("网络输出:", output)
思考题
- 1. 通过 model.parameters() 能获取到哪些参数?
答案:所有注册在模型及其子模块中的可学习参数(如权重、偏置等),通常用于优化器初始化。 - 2. with torch.no_grad(): 有什么作用?
答案:在该代码块内禁用自动求导,节省显存、加快推理,常用于模型评估和预测。
七、实用技巧与常见问题
- o 保存/加载模型结构与参数
- o 保存参数:torch.save(model.state_dict(), "model.pth")
- o 加载参数:model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))
model.eval() - o 查看模型各层参数信息for name, param in model.named_parameters():
print(name, param.shape) - o 灵活控制训练与推理模式
- o 训练模式:model.train()
- o 评估/推理模式:model.eval()(影响如 BatchNorm、Dropout 等行为)
知识小结
- o PyTorch 推荐以 nn.Module 子类化方式构建神经网络;
- o 所有层和可学习参数建议注册为成员变量,forward 灵活实现数据流;
- o 支持自定义分支、循环、条件结构,可高度扩展;
- o 常用网络层、激活、Dropout 等模块丰富,适合拼装多种结构;
- o 可通过嵌套子模块组织大型神经网络,提升可复用性与可维护性;
- o 充分利用 GPU 与自动求导,训练与推理代码高度一致。
下一篇预告:我们将带你深入理解 PyTorch 的损失函数和优化器配置,让你的模型高效收敛并实现最佳性能,敬请期待!
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统...
- systemd是什么如何使用_systemd/system
-
systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
-
Linux系统日常巡检脚本,巡检内容包含了,磁盘,...
- 7,MySQL管理员用户管理_mysql 管理员用户
-
一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
-
1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
-
设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...
- linux系统启动流程和服务管理,带你进去系统的世界
-
Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...
- CentOS7系统如何修改主机名_centos更改主机名称
-
请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
-
在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
-
为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...
- Kubernetes 高可用(HA)集群部署指南
-
Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...
- Linux项目开发,你必须了解Systemd服务!
-
1.Systemd简介...
- Linux系统systemd服务管理工具使用技巧
-
简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...
- Linux下NetworkManager和network的和平共处
-
简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
Linux下NetworkManager和network的和平共处
-
Kubernetes 高可用(HA)集群部署指南
-
linux系统启动流程和服务管理,带你进去系统的世界
-
7,MySQL管理员用户管理_mysql 管理员用户
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- systemd是什么如何使用_systemd/system
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
- 7,MySQL管理员用户管理_mysql 管理员用户
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
- linux系统启动流程和服务管理,带你进去系统的世界
- CentOS7系统如何修改主机名_centos更改主机名称
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)