百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

PyTorch 教程:第五篇 —— Build Neural Networks 教程

ztj100 2025-07-24 23:23 6 浏览 0 评论

PyTorch 神经网络模型构建详解:第五篇 —— Build Neural Networks 教程

导语

神经网络是深度学习的核心,PyTorch 以其灵活、易扩展的模型定义方式受到了开发者和研究者的广泛欢迎。本章将带你全流程掌握 PyTorch 模型的定义、结构扩展和前向传播原理。无论你想搭建最基础的全连接神经网络,还是构建更复杂的卷积、循环网络,本章都能为你打下坚实基础。

你将学到:

  1. 1. 如何通过继承 nn.Module 自定义模型结构;
  2. 2. 常用的神经网络层和激活函数如何组合;
  3. 3. 如何灵活实现模型的前向传播逻辑;
  4. 4. 在 GPU 上高效部署神经网络模型;
  5. 5. 实用技巧与常见问题解答。

一、模型构建的基本思想

概念讲解

在 PyTorch 中,所有神经网络模型都应继承自 torch.nn.Module,并重写两个核心方法:

  • o __init__():定义网络结构(如各层和参数);
  • o forward():定义输入数据如何前向传递经过各层得到输出。

这一机制既保证了结构的清晰可维护,也支持极大灵活性——你可以在 forward里嵌入分支、循环、条件判断等任意 Python 逻辑。

为什么要这样设计?

  • o 易扩展:可像写普通 Python 类一样随时添加、修改模型结构;
  • o 模块化:可将常用网络结构拆分成子模块,易于组合和重用;
  • o 支持动态图:前向传播过程中,可动态决定模型行为(如 Transformer、RNN 中常见的分支结构)。

二、从零定义一个神经网络

步骤详解

  1. 1. 继承 nn.Module 并实现构造方法
  2. o 在 __init__ 方法中定义各层(如 nn.Linearnn.Conv2d)。
  3. o 用 super().__init__() 保证父类初始化正确执行。
  4. 2. 实现 forward 方法
  5. o 描述数据经过每一层的计算逻辑。
  6. o forward 方法会被训练与推理自动调用,无需手动调用。

代码实战:三层全连接网络(MLP)

import torch
from torch import nn

# 定义一个简单的神经网络
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # 1. 拉平成一维向量
        self.flatten = nn.Flatten()
        # 2. 三层全连接网络(带激活函数)
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),    # 输入层,28*28=784(如用于MNIST)
            nn.ReLU(),                # 激活函数
            nn.Linear(512, 512),      # 隐藏层
            nn.ReLU(),
            nn.Linear(512, 10)        # 输出层,假设10类
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# 实例化模型并打印结构
model = NeuralNetwork()
print(model)

输出结构:

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

思考题

  1. 1. 为什么 forward() 不能直接命名为其他方法?
    答案:PyTorch 框架内部会自动调用 forward() 方法实现正向传播。如果不用这个名字,model(input) 语法和各种内置训练工具都无法正常工作。
  2. 2. nn.Sequential 有什么优点?
    答案nn.Sequential 适用于将多个层串联,结构简单,代码更简洁。对于需要分支、跳连等复杂结构则需自定义 forward

三、常用网络层与激活函数

概念讲解

torch.nn 提供了丰富的神经网络层与激活函数,包括:

  • o 全连接层nn.Linear
  • o 卷积层nn.Conv1dnn.Conv2dnn.Conv3d
  • o 池化层nn.MaxPool2dnn.AvgPool2d
  • o 归一化层nn.BatchNorm1dnn.LayerNorm
  • o 激活函数nn.ReLUnn.Sigmoidnn.Tanhnn.Softmax
  • o Dropoutnn.Dropout,训练时防止过拟合

你可以在 __init__ 中组合这些模块,也可在 forward 里动态调用。

代码实战:简单卷积神经网络(CNN)

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1), # 输入通道1, 输出通道32
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),                # 降采样
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64*7*7, 10) # 以MNIST为例,输入图片28x28,经过两次2x2池化后为7x7

    def forward(self, x):
        x = self.conv_stack(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

cnn_model = SimpleCNN()
print(cnn_model)

常见问题解答

  • o Q:如何为模型添加 Dropout?
    A:可在 nn.Sequentialforward 中加入 nn.Dropout(p),如 nn.Dropout(0.5),在训练时随机置零部分神经元。
  • o Q:如何为不同输入通道数适配不同网络?
    A:将首层卷积的 in_channels 参数设为对应的输入通道数。例如彩色图片为3,灰度为1。

四、模型在设备上的管理(CPU / GPU)

概念讲解

PyTorch 支持将模型和数据灵活迁移至 CPU 或 GPU。通常流程为:

  • o 检查是否有 GPU 可用:torch.cuda.is_available()
  • o 用 .to(device).cuda() 将模型和张量迁移到目标设备
  • o 训练与推理时,确保所有输入和模型都在同一设备上

代码实战

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("使用设备:", device)

model = NeuralNetwork().to(device) # 模型迁移到设备

# 随机生成一组输入,模拟一个 batch(如MNIST:batch=64,1通道,28x28)
X = torch.rand(64, 1, 28, 28, device=device)
logits = model(X)
print("模型输出 shape:", logits.shape) # [64, 10]

思考题

  1. 1. 如果模型和输入数据不在同一设备上会发生什么?
    答案:会报错:“Expected all tensors to be on the same device” 或类似信息。模型与输入需严格在同一设备上。

五、模型子模块与模块复用

概念讲解

一个大型神经网络可以由多个子模块(子类化 nn.Module)嵌套组合,极大提升复用性和可读性。例如常用的残差块、注意力模块都可独立实现后集成。

代码实战:模块化残差块(ResBlock)

class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    
    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity # 残差连接
        out = self.relu(out)
        return out

# 将残差块集成到更大网络
class NetWithResBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_layer = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.resblock = ResBlock(16)
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(16, 10)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.resblock(x)
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = NetWithResBlock()
print(model)

六、模型正向传播与自动求导

概念讲解

调用 model(input)model.forward(input) 即可执行正向传播。PyTorch 会自动为所有参数建立计算图,训练时自动反向传播。只要参数属于 nn.Module 的子模块,都会被自动追踪与优化。

代码实战

model = NeuralNetwork()
X = torch.rand(2, 1, 28, 28)   # 模拟2个输入样本
output = model(X)              # 正向传播
print("网络输出:", output)

思考题

  1. 1. 通过 model.parameters() 能获取到哪些参数?
    答案:所有注册在模型及其子模块中的可学习参数(如权重、偏置等),通常用于优化器初始化。
  2. 2. with torch.no_grad(): 有什么作用?
    答案:在该代码块内禁用自动求导,节省显存、加快推理,常用于模型评估和预测。

七、实用技巧与常见问题

  • o 保存/加载模型结构与参数
    • o 保存参数:torch.save(model.state_dict(), "model.pth")
    • o 加载参数:model = NeuralNetwork()
      model.load_state_dict(torch.load(
      "model.pth"))
      model.
      eval()
  • o 查看模型各层参数信息for name, param in model.named_parameters():
    print(name, param.shape)
  • o 灵活控制训练与推理模式
    • o 训练模式:model.train()
    • o 评估/推理模式:model.eval()(影响如 BatchNorm、Dropout 等行为)

知识小结

  • o PyTorch 推荐以 nn.Module 子类化方式构建神经网络;
  • o 所有层和可学习参数建议注册为成员变量,forward 灵活实现数据流;
  • o 支持自定义分支、循环、条件结构,可高度扩展;
  • o 常用网络层、激活、Dropout 等模块丰富,适合拼装多种结构;
  • o 可通过嵌套子模块组织大型神经网络,提升可复用性与可维护性;
  • o 充分利用 GPU 与自动求导,训练与推理代码高度一致。

下一篇预告:我们将带你深入理解 PyTorch 的损失函数和优化器配置,让你的模型高效收敛并实现最佳性能,敬请期待!

相关推荐

10条军规:电商API从数据泄露到高可用的全链路防护

电商API接口避坑指南:数据安全、版本兼容与成本控制的10个教训在电商行业数字化转型中,API接口已成为连接平台、商家、用户与第三方服务的核心枢纽。然而,从数据泄露到版本冲突,从成本超支到系统崩溃,A...

Python 文件处理在实际项目中的困难与应对策略

在Python项目开发,文件处理是一项基础且关键的任务。然而,在实际项目中,Python文件处理往往会面临各种各样的困难和挑战,从文件格式兼容性、编码问题,到性能瓶颈、并发访问冲突等。本文将深入...

The Future of Manufacturing with Custom CNC Parts

ThefutureofmanufacturingisincreasinglybeingshapedbytheintegrationofcustomCNC(ComputerNumericalContro...

Innovative Solutions in Custom CNC Machining

Inrecentyears,thelandscapeofcustomCNCmachininghasevolvedrapidly,drivenbyincreasingdemandsforprecisio...

C#.NET serilog 详解(c# repository)

简介Serilog是...

Custom CNC Machining for Small Batch Production

Inmodernmanufacturing,producingsmallbatchesofcustomizedpartshasbecomeanincreasinglycommondemandacros...

Custom CNC Machining for Customized Solutions

Thedemandforcustomizedsolutionsinmanufacturinghasgrownsignificantly,drivenbydiverseindustryneedsandt...

Revolutionizing Manufacturing with Custom CNC Parts

Understandinghowmanufacturingisevolving,especiallythroughtheuseofcustomCNCparts,canseemcomplex.Thisa...

Breaking Boundaries with Custom CNC Parts

BreakingboundarieswithcustomCNCpartsinvolvesexploringhowadvancedmanufacturingtechniquesaretransformi...

Custom CNC Parts for Aerospace Industry

Intherealmofaerospacemanufacturing,precisionandreliabilityareparamount.Thecomponentsthatmakeupaircra...

Cnc machining for custom parts and components

UnderstandingCNCmachiningforcustompartsandcomponentsinvolvesexploringitsprocesses,advantages,andcomm...

洞察宇宙(十八):深入理解C语言内存管理

分享乐趣,传播快乐,增长见识,留下美好。亲爱的您,这里是LearingYard学苑!今天小编为大家带来“深入理解C语言内存管理”...

The Art of Crafting Custom CNC Parts

UnderstandingtheprocessofcreatingcustomCNCpartscanoftenbeconfusingforbeginnersandevensomeexperienced...

Tailored Custom CNC Solutions for Automotive

Intheautomotiveindustry,precisionandefficiencyarecrucialforproducinghigh-qualityvehiclecomponents.Ta...

关于WEB服务器(.NET)一些经验累积(一)

以前做过技术支持,把一些遇到的问题累积保存起来,现在发出了。1.问题:未能加载文件或程序集“System.EnterpriseServices.Wrapper.dll”或它的某一个依赖项。拒绝访问。解...

取消回复欢迎 发表评论: