AutoInt 论文精读

论文引用: Song, Weiping, et al. “Autoint: Automatic feature interaction learning via self-attentive neural networks.” Proceedings of the 28th ACM International Conference on Information and Knowledge Management. 2019.

这是一篇北京大学发表在 CIKM 2019 的文章,看作者列表没有企业背景,主要还是提供一些理论思路。文章的核心也是想通过自动挖掘特征间的高阶交互关系来提升减少人工特征工程,但是与前面的 DeepFMDCN 等能够提供显式特征交叉能力的模型最大的差别在于:本文是通过不同 field 间特征做 Self-Attention 来实现特征的交互,也因此获得了一定的特征组合的可视化能力(即文章中声称提供了较好的可解释性)。

背景与动机

其实已有的深度模型的相关工作基本核心都是在做高阶的特征交叉,但是诸如 PNNFNNDeepCrossingWide&DeepDeepFM 等模型,主要是依赖前馈网络来实现高阶特征交叉,主要的问题是特征交叉过程是隐式的,很难解释是哪些特征间的组合起到了关键作用,这个问题也存在于 DCNxDeepFM 等提供显式特征交叉能力的模型中。

另外一些提供显式特征交叉能力的模型和算法也存在各种各样的问题,比如基于树+embedding 的$^{[1,2,3]}$,会将训练过程分裂成几个部分;比如显式做所有特征的高阶组合的 HOFM 模型$^{[4]}$,参数量过大,高于 5 阶的组合基本不可用(实际上根据张俊林的说法,高于 4 阶的特征组合就已经收益很低了,这在 DCN 的测试中也得到了一定的验证)。

基于这个背景,文章的目标是想找到一种特征自动进行高阶交叉的方法,既能弥补 MLP 对乘性特征组合捕获能力不强的弱点,又能够较好的解释哪些特征组合比较有效。

AutoInt 网络设计

模型概览

AutoInt 网络内部主要包括两块,如下图所示:

AutoInt 底层是 embedding 层,类似于 DeepFM 的设计,将所有的离散、连续特征都映射成一个等长的 embedding 向量,其中,离散特征是直接 lookup embedding 表,多值的离散特征使用 average pooling;连续特征则相当于乘以一个不含 bias 的 Dense 层。

AutoInt 核心是上面的交互层,使用 Multi-head Self-Attention 来实现,并且可以叠加多层,实现特征的高阶交叉。作者认为,特征组合的关键是知道哪些特征组合在一起有强大的表征能力,这个实际上相当于在人工特征工程中进行特征选择,那么在深度网络里怎么自动去实现特征组合的选择呢?作者受到 Self-Attention 的启发,考虑让每个 field 的特征与其他 field 的特征分别做 attention,根据 attention 的权重来判断该 field 特征与其他 field 特征组合的重要性,越重要的组合给予的权重越高,最后生成加权后的 sum pooling 作为该 field 特征与所有其他 field 特征组合的结果。

Self-Attention 层计算流程详解

关于 Self-Attention 相关的理论和应用,后面还会再单独介绍,感觉这一块是后面网络发展的重点。下面仅仅结合计算过程再详细的说明一下上面的思路,使用的符号与文章略有不同。

  1. 假设输入特征一共有 $m$ 个 field,每个 field 的特征记为 $\boldsymbol{x}_i$,对应的 embedding 向量记为 $\boldsymbol{e}_i$;所有 field 拼接起来的 embedding 记为 $\boldsymbol{e}$;
  2. 考虑第 $i$ 个 field 的特征 $\boldsymbol{x}_i$ 对应的 embedding 向量 $\boldsymbol{e}_i$,首先计算它经过单层 Self-Attention 后生成的特征组合向量 $\tilde{\boldsymbol{e}}_i$(其他 field 的计算过程完全相同);
  3. 对于 head $h$,根据该 head 中的 query、key、value 矩阵,计算第 $i$ 个 field 特征的 query 向量和所有其他 field 特征的 key、value 向量:
  4. 计算第 $i$ 个 field 的 query 向量与所有其他 field 特征对应的 key 向量的 attention 权重
  5. 计算第 $i$ 个 field 对应的所有其他 field 的加权 value 向量 $\tilde{\boldsymbol{e}}_i^{(h)}$,作为第 $h$ 个 head 中第 $i$ 个 field 特征与所有其他 field 特征交互后的组合向量:
  6. 将所有 head 的结果进行 concat,得到第 $i$ 个 field 特征与所有其他 field 特征交互后的组合向量:
  7. 将所有 field 的组合向量进行 concat,并加上输入层,得到单层 Self-Attention 的输出向量 $\tilde{\boldsymbol{e}}$:

说明:

  • 与 google 的原始论文 $[5]$ 相比,AutoInt 中的 Self-Attention 没有进行 scale,即第 4 步求 softmax 之前没有将内积除以一个缩放系数,导致的结果是突出了高效组合的重要性。当然,在实现的时候还是可以尝试把缩放加进来;
  • 文献 $[5]$ 在最后一步还会再过一个 LayerNormalization,文章里并没有加;实现的时候可以加了看看效果;
  • 在实现的时候,假设 embedding 的维度为 $d$,head 的数量为 $k$,则可以设置每个 head 中 query、key、value 矩阵 $\boldsymbol{W}^{(h)}$ 的维度为 $(\frac{d}{k}, d)$,这样得到的 $\tilde{\boldsymbol{e}}_i^{(h)}$ 就是 $\frac{d}{k}$ 维,将 $k$ 个 head 的结果 concat 后,$\tilde{\boldsymbol{e}}_i$ 又变成了 $d$ 维,从而保证输入的维度与输出的维度相同;这样的话,最后一步中矩阵 $\boldsymbol{W}$ 其实就不需要了(文章的实验中 $\boldsymbol{e}$ 与 $\tilde{\boldsymbol{e}}’$ 维度不同,因此需要通过 $\boldsymbol{W}$ 将 $\boldsymbol{e}$ 变成与 $\tilde{\boldsymbol{e}}’$ 相同的维度)。

Self-Attention 交叉能力分析

文章将有 $p$ 个不同 field 特征乘性组合的特征称为 $p$ 阶组合特征,记为 ,从计算过程容易看出来, 乃至 $\tilde{\boldsymbol{e}}_i$ 都是包含 $\boldsymbol{x}_i$ 与所有 交互的 2 阶组合特征:${g(x_i,x_1),g(x_i,x_2),\cdots,g(x_i,x_m)}$。因此,单层 Self-Attention 就能表征所有 field 的 2 阶组合特征。

到了两层时,由于第一层输出中每个 field 相当于都是包含了所有的 2 阶组合,因此它的输出就包含了 3 阶和 4 阶的组合特征,例如 $g(x_1,x_2,x_3,x_4)$ 就包含在 $\tilde{\boldsymbol{e}}_1$ 和 $\tilde{\boldsymbol{e}}_3$ 的交互中。同理,三层 Self-Attention 就包含 8 阶内的组合特征……与 DCN 中的 cross 层相比,cross 每层增加 1 阶特征组合,而 Self-Attention 每层增加 1 倍特征组合。

应用与讨论

上面已经介绍了 embedding 层和 Self-Attetion 层,其中,Self-Attention 层是可以直接堆叠的,由于有残差结构的设计(最后一步加上了输入),理论上可以堆的比较深(文章的实验也证明了这个设计是比较有效的)。它还可以作为子结构,通过串联或者并联的方式,嵌入到其他网络中去,例如:

  • 串联:最上面一层的 Self-Attention 输出可以直接送到 LR 里输出预测结果,或者再接一个 MLP 再输出预测结果;
  • 并联:embedding 层同时作为输入,送到 MLPFM、cross 等其他层中,最后所有层结果进行 concat,送到 LR 中输出预测结果;

训练一般还是使用 logloss 作为损失函数,用 Adam 等优化算法进行优化。

我在我们的数据集上测试的时候,发现 Self-Attention 层数也是 3 层就够了,到了 4 层测试 AUC 反而会降低,这与文章的参数是吻合的。

至于文章另外一个鼓吹的亮点,即特征组合的可解释性,实际上就是画出 attention 权重的热力图,主要是用于数据分析,感觉除了汇报好看点,也没啥实际的用处。

最后想说的一点,文章将不同 field 的特征当成了序列来做 Self-Attention,但其实 Self-Attention 也经常会用于对序列特征做 pooling,这也会在以后一起介绍。

参考文献

[1] Wang, Xiang, et al. “Tem: Tree-enhanced embedding model for explainable recommendation.” Proceedings of the 2018 World Wide Web Conference. 2018.
[2] Zhao, Qian, Yue Shi, and Liangjie Hong. “Gb-cent: Gradient boosted categorical embedding and numerical trees.” Proceedings of the 26th International Conference on World Wide Web. 2017.
[3] Zhu, Jie, et al. “Deep embedding forest: Forest-based serving with deep embedding features.” Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. 2017.
[4] Blondel, Mathieu, et al. “Higher-order factorization machines.” Advances in Neural Information Processing Systems. 2016.
[5] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems. 2017.