百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

特征交叉系列:AFM理论和实践,结合注意力机制的交叉池化

ztj100 2025-07-20 00:02 24 浏览 0 评论

关键词:FM注意力机制推荐算法

内容摘要

  • AFM介绍和结构简述
  • 注意力层问题分析
  • AFM在PyTorch下的实践和效果对比
  • 注意力权重可视化和业务分析

AFM介绍和结构简述

在上一节介绍了NFM算法特征交叉系列:NFM原理和实践,使用交叉池化连接FM和DNN,NFM采用求和池化对两两向量的乘积做信息浓缩聚合,本节介绍它引入注意力机制的升级版AFM(Attentional Factorization Machines)。

AFM整体结构和NFM类似,核心区别在于AFM采用注意力机制做池化,即采用加权求和的方式取代了NFM中的直接求和。

在FM层输出的向量乘积代表一对特征交叉的表征,可想而知,不同的特征交叉对最终预测结果的帮助大小也是不一样的,比如用户历史上感兴趣的商品类目和目标商品类目形成的交叉,肯定比用户性别和用户年龄形成的交叉效果要好,模型会给用户侧和商品侧的交互更多的倾向,而用户内部特征的交叉收益并不大。因此NFM中的求和池化并不合理,因为这相当于所有交叉结果都是相同的权重,并没有对重要的交叉做加强对无用的交叉做衰减,压力全部给到了下游的全连接层。

AFM是NFM的升级,因此AFM也是针对FM的二阶交互层做了修改,AFM的网络结构如下

相比于NFM,AFM网络结构有两处变动

  • 1:引入了Attention Net,采用一个全连接作为Attention相似函数,具体是使用激活函数是Relu的全连接层映射到固定维度,再加一个线性层输出到标量,公式如下

在得到该批次下每个样本的attention权重后,和交叉层输出的向量乘积做相乘,相当于分配权重,再将所有带权的交叉向量对应位置求和,达到加权求和的效果,注意力的输出如下

和NFM一样输出是一个[batch, emb_size]的二维向量。

  • 2:取消了特征交叉之后的多层感知机,而是直接以一个线性层输出为一个标量,和FM的一阶结果相加得到整个模型的输出。

注意力层问题分析

这里对注意力机制有点疑问,一般注意力机制有querykeyvalue,其中query作为参照物,query和key做相似计算输出权重,再对key对应的value进行加权求和,而在AFM中显然key和value都是交叉层两两向量的乘积,而没有query因此注意力权重可能是基于全部样本学习到的特征重要性。哪些特征做交叉更重要是有先验知识的,AFM的注意力学习到的是类似用户侧和商品侧交叉更有意义这样的大致趋势。个人认为由于在AFM中并没有加入query(比如召回的目标商品),因此注意力的作用不明显。


AFM在PyTorch下的实践和效果对比

本次实践的数据集和上一篇特征交叉系列:完全理解FM因子分解机原理和代码实战一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
AFM的PyTorch代码实现如下

class Embedding(nn.Module):
    def __init__(self, feat_dim, emb_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(feat_dim, emb_dim)
        nn.init.xavier_normal_(self.embedding.weight.data)

    def forward(self, x):
        # [None, field_num] => [None, field_num, emb_dim]
        return self.embedding(x)


class FM(nn.Module):
    def __init__(self):
        super(FM, self).__init__()

    def forward(self, x):
        # x=[None, field_num, emb_dim]
        field_num = x.shape[1]
        p = []
        q = []
        for i in range(field_num - 1):
            for j in range(i + 1, field_num):
                p.append(i)
                q.append(j)
        pp = x[:, p]
        qq = x[:, q]
        pq = pp * qq
        return pq


class Attention(nn.Module):
    def __init__(self, emb_size, attn_size, dropout=0.1):
        super(Attention, self).__init__()
        self.w = nn.Linear(emb_size, attn_size)
        self.h = nn.Linear(attn_size, 1)
        self.fc = torch.nn.Linear(emb_size, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        for weight in (self.w, self.h, self.fc):
            nn.init.xavier_uniform_(weight.weight.data)

    def forward(self, x):
        # x=[None, field_num*(field_num-1)/2, emb_dim] => [None, field_num*(field_num-1)/2, 1]
        eij = torch.softmax(self.h(self.relu(self.w(x))), dim=1)
        eij = self.dropout(eij)
        attn_x = x * eij
        out = self.dropout(torch.sum(attn_x, dim=1))  # [None, emb_size]
        return self.fc(out)  # [None, 1]


class Linear(nn.Module):
    def __init__(self, feat_dim):
        super(Linear, self).__init__()
        self.embedding = nn.Embedding(feat_dim, 1)
        self.bias = nn.Parameter(torch.zeros(1))
        nn.init.xavier_normal_(self.embedding.weight.data)

    def forward(self, x):
        # [None, field_num] => [None, field_num, 1] => [None, 1]
        return self.embedding(x).sum(dim=1) + self.bias


class AFM(nn.Module):
    def __init__(self, feat_dim, emb_dim, attention_size, dropout=0.1):
        super(AFM, self).__init__()
        self.linear = Linear(feat_dim=feat_dim)
        self.embedding = Embedding(feat_dim=feat_dim, emb_dim=emb_dim)
        self.fm = FM()
        self.attention = Attention(emb_size=emb_dim, attn_size=attention_size, dropout=dropout)

    def forward(self, x):
        linear = self.linear(x)  # [None, 1]
        emb = self.embedding(x)  # [None, field_num, emb_dim]
        cross = self.fm(emb)  # [None, field_num*(field_num-1)/2, emb_dim]
        attn_cross = self.attention(cross)  # [None, 1]
        out = linear + attn_cross
        return torch.sigmoid(out).squeeze(dim=1)

其中FM+Attention子模块实现了交叉注意力池化层。

本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

采用验证集的10次AUC不上升作为早停,FM,NFM,AFM的平均验证集AUC如下


FM

NFM

AFM

AUC

0.6274

0.6329

0.6293

结果带有注意力机制的AFM反而不如NFM,即注意力的学习可能还不如多层感知机。


注意力权重可视化和业务分析

进一步看一下模型训练完成后学习到权重eij,首先基于热力图观察到几乎每一条样本的注意力权重分布都是一致的,如下图横向是所有特征交叉的案例45种,纵向是batchsize,只看了部分30条样本

可以发现第44组特征交叉权重最大,其次是第22组,第16组,基本在8,15,22,44位置形成了三条竖线,进一步说明了AFM的注意力学习的是全局的重要性,随着样本的个性化不足。

然后将所有样本下的注意力权重求平均值,获得所有组合的注意力权重分布

将注意力最高的组合和对应的特征业务名称对应上,结果如下

注意力权重排名

组合编号

特征

1

44

商品品牌×商品产地

2

22

会员年限×商品品牌

3

23

会员年限×商品产地

4

33

会员vip等级×商品品牌

5

34

会员vip等级×商品产地

排名第一的是商品内侧的特征交叉,力压一众用户和商品的跨域交叉,不是太合理,一方面可能AFM训练出的注意力确实很糟糕,另一方面也可能是样本中频繁出现某商品的购买记录形成热点,导致模型给到该商品的品牌和产地以及他们交叉项的权重很大,也说明这种基于全局样本数据分布的注意力可能会收到样本热度的影响。

相关推荐

Linux集群自动化监控系统Zabbix集群搭建到实战

自动化监控系统...

systemd是什么如何使用_systemd/system

systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...

Linux服务器日常巡检脚本分享_linux服务器监控脚本

Linux系统日常巡检脚本,巡检内容包含了,磁盘,...

7,MySQL管理员用户管理_mysql 管理员用户

一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...

Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门

1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...

Linux自定义开机自启动服务脚本_linux添加开机自启动脚本

设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...

linux系统启动流程和服务管理,带你进去系统的世界

Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...

CentOS7系统如何修改主机名_centos更改主机名称

请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...

前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令

在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...

Linux开机自启服务完全指南:3步搞定系统服务管理器配置

为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...

Kubernetes 高可用(HA)集群部署指南

Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...

Linux项目开发,你必须了解Systemd服务!

1.Systemd简介...

Linux系统systemd服务管理工具使用技巧

简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...

Red Hat Enterprise Linux 10 安装 Kubernetes (K8s) 集群及高级管理

一、前言...

Linux下NetworkManager和network的和平共处

简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...

取消回复欢迎 发表评论: