Pytorch学习-day8: 损失函数与优化器
ztj100 2025-06-09 07:26 57 浏览 0 评论
学习目标
- 理解损失函数:学习什么是损失函数,为什么需要它,以及常见类型(如 MSE 和 CrossEntropy)。
- 理解优化器:了解优化器如何帮助模型学习,以及 SGD 和 Adam 的基本原理。
- 实践任务:为 Day 7 的 MLP 添加 MSE 损失函数和 SGD 优化器,训练 10 轮,观察模型如何改进预测。
术语解释
1. 损失函数 (Loss Function)
- 是什么:损失函数是衡量模型预测结果与真实答案之间差距的指标。模型的目标是让这个“差距”尽可能小。
- 类比:想象你在玩飞镖,目标是击中靶心(真实答案)。你的飞镖落在哪里(模型预测)与靶心的距离就是“损失”。损失越小,说明你越接近目标。
- 常见类型: MSE (Mean Squared Error, 均方误差):用于回归任务(预测连续值,如房价、温度)。它计算预测值与真实值差的平方平均值。 公式:MSE = (1/n) * Σ(预测值 - 真实值)^2 例子:预测房价是 100 万,真实是 120 万,差 20 万,平方后加权平均。 CrossEntropy (交叉熵损失):用于分类任务(预测类别,如猫狗分类)。它衡量预测概率分布与真实标签分布的差异。 例子:预测一张图片是猫的概率是 0.8,真实是猫,损失小;如果预测是狗,损失大。
- 作用:告诉模型“错在哪里,错多少”,为优化提供方向。
2. 优化器 (Optimizer)
- 是什么:优化器是调整模型参数(权重和偏置)的算法,让损失函数的值变小,模型预测更准确。
- 类比:模型像一个在山谷里找最低点(最小损失)的人。优化器是“导航员”,告诉它每次走哪一步(调整参数)。
- 常见类型: SGD (Stochastic Gradient Descent, 随机梯度下降): 原理:根据损失函数的梯度(斜率),小步调整参数,朝损失下降的方向走。 类比:像下山时看脚下的坡度,慢慢走。 特点:简单,但可能需要较多步数,容易卡在“局部低点”。 Adam: 原理:结合动量法和自适应学习率,比 SGD 更“聪明”,能更快找到最低点。 类比:像个有经验的登山者,知道哪里陡、哪里平,步伐大小自动调整。 特点:收敛更快,适合复杂模型,但参数多。
- 作用:通过反复调整模型参数,降低损失,让模型学到数据的规律。
示例:MLP + MSE + SGD
场景
假设你有一个简单的数据集:输入是 2 个数字(x1, x2),输出是它们的和(y = x1 + x2)。我们用一个 MLP 模型预测这个和,MSE 衡量预测误差,SGD 优化模型。
数据示例
- 输入:[1.0, 2.0],真实输出:3.0
- 输入:[0.5, 1.5],真实输出:2.0
- 目标:模型学会预测任意输入的和。
代码
以下是完整的 PyTorch 代码,包含注释,帮助新手理解每一步。
python
import torch
import torch.nn as nn
import torch.optim as optim
# 定义 MLP 模型(Day 7 的假设结构)
class MLP(nn.Module):
# 继承 PyTorch 的 Module 类
def __init__(self, input_size=2, hidden_size=10, output_size=1):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size) # 第一层:输入到隐藏层
self.relu = nn.ReLU() # 激活函数,让模型学非线性关系
self.fc2 = nn.Linear(hidden_size, output_size) # 第二层:隐藏层到输出
def forward(self, x): # 前向传播
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 设置超参数
input_size = 2 # 输入是 2 个数字
hidden_size = 10 # 隐藏层有 10 个神经元
output_size = 1 # 输出是 1 个数字(和)
learning_rate = 0.01 # 学习率,控制优化器步伐
epochs = 10 # 训练 10 轮
batch_size = 32 # 每批处理 32 个样本
# 生成示例数据
torch.manual_seed(42) # 固定随机种子,结果可重复
X = torch.randn(100, input_size) # 100 个样本,每个有 2 个特征
y = X.sum(dim=1, keepdim=True) + 0.1 * torch.randn(100, 1) # 目标:特征和 + 噪声
# 初始化模型、损失函数、优化器
model = MLP(input_size, hidden_size, output_size) # 创建模型
criterion = nn.MSELoss() # MSE 损失函数
optimizer = optim.SGD(model.parameters(), lr=learning_rate) # SGD 优化器
# 训练循环
model.train() # 设置模型为训练模式
for epoch in range(epochs):
total_loss = 0
# 小批量训练
for i in range(0, len(X), batch_size):
# 取一批数据
inputs = X[i:i+batch_size]
targets = y[i:i+batch_size]
# 前向传播
outputs = model(inputs) # 模型预测
loss = criterion(outputs, targets) # 计算 MSE 损失
# 反向传播与优化
optimizer.zero_grad() # 清空上一步的梯度
loss.backward() # 计算当前梯度
optimizer.step() # 根据梯度更新参数
total_loss += loss.item() # 累加损失
# 打印每轮平均损失
avg_loss = total_loss / (len(X) // batch_size)
print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}')
# 测试模型
model.eval() # 设置模型为评估模式
with torch.no_grad(): # 禁用梯度计算
test_input = torch.tensor([[1.0, 2.0]]) # 测试输入
prediction = model(test_input) # 预测
print(f'Test input: {test_input.tolist()}, Prediction: {prediction.tolist()}')
代码逐行解释
- 模型定义: MLP 类定义了一个两层神经网络:输入(2 个特征)→ 隐藏层(10 个神经元)→ 输出(1 个值)。 ReLU 激活函数让模型学到非线性关系(比如复杂的曲线)。 forward 方法描述数据如何通过网络。
- 数据生成: 创建 100 个随机样本,输入是 2 维(X),输出是输入的和加一点噪声(y)。 噪声模拟现实中的不完美数据。
- 损失函数: nn.MSELoss() 计算预测值和真实值的均方误差。 比如,预测 2.8,真实 3.0,误差是 (2.8 - 3.0)^2 = 0.04。
- 优化器: optim.SGD 使用梯度下降更新模型参数。 lr=0.01 控制每次调整的大小,太大可能跳过最低点,太小收敛慢。
- 训练过程: 每轮(epoch)遍历所有数据,分批(batch)处理。 前向传播:输入通过模型得到预测,计算损失。 反向传播:根据损失计算梯度(告诉参数怎么调整)。 优化:SGD 用梯度更新参数,减小损失。 打印平均损失,观察是否下降。
- 测试: 用 [1.0, 2.0] 测试,理想预测接近 3.0。 model.eval() 和 no_grad() 确保不计算梯度,节省内存。
预期输出
运行后,你会看到类似:
Epoch [1/10], Loss: 0.1234
Epoch [2/10], Loss: 0.0987
...
Epoch [10/10], Loss: 0.0123
Test input: [[1.0, 2.0]], Prediction: [[2.9876]]
- 损失逐渐下降,说明模型在学习。
- 预测值接近 3.0,说明模型学会了“求和”。
新手常见问题
- 为什么损失不下降? 学习率可能太高或太低,试试 0.001 或 0.1。 数据可能有问题,检查输入和目标是否匹配。 模型结构可能太简单,增加隐藏层或神经元。
- MSE 和 CrossEntropy 怎么选? 用 MSE 预测连续值(如温度、房价)。 用 CrossEntropy 预测类别(如猫、狗)。
- SGD 和 Adam 哪个好? SGD 简单,适合小数据集,但慢。 Adam 更快,适合复杂模型,但可能过拟合。
- 什么是梯度? 梯度是损失函数对参数的“斜率”,告诉优化器参数该往哪调(增大还是减小)。
资源建议
- PyTorch 文档: MSELoss SGD
- 视频教程: PyTorch 官方 YouTube 频道(如 Intro to PyTorch)。 Aladdin Persson 的 PyTorch 教程(YouTube)。
- 实践: 试试用 CrossEntropy 做分类任务。 调整学习率、隐藏层大小,观察损失变化。
下一步
- 试试 Adam 优化器:替换 optim.SGD 为 optim.Adam(model.parameters(), lr=0.001),比较收敛速度。
- 加载真实数据集:如果有 CSV 或图像数据,我可以帮你写加载代码。
- 可视化:用 Matplotlib 画损失曲线或预测结果。
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统...
- systemd是什么如何使用_systemd/system
-
systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
-
Linux系统日常巡检脚本,巡检内容包含了,磁盘,...
- 7,MySQL管理员用户管理_mysql 管理员用户
-
一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
-
1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
-
设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...
- linux系统启动流程和服务管理,带你进去系统的世界
-
Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...
- CentOS7系统如何修改主机名_centos更改主机名称
-
请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
-
在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
-
为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...
- Kubernetes 高可用(HA)集群部署指南
-
Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...
- Linux项目开发,你必须了解Systemd服务!
-
1.Systemd简介...
- Linux系统systemd服务管理工具使用技巧
-
简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...
- Linux下NetworkManager和network的和平共处
-
简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
Linux下NetworkManager和network的和平共处
-
Kubernetes 高可用(HA)集群部署指南
-
linux系统启动流程和服务管理,带你进去系统的世界
-
7,MySQL管理员用户管理_mysql 管理员用户
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- systemd是什么如何使用_systemd/system
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
- 7,MySQL管理员用户管理_mysql 管理员用户
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
- linux系统启动流程和服务管理,带你进去系统的世界
- CentOS7系统如何修改主机名_centos更改主机名称
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)