大模型入门-day13-14:小规模训练(小规模教学)
ztj100 2025-06-09 07:26 43 浏览 0 评论
小规模训练 内容基于 Hugging Face 的 datasets 库加载 WikiText-2 数据集,训练简单 Transformer 模型,并观察 Perplexity 下降。
第 13-14 天:小规模训练(6-10 小时)
学习目标
- 理解 Transformer 原理:掌握 Self-Attention 等核心概念。
- 加载数据:用 Hugging Face 的 datasets 库加载 WikiText-2。
- 构建模型:用 PyTorch 搭建简单 Transformer。
- 训练与评估:训练模型,观察 Perplexity 下降。
- 成果:能解释 Transformer,手写简单代码。
时间安排
- 总计:6-10 小时
- 第 13 天:3-5 小时(原理、数据加载、模型搭建)
- 第 14 天:3-5 小时(训练、评估、总结)
第 13 天:准备与搭建
任务 1:理解 Transformer 原理
时间:1-2 小时
内容:
- Transformer 核心:通过 Self-Attention 关注句子中的重要词,用编码器和解码器处理输入和生成输出。
- 关键组件:
- Self-Attention:让模型关注每个词与其他词的关系。
- Multi-Head Attention:多角度理解句子。
- Positional Encoding:给词加上位置信息。
- 资源:
- The Illustrated Transformer
- Attention is All You Need(可选)
练习:用自己的话说:“Transformer 怎么预测下一个词?”
任务 2:加载数据集和分词
时间:1 小时
内容:用 datasets 库加载 WikiText-2 数据集,并用分词器处理文本。
代码:
python
# 导入库
from datasets import load_dataset # 加载数据集
from transformers import AutoTokenizer # 加载分词器
# 加载 WikiText-2 数据集
dataset = load_dataset("wikitext", "wikitext-2-v1")
# 解释:从 Hugging Face 下载 WikiText-2,包含 train、valid、test 三部分。
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# 解释:用 BERT 的分词器,把文本转为数字 token。
# 分词函数
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
# 解释:对文本分词,截断或填充到 512 个 token。
# 分词整个数据集
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# 解释:批量处理数据集,分词后返回新数据集。
输出示例:
- 检查 tokenized_dataset["train"][0]["input_ids"],会看到一串数字(如 [101, 1996, 4937, ...])。
任务 3:数据预处理
时间:1 小时
内容:将分词数据转为 PyTorch 张量,创建 DataLoader。
代码:
python
# 导入库
from torch.utils.data import DataLoader # 创建 DataLoader
from torch.nn.utils.rnn import pad_sequence # 填充序列
# 自定义批处理函数
def collate_fn(batch):
input_ids = [torch.tensor(item["input_ids"]) for item in batch]
# 解释:从每个样本中提取 input_ids,转为张量。
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
# 解释:填充序列到相同长度,用 pad_token_id(如 0)填充。
return {"input_ids": input_ids}
# 解释:返回字典,包含填充后的 input_ids。
# 创建 DataLoader
train_dataloader = DataLoader(
tokenized_dataset["train"], # 训练集
batch_size=8, # 每批 8 个样本
shuffle=True, # 随机打乱
collate_fn=collate_fn # 用自定义函数处理批次
)
# 解释:DataLoader 批量加载数据,方便训练。
输出示例:
- next(iter(train_dataloader))["input_ids"] 输出形状:(8, 512)。
任务 4:构建简单 Transformer 模型
时间:2-3 小时
内容:用 PyTorch 搭建一个简单 Transformer。
代码:
python
# 导入库
import torch
import torch.nn as nn # 神经网络模块
# 定义模型
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_seq_length, dropout=0.1):
super(SimpleTransformer, self).__init__()
# 解释:初始化父类 nn.Module。
self.embedding = nn.Embedding(vocab_size, d_model)
# 解释:将词索引转为 d_model 维的向量。
self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_length, d_model))
# 解释:可学习的位置编码,记录词的位置。
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
)
# 解释:定义单层编码器,包含注意力机制和前馈网络。
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
# 解释:堆叠多层编码器。
self.fc_out = nn.Linear(d_model, vocab_size)
# 解释:将编码器输出映射到词汇表大小。
def forward(self, src):
seq_length = src.size(1)
# 解释:获取输入序列长度。
src = self.embedding(src) + self.positional_encoding[:, :seq_length, :]
# 解释:词嵌入加上位置编码。
output = self.transformer_encoder(src)
# 解释:通过 Transformer 编码器处理。
return self.fc_out(output)
# 解释:输出预测结果。
# 初始化模型
vocab_size = tokenizer.vocab_size # 词汇表大小(如 30522)
model = SimpleTransformer(
vocab_size=vocab_size, d_model=512, nhead=8, num_encoder_layers=6,
dim_feedforward=2048, max_seq_length=512, dropout=0.1
)
# 解释:创建模型实例,设置超参数。
参数说明:
- d_model:词嵌入维度(512)。
- nhead:注意力头数(8)。
- num_encoder_layers:编码器层数(6)。
第 14 天:训练与评估
任务 5:训练模型
时间:2-3 小时
内容:编写训练循环,训练模型。
代码:
python
# 导入库
import torch.optim as optim # 优化器
# 设置设备和损失函数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# 训练函数
def train(model, dataloader, epochs=3): # 减少 epoch 以节省时间
model.train()
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
# 解释:将输入移到 GPU/CPU。
labels = input_ids # 语言模型用输入预测下一个词
optimizer.zero_grad()
# 解释:清空上一次梯度。
output = model(input_ids)
# 解释:模型前向传播。
loss = criterion(output.view(-1, vocab_size), labels.view(-1))
# 解释:计算损失,view(-1) 展平张量。
loss.backward()
# 解释:反向传播计算梯度。
optimizer.step()
# 解释:更新模型参数。
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.3f}")
# 开始训练
train(model, train_dataloader)
输出示例:
Epoch 1, Avg Loss: 7.500
Epoch 2, Avg Loss: 6.800
Epoch 3, Avg Loss: 6.200
任务 6:评估模型
时间:1 小时
内容:计算 Perplexity,观察下降趋势。
代码:
python
import math
def calculate_perplexity(model, dataloader):
model.eval()
total_loss = 0
with torch.no_grad(): # 不计算梯度
for batch in dataloader:
input_ids = batch["input_ids"].to(device)
labels = input_ids
output = model(input_ids)
loss = criterion(output.view(-1, vocab_size), labels.view(-1))
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
perplexity = math.exp(avg_loss) # Perplexity = e^loss
return perplexity
# 计算并打印
perplexity = calculate_perplexity(model, train_dataloader)
print(f"Perplexity: {perplexity:.3f}")
输出示例:
Perplexity: 500.000
任务 7:总结与反思
时间:1 小时
内容:回顾学习成果,回答问题:
- Perplexity 是否下降? 如果没有,可能是学习率太高(调小 lr)或数据问题。
- Self-Attention 怎么工作? 试着解释:“它让模型关注句子中重要的词,比如‘猫’和‘坐’的关系。”
- 改进建议:调整 lr(如 0.001)、增加 epochs、减少 batch_size。
成果验收
- 原理:能用简单语言解释 Transformer。
- 数据:成功加载并分词 WikiText-2。
- 模型:手写并运行 Transformer 代码。
- 评估:观察到 Perplexity 下降(比如从 1800 到 500)。
小 Tips
- 硬件:没 GPU 用 CPU,调小 batch_size(如 4)。
- 调试:报错告诉我,我帮你调。
- 扩展:试试用 valid_dataloader 评估验证集。
完成任务后,告诉我 Perplexity 结果和你的 Transformer 解释,我帮你确认!动手开始吧!
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统...
- systemd是什么如何使用_systemd/system
-
systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
-
Linux系统日常巡检脚本,巡检内容包含了,磁盘,...
- 7,MySQL管理员用户管理_mysql 管理员用户
-
一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
-
1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
-
设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...
- linux系统启动流程和服务管理,带你进去系统的世界
-
Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...
- CentOS7系统如何修改主机名_centos更改主机名称
-
请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
-
在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
-
为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...
- Kubernetes 高可用(HA)集群部署指南
-
Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...
- Linux项目开发,你必须了解Systemd服务!
-
1.Systemd简介...
- Linux系统systemd服务管理工具使用技巧
-
简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...
- Linux下NetworkManager和network的和平共处
-
简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
Linux下NetworkManager和network的和平共处
-
Kubernetes 高可用(HA)集群部署指南
-
linux系统启动流程和服务管理,带你进去系统的世界
-
7,MySQL管理员用户管理_mysql 管理员用户
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- systemd是什么如何使用_systemd/system
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
- 7,MySQL管理员用户管理_mysql 管理员用户
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
- linux系统启动流程和服务管理,带你进去系统的世界
- CentOS7系统如何修改主机名_centos更改主机名称
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)