用 FTRL 训练 FM 模型

近期尝试了基于 FTRL 来训练 FM 模型,用于短视频的排序。这篇博客主要总结一下算法的理论推导和工程化的一些心得。

一、FM (Factorization Machines) 模型推导

FM 模型简介

在设计排序模型时,至关重要的步骤就是特征的构造和选择。除了一些简单单特征外,往往要对特征进行组合,例如对用户的年龄、性别组合,对视频的演员、类别进行组合等,更大的特征空间能够增加模型表征能力。对于特征组合来说,业界现在通用的做法主要有两大类:

  • FM 系列,常见的模型包括 FMFFMDeepFM,它们对特征的取值范围比较敏感。
  • Tree 系列,常见的模型包括 GBDT,它们对特征的取值范围不敏感。

其中,FM 系列由于适合处理大规模稀疏数据,并且易于与深度神经网络结合,因此使用十分广泛,成为大厂居家必备。

FM 模型的主要思想是在 LR 的基础上,对所有的特征自动做两两组合$^{[1,2]}$。两两组合最直观的方法就是为每对特征组合设置一个参数(例如 Poly2 模型),但是这样就需要 $\text{O}(n^2)$ 个参数,当特征数量很多时,需要的样本量也是巨大的,往往不可能所有的参数都有充足的样本训练。因此 FM 考虑使用矩阵分解的方式来还原这个 $n\times n$ 的参数矩阵,只需要 $n\times k$ ($k$ 通常是个很小的常数)的参数即可实现特征两两组合的目的。

具体来说,给定样本 $z=(\boldsymbol{x},y)$,记 $\boldsymbol{v}_i = (v_i^{(1)},\cdots,v_i^{(d)})^\top$ 为第 $i$ 维特征对应的隐式向量,则 FM 模型为:

FM 的参数包括 $\boldsymbol{w}={w_0,\cdots w_n,v_1^{(1)},\cdots v_n^{(d)}}$,容易得到 FM 对各参数的偏导如下:

FM 模型求解(回归问题)

此时直接将 $\hat{y} = f(\boldsymbol{x}|\boldsymbol{w})$ 作为对 $y$ 的预测结果,因此可以将样本 $z=(\boldsymbol{x},y)$ 的损失函数定义为:

损失函数对参数的偏导为:

FM 模型求解(二分类问题)

此时将 $\hat{y} = \pi(f(\boldsymbol{x}|\boldsymbol{w}))=\frac{1}{1+e^{-f(\boldsymbol{x}|\boldsymbol{w})}}$ 作为对 $y$ 的预测结果,其中,$\pi(x)$ 为 Sigmoid 函数。还是分标签取值来进行讨论(损失函数的推导参考 LR 模型)。

1. Label 为 {1,0}

则将样本 $z=(\boldsymbol{x},y)$ 的损失函数定义为 LogLoss 函数:

损失函数对参数的偏导为:

2. Label 为 {1,-1}

则将样本 $z=(\boldsymbol{x},y)$ 的损失函数定义为 SigmoidLoss 函数:

损失函数对参数的偏导为:

二、FTRL Optimizer 介绍

上面一陀公式实际上是优化算法求梯度的时候用到的。优化算法目前有很多种,在如在线更新模型或者在线排序等对性能有严格要求的场景中,模型的稀疏解十分关键。稀疏的模型意味着只保留最关键特征的参数,意味着更少的存储、查询与计算。为了得到模型的稀疏解,通常的做法是使用 L1 正则、基于参数大小或者累积梯度大小的截断等技术。其中,FTRL 集众家之长,实现了精度与稀疏性的平衡$^{[3]}$。

FTRL 更像是一种启发式的模型组装,其特征权重的更新公式为:

其中,$\boldsymbol{g}^{1:t}$ 表示 $1\sim t$ 轮迭代中参数梯度的累积和,其中,L1 正则化部分是为了生成稀疏解,L2 正则化部分是为了使解更平滑(在论文的推导中不包含这一项),而 $\parallel \boldsymbol{w}-\boldsymbol{w}^t\parallel^2_2$ 是为了保证 $\boldsymbol{w}$ 不要离已迭代过的解太远。经过比较复杂的推导(参考文献 [4]),可以得到每一维参数的求解式:

其中,令 ,主要是方便存储和迭代计算。

三、基于 FTRL 训练 FM 的算法流程

FTRL 来训练 FM 模型,由于我们组习惯用 {0,1} 作为样本标签,则根据式 (2), (6), (10),可以得到如下算法流程:

Algorithm Ftrl+FM

FTRL 算法特别适合在线更新模型,即基于每条实时样本更新模型。但是出于性能和可靠性考虑,也可以稍加修改应用于离线训练或者近线批量训练。例如,离线训练任务,将每天/每小时的数据作为一个批次用来更新 FM 模型:在每轮迭代时,需要将一批次样本的所有梯度、损失等计算结果进行汇总(可以简单的用平均值来代替),再用汇总后的值更新模型。为了训练充分,可以对每个批次的样本迭代训练若干轮。训练完的模型需要将参数 $\boldsymbol{w}, \boldsymbol{s}, \boldsymbol{z}$ 保存起来,下次加载后再增量更新;而在线预测时,只需要加载参数 $\boldsymbol{w}$ 即可。

另外,由于短视频的标签、UP 主等特征变化较快,因此对于离散特征的编码可以考虑使用特征 Hash,虽然牺牲了一定的可解释性,并且存在一定的编码冲突,但是实测下来效果还是不错的,并且工程上确实能省下很多麻烦,提升不少性能。

最后,虽然 FM 模型具备表征特征两两组合的能力,但是实际上我们发现由于样本、调参等的限制,并不能充分发掘每对特征组合的作用,并且该模型对于三个及以上特征的组合就完全无能为力了。因此,实际应用时还是不能太依赖模型的自动特征组合能力,如果有什么对业务比较有帮助的特征,还是人工生成,再一起丢到模型里去训练吧。

参考文献

[1] Rendle, S. (2011). Factorization Machines. IEEE International Conference on Data Mining.
[2] Rendle, S. (2012). Factorization machines with libfm. Acm Transactions on Intelligent Systems & Technology, 3(3), 1-22.
[3] McMahan, H. B., Holt, G., Sculley, D., Young, M., Ebner, D., Grady, J., … & Chikkerur, S. (2013, August). Ad click prediction: a view from the trenches. In Proceedings of the 19th ACM SIGKDD international conference on Knowledge discovery and data mining (pp. 1222-1230). ACM.
[4] 冯扬 (2014). 在线最优化求解.