PyTorch 教程:第五篇 —— Build Neural Networks 教程
ztj100 2025-07-24 23:23 6 浏览 0 评论
PyTorch 神经网络模型构建详解:第五篇 —— Build Neural Networks 教程
导语
神经网络是深度学习的核心,PyTorch 以其灵活、易扩展的模型定义方式受到了开发者和研究者的广泛欢迎。本章将带你全流程掌握 PyTorch 模型的定义、结构扩展和前向传播原理。无论你想搭建最基础的全连接神经网络,还是构建更复杂的卷积、循环网络,本章都能为你打下坚实基础。
你将学到:
- 1. 如何通过继承 nn.Module 自定义模型结构;
- 2. 常用的神经网络层和激活函数如何组合;
- 3. 如何灵活实现模型的前向传播逻辑;
- 4. 在 GPU 上高效部署神经网络模型;
- 5. 实用技巧与常见问题解答。
一、模型构建的基本思想
概念讲解
在 PyTorch 中,所有神经网络模型都应继承自 torch.nn.Module,并重写两个核心方法:
- o __init__():定义网络结构(如各层和参数);
- o forward():定义输入数据如何前向传递经过各层得到输出。
这一机制既保证了结构的清晰可维护,也支持极大灵活性——你可以在 forward里嵌入分支、循环、条件判断等任意 Python 逻辑。
为什么要这样设计?
- o 易扩展:可像写普通 Python 类一样随时添加、修改模型结构;
- o 模块化:可将常用网络结构拆分成子模块,易于组合和重用;
- o 支持动态图:前向传播过程中,可动态决定模型行为(如 Transformer、RNN 中常见的分支结构)。
二、从零定义一个神经网络
步骤详解
- 1. 继承 nn.Module 并实现构造方法
- o 在 __init__ 方法中定义各层(如 nn.Linear、nn.Conv2d)。
- o 用 super().__init__() 保证父类初始化正确执行。
- 2. 实现 forward 方法
- o 描述数据经过每一层的计算逻辑。
- o forward 方法会被训练与推理自动调用,无需手动调用。
代码实战:三层全连接网络(MLP)
import torch
from torch import nn
# 定义一个简单的神经网络
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
# 1. 拉平成一维向量
self.flatten = nn.Flatten()
# 2. 三层全连接网络(带激活函数)
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512), # 输入层,28*28=784(如用于MNIST)
nn.ReLU(), # 激活函数
nn.Linear(512, 512), # 隐藏层
nn.ReLU(),
nn.Linear(512, 10) # 输出层,假设10类
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 实例化模型并打印结构
model = NeuralNetwork()
print(model)
输出结构:
NeuralNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
思考题
- 1. 为什么 forward() 不能直接命名为其他方法?
答案:PyTorch 框架内部会自动调用 forward() 方法实现正向传播。如果不用这个名字,model(input) 语法和各种内置训练工具都无法正常工作。 - 2. nn.Sequential 有什么优点?
答案:nn.Sequential 适用于将多个层串联,结构简单,代码更简洁。对于需要分支、跳连等复杂结构则需自定义 forward。
三、常用网络层与激活函数
概念讲解
torch.nn 提供了丰富的神经网络层与激活函数,包括:
- o 全连接层:nn.Linear
- o 卷积层:nn.Conv1d、nn.Conv2d、nn.Conv3d
- o 池化层:nn.MaxPool2d、nn.AvgPool2d
- o 归一化层:nn.BatchNorm1d、nn.LayerNorm
- o 激活函数:nn.ReLU、nn.Sigmoid、nn.Tanh、nn.Softmax
- o Dropout:nn.Dropout,训练时防止过拟合
你可以在 __init__ 中组合这些模块,也可在 forward 里动态调用。
代码实战:简单卷积神经网络(CNN)
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv_stack = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 输入通道1, 输出通道32
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 降采样
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.flatten = nn.Flatten()
self.fc = nn.Linear(64*7*7, 10) # 以MNIST为例,输入图片28x28,经过两次2x2池化后为7x7
def forward(self, x):
x = self.conv_stack(x)
x = self.flatten(x)
x = self.fc(x)
return x
cnn_model = SimpleCNN()
print(cnn_model)
常见问题解答
- o Q:如何为模型添加 Dropout?
A:可在 nn.Sequential 或 forward 中加入 nn.Dropout(p),如 nn.Dropout(0.5),在训练时随机置零部分神经元。 - o Q:如何为不同输入通道数适配不同网络?
A:将首层卷积的 in_channels 参数设为对应的输入通道数。例如彩色图片为3,灰度为1。
四、模型在设备上的管理(CPU / GPU)
概念讲解
PyTorch 支持将模型和数据灵活迁移至 CPU 或 GPU。通常流程为:
- o 检查是否有 GPU 可用:torch.cuda.is_available()
- o 用 .to(device) 或 .cuda() 将模型和张量迁移到目标设备
- o 训练与推理时,确保所有输入和模型都在同一设备上
代码实战
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用设备:", device)
model = NeuralNetwork().to(device) # 模型迁移到设备
# 随机生成一组输入,模拟一个 batch(如MNIST:batch=64,1通道,28x28)
X = torch.rand(64, 1, 28, 28, device=device)
logits = model(X)
print("模型输出 shape:", logits.shape) # [64, 10]
思考题
- 1. 如果模型和输入数据不在同一设备上会发生什么?
答案:会报错:“Expected all tensors to be on the same device” 或类似信息。模型与输入需严格在同一设备上。
五、模型子模块与模块复用
概念讲解
一个大型神经网络可以由多个子模块(子类化 nn.Module)嵌套组合,极大提升复用性和可读性。例如常用的残差块、注意力模块都可独立实现后集成。
代码实战:模块化残差块(ResBlock)
class ResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity # 残差连接
out = self.relu(out)
return out
# 将残差块集成到更大网络
class NetWithResBlock(nn.Module):
def __init__(self):
super().__init__()
self.input_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
self.resblock = ResBlock(16)
self.pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.input_layer(x)
x = self.resblock(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
model = NetWithResBlock()
print(model)
六、模型正向传播与自动求导
概念讲解
调用 model(input) 或 model.forward(input) 即可执行正向传播。PyTorch 会自动为所有参数建立计算图,训练时自动反向传播。只要参数属于 nn.Module 的子模块,都会被自动追踪与优化。
代码实战
model = NeuralNetwork()
X = torch.rand(2, 1, 28, 28) # 模拟2个输入样本
output = model(X) # 正向传播
print("网络输出:", output)
思考题
- 1. 通过 model.parameters() 能获取到哪些参数?
答案:所有注册在模型及其子模块中的可学习参数(如权重、偏置等),通常用于优化器初始化。 - 2. with torch.no_grad(): 有什么作用?
答案:在该代码块内禁用自动求导,节省显存、加快推理,常用于模型评估和预测。
七、实用技巧与常见问题
- o 保存/加载模型结构与参数
- o 保存参数:torch.save(model.state_dict(), "model.pth")
- o 加载参数:model = NeuralNetwork()
model.load_state_dict(torch.load("model.pth"))
model.eval() - o 查看模型各层参数信息for name, param in model.named_parameters():
print(name, param.shape) - o 灵活控制训练与推理模式
- o 训练模式:model.train()
- o 评估/推理模式:model.eval()(影响如 BatchNorm、Dropout 等行为)
知识小结
- o PyTorch 推荐以 nn.Module 子类化方式构建神经网络;
- o 所有层和可学习参数建议注册为成员变量,forward 灵活实现数据流;
- o 支持自定义分支、循环、条件结构,可高度扩展;
- o 常用网络层、激活、Dropout 等模块丰富,适合拼装多种结构;
- o 可通过嵌套子模块组织大型神经网络,提升可复用性与可维护性;
- o 充分利用 GPU 与自动求导,训练与推理代码高度一致。
下一篇预告:我们将带你深入理解 PyTorch 的损失函数和优化器配置,让你的模型高效收敛并实现最佳性能,敬请期待!
相关推荐
- 10条军规:电商API从数据泄露到高可用的全链路防护
-
电商API接口避坑指南:数据安全、版本兼容与成本控制的10个教训在电商行业数字化转型中,API接口已成为连接平台、商家、用户与第三方服务的核心枢纽。然而,从数据泄露到版本冲突,从成本超支到系统崩溃,A...
- Python 文件处理在实际项目中的困难与应对策略
-
在Python项目开发,文件处理是一项基础且关键的任务。然而,在实际项目中,Python文件处理往往会面临各种各样的困难和挑战,从文件格式兼容性、编码问题,到性能瓶颈、并发访问冲突等。本文将深入...
- The Future of Manufacturing with Custom CNC Parts
-
ThefutureofmanufacturingisincreasinglybeingshapedbytheintegrationofcustomCNC(ComputerNumericalContro...
- Innovative Solutions in Custom CNC Machining
-
Inrecentyears,thelandscapeofcustomCNCmachininghasevolvedrapidly,drivenbyincreasingdemandsforprecisio...
- C#.NET serilog 详解(c# repository)
-
简介Serilog是...
- Custom CNC Machining for Small Batch Production
-
Inmodernmanufacturing,producingsmallbatchesofcustomizedpartshasbecomeanincreasinglycommondemandacros...
- Custom CNC Machining for Customized Solutions
-
Thedemandforcustomizedsolutionsinmanufacturinghasgrownsignificantly,drivenbydiverseindustryneedsandt...
- Revolutionizing Manufacturing with Custom CNC Parts
-
Understandinghowmanufacturingisevolving,especiallythroughtheuseofcustomCNCparts,canseemcomplex.Thisa...
- Breaking Boundaries with Custom CNC Parts
-
BreakingboundarieswithcustomCNCpartsinvolvesexploringhowadvancedmanufacturingtechniquesaretransformi...
- Custom CNC Parts for Aerospace Industry
-
Intherealmofaerospacemanufacturing,precisionandreliabilityareparamount.Thecomponentsthatmakeupaircra...
- Cnc machining for custom parts and components
-
UnderstandingCNCmachiningforcustompartsandcomponentsinvolvesexploringitsprocesses,advantages,andcomm...
- 洞察宇宙(十八):深入理解C语言内存管理
-
分享乐趣,传播快乐,增长见识,留下美好。亲爱的您,这里是LearingYard学苑!今天小编为大家带来“深入理解C语言内存管理”...
- The Art of Crafting Custom CNC Parts
-
UnderstandingtheprocessofcreatingcustomCNCpartscanoftenbeconfusingforbeginnersandevensomeexperienced...
- Tailored Custom CNC Solutions for Automotive
-
Intheautomotiveindustry,precisionandefficiencyarecrucialforproducinghigh-qualityvehiclecomponents.Ta...
- 关于WEB服务器(.NET)一些经验累积(一)
-
以前做过技术支持,把一些遇到的问题累积保存起来,现在发出了。1.问题:未能加载文件或程序集“System.EnterpriseServices.Wrapper.dll”或它的某一个依赖项。拒绝访问。解...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 10条军规:电商API从数据泄露到高可用的全链路防护
- Python 文件处理在实际项目中的困难与应对策略
- The Future of Manufacturing with Custom CNC Parts
- Innovative Solutions in Custom CNC Machining
- C#.NET serilog 详解(c# repository)
- Custom CNC Machining for Small Batch Production
- Custom CNC Machining for Customized Solutions
- Revolutionizing Manufacturing with Custom CNC Parts
- Breaking Boundaries with Custom CNC Parts
- Custom CNC Parts for Aerospace Industry
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)