百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

特征交叉系列:AFM理论和实践,结合注意力机制的交叉池化

ztj100 2025-07-20 00:02 5 浏览 0 评论

关键词:FM注意力机制推荐算法

内容摘要

  • AFM介绍和结构简述
  • 注意力层问题分析
  • AFM在PyTorch下的实践和效果对比
  • 注意力权重可视化和业务分析

AFM介绍和结构简述

在上一节介绍了NFM算法特征交叉系列:NFM原理和实践,使用交叉池化连接FM和DNN,NFM采用求和池化对两两向量的乘积做信息浓缩聚合,本节介绍它引入注意力机制的升级版AFM(Attentional Factorization Machines)。

AFM整体结构和NFM类似,核心区别在于AFM采用注意力机制做池化,即采用加权求和的方式取代了NFM中的直接求和。

在FM层输出的向量乘积代表一对特征交叉的表征,可想而知,不同的特征交叉对最终预测结果的帮助大小也是不一样的,比如用户历史上感兴趣的商品类目和目标商品类目形成的交叉,肯定比用户性别和用户年龄形成的交叉效果要好,模型会给用户侧和商品侧的交互更多的倾向,而用户内部特征的交叉收益并不大。因此NFM中的求和池化并不合理,因为这相当于所有交叉结果都是相同的权重,并没有对重要的交叉做加强对无用的交叉做衰减,压力全部给到了下游的全连接层。

AFM是NFM的升级,因此AFM也是针对FM的二阶交互层做了修改,AFM的网络结构如下

相比于NFM,AFM网络结构有两处变动

  • 1:引入了Attention Net,采用一个全连接作为Attention相似函数,具体是使用激活函数是Relu的全连接层映射到固定维度,再加一个线性层输出到标量,公式如下

在得到该批次下每个样本的attention权重后,和交叉层输出的向量乘积做相乘,相当于分配权重,再将所有带权的交叉向量对应位置求和,达到加权求和的效果,注意力的输出如下

和NFM一样输出是一个[batch, emb_size]的二维向量。

  • 2:取消了特征交叉之后的多层感知机,而是直接以一个线性层输出为一个标量,和FM的一阶结果相加得到整个模型的输出。

注意力层问题分析

这里对注意力机制有点疑问,一般注意力机制有querykeyvalue,其中query作为参照物,query和key做相似计算输出权重,再对key对应的value进行加权求和,而在AFM中显然key和value都是交叉层两两向量的乘积,而没有query因此注意力权重可能是基于全部样本学习到的特征重要性。哪些特征做交叉更重要是有先验知识的,AFM的注意力学习到的是类似用户侧和商品侧交叉更有意义这样的大致趋势。个人认为由于在AFM中并没有加入query(比如召回的目标商品),因此注意力的作用不明显。


AFM在PyTorch下的实践和效果对比

本次实践的数据集和上一篇特征交叉系列:完全理解FM因子分解机原理和代码实战一致,采用用户的购买记录流水作为训练数据,用户侧特征是年龄,性别,会员年限等离散特征,商品侧特征采用商品的二级类目,产地,品牌三个离散特征,随机构造负样本,一共有10个特征域,全部是离散特征,对于枚举值过多的特征采用hash分箱,得到一共72个特征。
AFM的PyTorch代码实现如下

class Embedding(nn.Module):
    def __init__(self, feat_dim, emb_dim):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(feat_dim, emb_dim)
        nn.init.xavier_normal_(self.embedding.weight.data)

    def forward(self, x):
        # [None, field_num] => [None, field_num, emb_dim]
        return self.embedding(x)


class FM(nn.Module):
    def __init__(self):
        super(FM, self).__init__()

    def forward(self, x):
        # x=[None, field_num, emb_dim]
        field_num = x.shape[1]
        p = []
        q = []
        for i in range(field_num - 1):
            for j in range(i + 1, field_num):
                p.append(i)
                q.append(j)
        pp = x[:, p]
        qq = x[:, q]
        pq = pp * qq
        return pq


class Attention(nn.Module):
    def __init__(self, emb_size, attn_size, dropout=0.1):
        super(Attention, self).__init__()
        self.w = nn.Linear(emb_size, attn_size)
        self.h = nn.Linear(attn_size, 1)
        self.fc = torch.nn.Linear(emb_size, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        for weight in (self.w, self.h, self.fc):
            nn.init.xavier_uniform_(weight.weight.data)

    def forward(self, x):
        # x=[None, field_num*(field_num-1)/2, emb_dim] => [None, field_num*(field_num-1)/2, 1]
        eij = torch.softmax(self.h(self.relu(self.w(x))), dim=1)
        eij = self.dropout(eij)
        attn_x = x * eij
        out = self.dropout(torch.sum(attn_x, dim=1))  # [None, emb_size]
        return self.fc(out)  # [None, 1]


class Linear(nn.Module):
    def __init__(self, feat_dim):
        super(Linear, self).__init__()
        self.embedding = nn.Embedding(feat_dim, 1)
        self.bias = nn.Parameter(torch.zeros(1))
        nn.init.xavier_normal_(self.embedding.weight.data)

    def forward(self, x):
        # [None, field_num] => [None, field_num, 1] => [None, 1]
        return self.embedding(x).sum(dim=1) + self.bias


class AFM(nn.Module):
    def __init__(self, feat_dim, emb_dim, attention_size, dropout=0.1):
        super(AFM, self).__init__()
        self.linear = Linear(feat_dim=feat_dim)
        self.embedding = Embedding(feat_dim=feat_dim, emb_dim=emb_dim)
        self.fm = FM()
        self.attention = Attention(emb_size=emb_dim, attn_size=attention_size, dropout=dropout)

    def forward(self, x):
        linear = self.linear(x)  # [None, 1]
        emb = self.embedding(x)  # [None, field_num, emb_dim]
        cross = self.fm(emb)  # [None, field_num*(field_num-1)/2, emb_dim]
        attn_cross = self.attention(cross)  # [None, 1]
        out = linear + attn_cross
        return torch.sigmoid(out).squeeze(dim=1)

其中FM+Attention子模块实现了交叉注意力池化层。

本例全部是离散分箱变量,所有有值的特征都是1,因此只要输入有值位置的索引即可,一条输入例如

>>> train_data[0]
Out[120]: (tensor([ 2, 10, 14, 18, 34, 39, 47, 51, 58, 64]), tensor(0))

采用验证集的10次AUC不上升作为早停,FM,NFM,AFM的平均验证集AUC如下


FM

NFM

AFM

AUC

0.6274

0.6329

0.6293

结果带有注意力机制的AFM反而不如NFM,即注意力的学习可能还不如多层感知机。


注意力权重可视化和业务分析

进一步看一下模型训练完成后学习到权重eij,首先基于热力图观察到几乎每一条样本的注意力权重分布都是一致的,如下图横向是所有特征交叉的案例45种,纵向是batchsize,只看了部分30条样本

可以发现第44组特征交叉权重最大,其次是第22组,第16组,基本在8,15,22,44位置形成了三条竖线,进一步说明了AFM的注意力学习的是全局的重要性,随着样本的个性化不足。

然后将所有样本下的注意力权重求平均值,获得所有组合的注意力权重分布

将注意力最高的组合和对应的特征业务名称对应上,结果如下

注意力权重排名

组合编号

特征

1

44

商品品牌×商品产地

2

22

会员年限×商品品牌

3

23

会员年限×商品产地

4

33

会员vip等级×商品品牌

5

34

会员vip等级×商品产地

排名第一的是商品内侧的特征交叉,力压一众用户和商品的跨域交叉,不是太合理,一方面可能AFM训练出的注意力确实很糟糕,另一方面也可能是样本中频繁出现某商品的购买记录形成热点,导致模型给到该商品的品牌和产地以及他们交叉项的权重很大,也说明这种基于全局样本数据分布的注意力可能会收到样本热度的影响。

相关推荐

爬取电影视频数据(电影资源爬虫)

本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理。作者:yangrq1018原文链接:https://segmentfault.com/a/11900...

Python效率倍增的10个实用代码片段

引言Python是一门功能强大且灵活的编程语言,广泛应用于数据分析、Web开发、人工智能等多个领域。它的简洁语法和高可读性让开发者能够快速上手,但在实际工作中,我们常常会遇到一些重复性或繁琐的任务。这...

Python数据处理:深入理解序列化与反序列化

在现代编程实践中,数据的序列化与反序列化是数据持久化、网络通信等领域不可或缺的技术。本文将深入探讨Python中数据序列化与反序列化的概念、实现方式以及数据验证的重要性,并提供丰富的代码示例。...

亿纬锂能:拟向PKL买地,在马来西亚建立锂电池制造厂

亿纬锂能5月12日公告,亿纬马来西亚与PEMAJUKELANGLAMASDN.BHD.(PKL)签订《MEMORANDUMOFUNDERSTANDING》(谅解备忘录),亿纬马来西亚拟向PKL购买标的...

一个超强的机器学习库(spark机器学习库)

简介PyCaret...

30天学会Python编程:9. Python文件与IO操作

9.1文件操作基础9.1.1文件操作流程9.1.2文件打开模式表9-1Python文件打开模式...

Python的Pickle序列化与反序列化(python反序列化json)

动动小手,点击关注...

python进阶突破内置模块——数据序列化与格式

数据序列化是将数据结构或对象转换为可存储/传输格式的过程,反序列化则是逆向操作。Python提供了多种工具来处理不同场景下的序列化需求。一、核心内置模块...

微信聊天记录可视化工具详细介绍(微信聊天记录分析报告小程序)

功能概要能做什么...

Python常用文件操作库使用详解(python中文件操作的相关函数有哪些)

Python生态系统提供了丰富的文件操作库,可以处理各种复杂的文件操作需求。本教程将介绍Python中最常用的文件操作库及其实际应用。一、标准库核心模块1.1os模块-操作系统接口主要功能...

Vue3+Django4全新技术实战全栈项目(已完结)

获课》aixuetang.xyz/5739/Django与推荐算法的集成及模型部署实践...

性能调优方面,经常要优化跑的最慢的代码,教你一种快速的方法

在我们遇到性能问题的时候,很多时候需要去查看性能的瓶颈在哪里,本篇文章就是提供了多种常用的方案来监控函数的运行时间。1.time首先说明,time模块很多是系统相关的,在不同的OS中可能会有一些精度差...

Python解决读取excel数据慢的问题

前言:在做自动化测试的时候,我思考了一个问题,就是如果我们的测试用例随着项目的推进越来越多时,我们做自动化回归的时间也就越来越长,其中影响自动化测试速度的一个原因就是测试用例的读取问题。用例越多,所消...

【Python机器学习系列】基于Flask来构建API调用机器学习模型服务

这是我的第364篇...

不会用mmdet工具?速看MMDetection工具的终极指南

来源:计算机视觉工坊添加微信:dddvisiona,备注:目标检测,拉你入群。文末附行业细分群...

取消回复欢迎 发表评论: