Pytorch学习-day8: 损失函数与优化器
ztj100 2025-06-09 07:26 44 浏览 0 评论
学习目标
- 理解损失函数:学习什么是损失函数,为什么需要它,以及常见类型(如 MSE 和 CrossEntropy)。
- 理解优化器:了解优化器如何帮助模型学习,以及 SGD 和 Adam 的基本原理。
- 实践任务:为 Day 7 的 MLP 添加 MSE 损失函数和 SGD 优化器,训练 10 轮,观察模型如何改进预测。
术语解释
1. 损失函数 (Loss Function)
- 是什么:损失函数是衡量模型预测结果与真实答案之间差距的指标。模型的目标是让这个“差距”尽可能小。
- 类比:想象你在玩飞镖,目标是击中靶心(真实答案)。你的飞镖落在哪里(模型预测)与靶心的距离就是“损失”。损失越小,说明你越接近目标。
- 常见类型: MSE (Mean Squared Error, 均方误差):用于回归任务(预测连续值,如房价、温度)。它计算预测值与真实值差的平方平均值。 公式:MSE = (1/n) * Σ(预测值 - 真实值)^2 例子:预测房价是 100 万,真实是 120 万,差 20 万,平方后加权平均。 CrossEntropy (交叉熵损失):用于分类任务(预测类别,如猫狗分类)。它衡量预测概率分布与真实标签分布的差异。 例子:预测一张图片是猫的概率是 0.8,真实是猫,损失小;如果预测是狗,损失大。
- 作用:告诉模型“错在哪里,错多少”,为优化提供方向。
2. 优化器 (Optimizer)
- 是什么:优化器是调整模型参数(权重和偏置)的算法,让损失函数的值变小,模型预测更准确。
- 类比:模型像一个在山谷里找最低点(最小损失)的人。优化器是“导航员”,告诉它每次走哪一步(调整参数)。
- 常见类型: SGD (Stochastic Gradient Descent, 随机梯度下降): 原理:根据损失函数的梯度(斜率),小步调整参数,朝损失下降的方向走。 类比:像下山时看脚下的坡度,慢慢走。 特点:简单,但可能需要较多步数,容易卡在“局部低点”。 Adam: 原理:结合动量法和自适应学习率,比 SGD 更“聪明”,能更快找到最低点。 类比:像个有经验的登山者,知道哪里陡、哪里平,步伐大小自动调整。 特点:收敛更快,适合复杂模型,但参数多。
- 作用:通过反复调整模型参数,降低损失,让模型学到数据的规律。
示例:MLP + MSE + SGD
场景
假设你有一个简单的数据集:输入是 2 个数字(x1, x2),输出是它们的和(y = x1 + x2)。我们用一个 MLP 模型预测这个和,MSE 衡量预测误差,SGD 优化模型。
数据示例
- 输入:[1.0, 2.0],真实输出:3.0
- 输入:[0.5, 1.5],真实输出:2.0
- 目标:模型学会预测任意输入的和。
代码
以下是完整的 PyTorch 代码,包含注释,帮助新手理解每一步。
python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 MLP 模型(Day 7 的假设结构)
class MLP(nn.Module):
# 继承 PyTorch 的 Module 类
def __init__(self, input_size=2, hidden_size=10, output_size=1):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # 第一层:输入到隐藏层
self.relu = nn.ReLU() # 激活函数,让模型学非线性关系
self.fc2 = nn.Linear(hidden_size, output_size) # 第二层:隐藏层到输出
def forward(self, x): # 前向传播
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 设置超参数
input_size = 2 # 输入是 2 个数字
hidden_size = 10 # 隐藏层有 10 个神经元
output_size = 1 # 输出是 1 个数字(和)
learning_rate = 0.01 # 学习率,控制优化器步伐
epochs = 10 # 训练 10 轮
batch_size = 32 # 每批处理 32 个样本
# 生成示例数据
torch.manual_seed(42) # 固定随机种子,结果可重复
X = torch.randn(100, input_size) # 100 个样本,每个有 2 个特征
y = X.sum(dim=1, keepdim=True) + 0.1 * torch.randn(100, 1) # 目标:特征和 + 噪声
# 初始化模型、损失函数、优化器
model = MLP(input_size, hidden_size, output_size) # 创建模型
criterion = nn.MSELoss() # MSE 损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # SGD 优化器
# 训练循环
model.train() # 设置模型为训练模式
for epoch in range(epochs):
total_loss = 0
# 小批量训练
for i in range(0, len(X), batch_size):
# 取一批数据
inputs = X[i:i+batch_size]
targets = y[i:i+batch_size]
# 前向传播
outputs = model(inputs) # 模型预测
loss = criterion(outputs, targets) # 计算 MSE 损失
# 反向传播与优化
optimizer.zero_grad() # 清空上一步的梯度
loss.backward() # 计算当前梯度
optimizer.step() # 根据梯度更新参数
total_loss += loss.item() # 累加损失
# 打印每轮平均损失
avg_loss = total_loss / (len(X) // batch_size)
print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
# 测试模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算
test_input = torch.tensor([[1.0, 2.0]]) # 测试输入
prediction = model(test_input) # 预测
print(f'Test input: {test_input.tolist()}, Prediction: {prediction.tolist()}')
代码逐行解释
- 模型定义: MLP 类定义了一个两层神经网络:输入(2 个特征)→ 隐藏层(10 个神经元)→ 输出(1 个值)。 ReLU 激活函数让模型学到非线性关系(比如复杂的曲线)。 forward 方法描述数据如何通过网络。
- 数据生成: 创建 100 个随机样本,输入是 2 维(X),输出是输入的和加一点噪声(y)。 噪声模拟现实中的不完美数据。
- 损失函数: nn.MSELoss() 计算预测值和真实值的均方误差。 比如,预测 2.8,真实 3.0,误差是 (2.8 - 3.0)^2 = 0.04。
- 优化器: optim.SGD 使用梯度下降更新模型参数。 lr=0.01 控制每次调整的大小,太大可能跳过最低点,太小收敛慢。
- 训练过程: 每轮(epoch)遍历所有数据,分批(batch)处理。 前向传播:输入通过模型得到预测,计算损失。 反向传播:根据损失计算梯度(告诉参数怎么调整)。 优化:SGD 用梯度更新参数,减小损失。 打印平均损失,观察是否下降。
- 测试: 用 [1.0, 2.0] 测试,理想预测接近 3.0。 model.eval() 和 no_grad() 确保不计算梯度,节省内存。
预期输出
运行后,你会看到类似:
Epoch [1/10], Loss: 0.1234
Epoch [2/10], Loss: 0.0987
...
Epoch [10/10], Loss: 0.0123
Test input: [[1.0, 2.0]], Prediction: [[2.9876]]
- 损失逐渐下降,说明模型在学习。
- 预测值接近 3.0,说明模型学会了“求和”。
新手常见问题
- 为什么损失不下降? 学习率可能太高或太低,试试 0.001 或 0.1。 数据可能有问题,检查输入和目标是否匹配。 模型结构可能太简单,增加隐藏层或神经元。
- MSE 和 CrossEntropy 怎么选? 用 MSE 预测连续值(如温度、房价)。 用 CrossEntropy 预测类别(如猫、狗)。
- SGD 和 Adam 哪个好? SGD 简单,适合小数据集,但慢。 Adam 更快,适合复杂模型,但可能过拟合。
- 什么是梯度? 梯度是损失函数对参数的“斜率”,告诉优化器参数该往哪调(增大还是减小)。
资源建议
- PyTorch 文档: MSELoss SGD
- 视频教程: PyTorch 官方 YouTube 频道(如 Intro to PyTorch)。 Aladdin Persson 的 PyTorch 教程(YouTube)。
- 实践: 试试用 CrossEntropy 做分类任务。 调整学习率、隐藏层大小,观察损失变化。
下一步
- 试试 Adam 优化器:替换 optim.SGD 为 optim.Adam(model.parameters(), lr=0.001),比较收敛速度。
- 加载真实数据集:如果有 CSV 或图像数据,我可以帮你写加载代码。
- 可视化:用 Matplotlib 画损失曲线或预测结果。
相关推荐
- 这个 JavaScript Api 已被废弃!请慎用!
-
在开发过程中,我们可能会不自觉地使用一些已经被标记为废弃的JavaScriptAPI。这些...
- JavaScript中10个“过时”的API,你的代码里还在用吗?
-
JavaScript作为一门不断发展的语言,其API也在持续进化。新的、更安全、更高效的API不断涌现,而一些旧的API则因为各种原因(如安全问题、性能瓶颈、设计缺陷或有了更好的替代品)被标记为“废...
- 几大开源免费的 JavaScript 富文本编辑器测评
-
MarkDown编辑器用的时间长了,发现发现富文本编辑器用起来是真的舒服。...
- 比较好的网页里面的 html 编辑器 推荐
-
如果您正在寻找嵌入到网页中的HTML编辑器,以便用户可以直接在网页上编辑HTML内容,以下是几个备受推荐的:CKEditor:CKEditor是一个功能强大的、开源的富文本编辑器,可以嵌入到...
- Luckysheet 实现excel多人在线协同编辑
-
前言前些天看到Luckysheet支持协同编辑Excel,正符合我们协同项目的一部分,故而想进一步完善协同文章,但是遇到了一下困难,特此做声明哈,若侵权,请联系我删除文章!若侵犯版权、个人隐私,请联系...
- 从 Element UI 源码的构建流程来看前端 UI 库设计
-
作者:前端森林转发链接:https://mp.weixin.qq.com/s/ziDMLDJcvx07aM6xoEyWHQ引言...
- 手把手教你如何用 Decorator 装饰你的 Typescript?「实践」
-
作者:Nealyang转发连接:https://mp.weixin.qq.com/s/PFgc8xD7gT40-9qXNTpk7A...
- 推荐五个优秀的富文本编辑器
-
富文本编辑器是一种可嵌入浏览器网页中,所见即所得的文本编辑器。对于许多从事前端开发的小伙伴来说并不算陌生,它的应用场景非常广泛,平时发个评论、写篇博客文章等都能见到它的身影。...
- 基于vue + element的后台管理系统解决方案
-
作者:林鑫转发链接:https://github.com/lin-xin前言该方案作为一套多功能的后台框架模板,适用于绝大部分的后台管理系统(WebManagementSystem)开发。基于v...
- 开源富文本编辑器Quill 2.0重磅发布
-
开源富文本编辑器Quill正式发布2.0版本。官方TypeScript声明...
- Python之Web开发框架学习 Django-表单处理
-
在Django中创建表单实际上类似于创建模型。同样,我们只需要从Django类继承,则类属性将是表单字段。让我们在myapp文件夹中添加一个forms.py文件以包含我们的应用程序表单。我们将创建一个...
- Django测试入门:打造坚实代码基础的钥匙
-
这一篇说一下django框架的自动化测试,...
- Django ORM vs SQLAlchemy:到底谁更香?从入门到上头的选择指南
-
阅读文章前辛苦您点下“关注”,方便讨论和分享,为了回馈您的支持,我将每日更新优质内容。...
- 超详细的Django 框架介绍,它来了!
-
时光荏苒,一晃小编的Tornado框架系列也结束了。这个框架虽然没有之前的FastAPI高流量,但是,它也是小编的心血呀。总共16篇博文,从入门到进阶,包含了框架的方方面面。虽然小编有些方面介绍得不是...
- 20《Nginx 入门教程》使用 Nginx 部署 Python 项目
-
今天的目标是完成一个PythonWeb项目的线上部署,我们使用最新的Django项目搭建一个简易的Web工程,然后基于Nginx服务部署该PythonWeb项目。1.前期准备...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)