特征交叉系列:AFM理论和实践,结合注意力机制的交叉池化
ztj100 2025-07-20 00:02 24 浏览 0 评论
关键词:FM,注意力机制,推荐算法
内容摘要
- AFM介绍和结构简述
- 注意力层问题分析
- AFM在PyTorch下的实践和效果对比
- 注意力权重可视化和业务分析
AFM介绍和结构简述
在上一节介绍了NFM算法特征交叉系列:NFM原理和实践,使用交叉池化连接FM和DNN,NFM采用求和池化对两两向量的乘积做信息浓缩聚合,本节介绍它引入注意力机制的升级版AFM(Attentional Factorization Machines)。
AFM整体结构和NFM类似,核心区别在于AFM采用注意力机制做池化,即采用加权求和的方式取代了NFM中的直接求和。
在FM层输出的向量乘积代表一对特征交叉的表征,可想而知,不同的特征交叉对最终预测结果的帮助大小也是不一样的,比如用户历史上感兴趣的商品类目和目标商品类目形成的交叉,肯定比用户性别和用户年龄形成的交叉效果要好,模型会给用户侧和商品侧的交互更多的倾向,而用户内部特征的交叉收益并不大。因此NFM中的求和池化并不合理,因为这相当于所有交叉结果都是相同的权重,并没有对重要的交叉做加强,对无用的交叉做衰减,压力全部给到了下游的全连接层。
AFM是NFM的升级,因此AFM也是针对FM的二阶交互层做了修改,AFM的网络结构如下
相比于NFM,AFM网络结构有两处变动
- 1:引入了Attention Net,采用一个全连接作为Attention相似函数,具体是使用激活函数是Relu的全连接层映射到固定维度,再加一个线性层输出到标量,公式如下
在得到该批次下每个样本的attention权重后,和交叉层输出的向量乘积做相乘,相当于分配权重,再将所有带权的交叉向量对应位置求和,达到加权求和的效果,注意力的输出如下
和NFM一样输出是一个[batch, emb_size]的二维向量。
- 2:取消了特征交叉之后的多层感知机,而是直接以一个线性层输出为一个标量,和FM的一阶结果相加得到整个模型的输出。
注意力层问题分析
这里对注意力机制有点疑问,一般注意力机制有query,key,value,其中query作为参照物,query和key做相似计算输出权重,再对key对应的value进行加权求和,而在AFM中显然key和value都是交叉层两两向量的乘积,而没有query,因此注意力权重可能是基于全部样本学习到的特征重要性。哪些特征做交叉更重要是有先验知识的,AFM的注意力学习到的是类似用户侧和商品侧交叉更有意义这样的大致趋势。个人认为由于在AFM中并没有加入query(比如召回的目标商品),因此注意力的作用不明显。
AFM在PyTorch下的实践和效果对比
本次实践的数据集和上一篇特征交叉系列:完全理解FM因子分解机原理和代码实战一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
AFM的PyTorch代码实现如下
class Embedding(nn.Module):
def __init__(self, feat_dim, emb_dim):
super(Embedding, self).__init__()
self.embedding = nn.Embedding(feat_dim, emb_dim)
nn.init.xavier_normal_(self.embedding.weight.data)
def forward(self, x):
# [None, field_num] => [None, field_num, emb_dim]
return self.embedding(x)
class FM(nn.Module):
def __init__(self):
super(FM, self).__init__()
def forward(self, x):
# x=[None, field_num, emb_dim]
field_num = x.shape[1]
p = []
q = []
for i in range(field_num - 1):
for j in range(i + 1, field_num):
p.append(i)
q.append(j)
pp = x[:, p]
qq = x[:, q]
pq = pp * qq
return pq
class Attention(nn.Module):
def __init__(self, emb_size, attn_size, dropout=0.1):
super(Attention, self).__init__()
self.w = nn.Linear(emb_size, attn_size)
self.h = nn.Linear(attn_size, 1)
self.fc = torch.nn.Linear(emb_size, 1)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
for weight in (self.w, self.h, self.fc):
nn.init.xavier_uniform_(weight.weight.data)
def forward(self, x):
# x=[None, field_num*(field_num-1)/2, emb_dim] => [None, field_num*(field_num-1)/2, 1]
eij = torch.softmax(self.h(self.relu(self.w(x))), dim=1)
eij = self.dropout(eij)
attn_x = x * eij
out = self.dropout(torch.sum(attn_x, dim=1)) # [None, emb_size]
return self.fc(out) # [None, 1]
class Linear(nn.Module):
def __init__(self, feat_dim):
super(Linear, self).__init__()
self.embedding = nn.Embedding(feat_dim, 1)
self.bias = nn.Parameter(torch.zeros(1))
nn.init.xavier_normal_(self.embedding.weight.data)
def forward(self, x):
# [None, field_num] => [None, field_num, 1] => [None, 1]
return self.embedding(x).sum(dim=1) + self.bias
class AFM(nn.Module):
def __init__(self, feat_dim, emb_dim, attention_size, dropout=0.1):
super(AFM, self).__init__()
self.linear = Linear(feat_dim=feat_dim)
self.embedding = Embedding(feat_dim=feat_dim, emb_dim=emb_dim)
self.fm = FM()
self.attention = Attention(emb_size=emb_dim, attn_size=attention_size, dropout=dropout)
def forward(self, x):
linear = self.linear(x) # [None, 1]
emb = self.embedding(x) # [None, field_num, emb_dim]
cross = self.fm(emb) # [None, field_num*(field_num-1)/2, emb_dim]
attn_cross = self.attention(cross) # [None, 1]
out = linear + attn_cross
return torch.sigmoid(out).squeeze(dim=1)
其中FM+Attention子模块实现了交叉注意力池化层。
本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如
>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))
采用验证集的10次AUC不上升作为早停,FM,NFM,AFM的平均验证集AUC如下
FM | NFM | AFM | |
AUC | 0.6274 | 0.6329 | 0.6293 |
结果带有注意力机制的AFM反而不如NFM,即注意力的学习可能还不如多层感知机。
注意力权重可视化和业务分析
进一步看一下模型训练完成后学习到权重eij,首先基于热力图观察到几乎每一条样本的注意力权重分布都是一致的,如下图横向是所有特征交叉的案例45种,纵向是batchsize,只看了部分30条样本
可以发现第44组特征交叉权重最大,其次是第22组,第16组,基本在8,15,22,44位置形成了三条竖线,进一步说明了AFM的注意力学习的是全局的重要性,随着样本的个性化不足。
然后将所有样本下的注意力权重求平均值,获得所有组合的注意力权重分布
将注意力最高的组合和对应的特征业务名称对应上,结果如下
注意力权重排名 | 组合编号 | 特征 |
1 | 44 | 商品品牌×商品产地 |
2 | 22 | 会员年限×商品品牌 |
3 | 23 | 会员年限×商品产地 |
4 | 33 | 会员vip等级×商品品牌 |
5 | 34 | 会员vip等级×商品产地 |
排名第一的是商品内侧的特征交叉,力压一众用户和商品的跨域交叉,不是太合理,一方面可能AFM训练出的注意力确实很糟糕,另一方面也可能是样本中频繁出现某商品的购买记录形成热点,导致模型给到该商品的品牌和产地以及他们交叉项的权重很大,也说明这种基于全局样本数据分布的注意力可能会收到样本热度的影响。
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统...
- systemd是什么如何使用_systemd/system
-
systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
-
Linux系统日常巡检脚本,巡检内容包含了,磁盘,...
- 7,MySQL管理员用户管理_mysql 管理员用户
-
一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
-
1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
-
设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...
- linux系统启动流程和服务管理,带你进去系统的世界
-
Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...
- CentOS7系统如何修改主机名_centos更改主机名称
-
请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
-
在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
-
为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...
- Kubernetes 高可用(HA)集群部署指南
-
Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...
- Linux项目开发,你必须了解Systemd服务!
-
1.Systemd简介...
- Linux系统systemd服务管理工具使用技巧
-
简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...
- Linux下NetworkManager和network的和平共处
-
简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
这一次,彻底搞懂Java并发包中的Atomic原子类
-
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- systemd是什么如何使用_systemd/system
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
- 7,MySQL管理员用户管理_mysql 管理员用户
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
- linux系统启动流程和服务管理,带你进去系统的世界
- CentOS7系统如何修改主机名_centos更改主机名称
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)