百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

利用PyTorch的三元组损失Hard Triplet Loss进行嵌入模型微调

ztj100 2024-12-01 07:00 16 浏览 0 评论

本文介绍如何使用 PyTorch 和三元组边缘损失 (Triplet Margin Loss) 微调嵌入模型,并重点阐述实现细节和代码示例。三元组损失是一种对比损失函数,通过缩小锚点与正例间的距离,同时扩大锚点与负例间的距离来优化模型。

数据集准备与处理

一般的嵌入模型都会使用Sentence Transformer ,其中的 encode() 方法可以直接处理文本输入。但是为了进行微调,我们需要采用 Transformer 库,所以就要将文本转换为模型可接受的 token IDs 和 attention masks。Token IDs 代表模型词汇表中的词或字符,attention masks 用于防止模型关注填充 tokens。

本文使用 thenlper/gte-base 模型,需要对应的 tokenizer 对文本进行预处理。该模型基于 BertModel 架构:

BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)

利用 Transformers 库的 AutoTokenizer 和 AutoModel 可以简化模型加载过程,无需手动处理底层架构和配置细节。

from transformers import AutoTokenizer, AutoModel 
from tqdm import tqdm 
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base") 

# 获取文本并进行标记 
train_texts = [df_train.loc[i]['content'] for i in range(df_train.shape[0])] 
dev_texts = [df_dev.loc[i]['content'] for i in range(df_dev.shape[0])] 
test_texts = [df_test.loc[i]['content'] for i in range(df_test.shape[0])] 

train_tokens = [] 
train_attention_masks = [] 
dev_tokens = [] 
dev_attention_masks = [] 
test_tokens = [] 
test_attention_masks = [] 

for sent in tqdm(train_texts): 
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') 
train_tokens.append(encoding['input_ids'].squeeze(0)) 
train_attention_masks.append(encoding['attention_mask'].squeeze(0)) 

for sent in tqdm(dev_texts): 
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') 
dev_tokens.append(encoding['input_ids'].squeeze(0)) 
dev_attention_masks.append(encoding['attention_mask'].squeeze(0)) 

for sent in tqdm(test_texts): 
encoding = tokenizer(sent, truncation=True, padding='max_length', return_tensors='pt') 
test_tokens.append(encoding['input_ids'].squeeze(0)) 
test_attention_masks.append(encoding['attention_mask'].squeeze(0))

获取 token IDs 和 attention masks 后,需要将其存储并创建一个自定义的 PyTorch 数据集。

import random 
from collections import defaultdict 
import torch 
from torch.utils.data import Dataset, DataLoader, Sampler, SequentialSampler 

class CustomTripletDataset(Dataset): 
def __init__(self, tokens, attention_masks, labels): 
self.tokens = tokens 
self.attention_masks = attention_masks 
self.labels = torch.Tensor(labels) 
self.label_dict = defaultdict(list) 

for i in range(len(tokens)): 
self.label_dict[int(self.labels[i])].append(i) 
self.unique_classes = list(self.label_dict.keys()) 

def __len__(self): 
return len(self.tokens) 

def __getitem__(self, index): 
ids = self.tokens[index].to(device) 
ams = self.attention_masks[index].to(device) 
y = self.labels[index].to(device) 
return ids, ams, y

由于采用三元组损失,需要从数据集中采样正例和负例。label_dict 字典用于存储每个类别及其对应的数据索引,方便随机采样。DataLoader 用于加载数据集:

train_loader = DataLoader(train_dataset, batch_sampler=train_batch_sampler)

其中 train_batch_sampler 是自定义的批次采样器:

class CustomBatchSampler(SequentialSampler): 
def __init__(self, dataset, batch_size): 
self.dataset = dataset 
self.batch_size = batch_size 
self.unique_classes = sorted(dataset.unique_classes) 
self.label_dict = dataset.label_dict 
self.num_batches = len(self.dataset) // self.batch_size 
self.class_size = self.batch_size // 4 

def __iter__(self): 
total_samples_used = 0 
weights = np.repeat(1, len(self.unique_classes)) 

while total_samples_used < len(self.dataset): 
batch = [] 
classes = [] 
for _ in range(4): 
next_selected_class = self._select_class(weights) 
while next_selected_class in classes: 
next_selected_class = self._select_class(weights) 
weights[next_selected_class] += 1 
classes.append(next_selected_class) 
new_choices = self.label_dict[next_selected_class] 
remaining_samples = list(np.random.choice(new_choices, min(self.class_size, len(new_choices)), replace=False)) 
batch.extend(remaining_samples) 

total_samples_used += len(batch) 

yield batch 

def _select_class(self, weights): 
dist = 1/weights 
dist = dist/np.sum(dist) 
selected = int(np.random.choice(self.unique_classes, p=dist)) 
return selected 

def __len__(self): 
return self.num_batches

自定义批次采样器控制训练批次的构成,本文的实现确保每个批次包含 4 个类别,每个类别包含 8 个数据点。验证采样器则确保验证集批次在不同 epoch 间保持一致。

模型构建

嵌入模型通常基于 Transformer 架构,输出每个 token 的嵌入。为了获得句子嵌入,需要对 token 嵌入进行汇总。常用的方法包括 CLS 池化和平均池化。本文使用的 gte-base 模型采用平均池化,需要从模型输出中提取 token 嵌入并计算平均值。

import torch.nn.functional as F 
import torch.nn as nn 

class EmbeddingModel(nn.Module): 
def __init__(self, base_model): 
super().__init__() 
self.base_model = base_model 

def average_pool(self, last_hidden_states, attention_mask): 
# 平均 token 嵌入 
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) 
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 

def forward(self, input_ids, attention_mask): 
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) 
last_hidden_state = outputs.last_hidden_state 
pooled_output = self.average_pool(last_hidden_state, attention_mask) 
normalized_output = F.normalize(pooled_output, p=2, dim=1) 
return normalized_output 

base_model = AutoModel.from_pretrained("thenlper/gte-base") 
model = EmbeddingModel(base_model)

EmbeddingModel 类封装了 Hugging Face 模型,并实现了平均池化和嵌入归一化。

模型训练

训练循环中需要动态计算每个锚点的最难正例和最难负例。

import numpy as np 

def train(model, train_loader, criterion, optimizer, scheduler): 
model.train() 
epoch_train_losses = [] 

for idx, (ids, attention_masks, labels) in enumerate(train_loader): 
optimizer.zero_grad() 

embeddings = model(ids, attention_masks) 

distance_matrix = torch.cdist(embeddings, embeddings, p=2) # 创建方形距离矩阵 

anchors = [] 
positives = [] 
negatives = [] 

for i in range(len(labels)): 

anchor_label = labels[i].item() 
anchor_distance = distance_matrix[i] # 锚点与所有其他点之间的距离 

# 最难的正例(同一类别中最远的) 
hardest_positive_idx = (labels == anchor_label).nonzero(as_tuple=True)[0] # 所有同类索引 
hardest_positive_idx = hardest_positive_idx[hardest_positive_idx != i] # 排除自己的标签 
hardest_positive = hardest_positive_idx[anchor_distance[hardest_positive_idx].argmax()] # 最远同类的标签 

# 最难的负例(不同类别中最近的) 
hardest_negative_idx = (labels != anchor_label).nonzero(as_tuple=True)[0] # 所有不同类索引 
hardest_negative = hardest_negative_idx[anchor_distance[hardest_negative_idx].argmin()] # 最近不同类的标签 

# 加载选择的 
anchors.append(embeddings[i]) 
positives.append(embeddings[hardest_positive]) 
negatives.append(embeddings[hardest_negative]) 

# 将列表转换为张量 
anchors = torch.stack(anchors) 
positives = torch.stack(positives) 
negatives = torch.stack(negatives) 

# 计算损失 
loss = criterion(anchors, positives, negatives) 
epoch_train_losses.append(loss.item()) 

# 反向传播和优化 
loss.backward() 
optimizer.step() 

# 更新学习率 
scheduler.step() 

return np.mean(epoch_train_losses)

训练过程中使用 torch.cdist() 计算嵌入间的距离矩阵,并根据距离选择最难正例和最难负例。PyTorch 的 TripletMarginLoss 用于计算损失。

结论与讨论

实践表明,Batch Hard Triplet Loss 在某些情况下并非最优选择。例如,当正例样本内部差异较大时,强制其嵌入相似可能适得其反。

本文的重点在于 PyTorch 中自定义批次采样和动态距离计算的实现。

对于某些任务,直接在分类任务上微调嵌入模型可能比使用三元组损失更有效。

相关推荐

Win10预览版10532已知问题汇总(微软win11正式版已知问题一览)

IT之家讯微软已向Insider用户推送了Win10预览版10532更新,本次更新对右键菜单、《Windows反馈》应用以及Edge浏览器进行了改进。除此之外还包含一些Bug,汇总如下,有意升级Wi...

Gabe Aul正测试Win10 Mobile 10532,Insider用户还需等

IT之家讯本月中旬微软向Insider用户推送了Win10Mobile预览版10512,该版本修复了一些Bug,增强了系统稳定性,但依然存在一些问题。今天,微软Insider项目负责人GabeAu...

微软开始推送Win10预览版10532快速版更新

8月28日消息,刚才,微软推送了Win10Build10532快速版,修复了之前的Bug,并带来了三项改进。主要来说,这次的更新改进了右键菜单的UI,使其更具Modern风格(见上图)。此外,更新...

Win10预览版10532更新内容大全(windows10更新预览版)

IT之家讯今天凌晨微软向Insider用户推送了Win10预览版10532快速版更新,本次更新主要带来了三处改进,汇总如下:o改进右键菜单,外观更加Modern。这是基于网友要求界面一致的反馈做出...

无法升级Win10预览版10532?也许Hyper-V在搞鬼

根据IT之家网友的反映,安装了微软虚拟机Hyper-V的Win10预览版用户无法成功升级Build10532版本,安装过程中会被要求回滚系统。很多朋友在尝试关闭虚拟机之后重启安装程序,结果仍然无法顺...

Win10预览版10532界面兴起“酷黑”风潮

Win10预览版10532的界面改动还是较为明显的,主要体现在右键菜单上面。总体来看,该版本的右键菜单间距更宽,视觉上更大气,操作上更便于触控。具体来说,任务栏右键菜单的变化最为明显。除了增加选项的宽...

Win10预览版10532上手图集(windows10预览版下载)

IT之家讯8月28日,微软今天推送了Win10预览版10532快速版更新,在该版本中,微软主要是加强细节上调整,并且主要是增强Edge浏览器性能等。在Windows10预览版10532中,微软改进了...

Win10预览版10532上手视频亮点演示

IT之家讯8月28日消息,今天凌晨微软向WindowsInsider快速通道用户推送了Win10预览版10532。在Windows10预览版10532中,微软改进了右键菜单,外观更加现代化。另外还...

第二篇 前端框架Vue.js(vue前端框架技术)

前端三大核心是网页开发的基础,Vue则是基于它们构建的“生产力工具”。通俗理解就是HTML是化妆的工具如眉笔,CSS是化妆品如口红,JavaScript是化妆后的互动,而Vue就是化妆助手。有了化妆工...

基于SpringBoot + vue2实现的旅游推荐管理系统

项目描述...

基于Vue以及iView组件的后端管理UI模板——iview-admin

介绍iView-admin是一套后端管理界面模板,基于Vue2.0,iView(现在为ViewUI)组件是一套完整的基于Vue的高质量组件库,虽然Github上有一套非常火的基于ElementUI...

别再说你会SPA开发了,这5个核心你真的搞懂了吗?

前言此spa非彼spa,不是你所熟知的spa。你所熟知的spa作者肯定是没有你熟悉的。我们这里指的是在前端开发中的一种模型,叫作单页应用程序,顾名思义,就是整个项目只有一个页面,而页面中的内容是动态的...

React.js Top20面试题(react.js中文官网)

概述作为React开发者,对框架的关键概念和原则有扎实的理解是很重要的。考虑到这一点,我整理了一份包含20个重要问题的清单,每个React开发者都应该知道,无论他们是在面试工作还是只是想提高技能。...

美媒:特朗普签署行政令后,FBI又发现约2400份、总计超14000页涉肯尼迪遇刺案文件

来源:环球时报新媒体1月23日特朗普下令公布肯尼迪遇刺案相关机密文件图源:美媒综合福克斯新闻网和Axios网站10日报道,在总统特朗普签署行政令,要求公布“肯尼迪遇刺案”相关政府机密文件之后,美国...

2021 年 Node.js 开发人员学习路线图

Node.js自发布以来,已成为业界重要破局者之一。Uber、Medium、PayPal和沃尔玛等大型企业,纷纷将技术栈转向Node.js。Node.js支持开发功能强大的应用,例如实时追踪...

取消回复欢迎 发表评论: