Softmax的进化之旅
由于主包最近准备了一场 Coding面试,初步领略了下Flash Attention的神奇,其中的关键操作便是对于Softmax的优化。面试结束之后主包也是决定记录一下这个操作,于是有了这篇博客。
由于是准备的Coding面,这里也会附上代码的😁 (所以也可以叫 手撕online softmax?)
为什么需要Online Softmax
这个首先要从Flash Attention说起,这里简单介绍,原始Transformer中,会存在一个Score矩阵,其维度为 N ⋅ N,非常巨大。这导致其在GPU上没法分块计算。由此我们想要分块计算Q,K,V矩阵,Flash Attention便完成了这个任务,有效减小了GPU的IO操作数量,从而实现了加速。
想法很简单,但是实现起来存在一些问题,其中的关键便是Softmax,我们从头开始说起
Softmax的“进化”
标准的Softmax
我们都知道Softmax的公式如下:
$$ Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}} $$
这个的手撕代码很简单,我们对矩阵做行Softmax:
1 |
|
输出如下: 1
2
3
4
5
6======= ours softmax =======
tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675],
[0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])
======= Standard softmax =======
tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675],
[0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])
看上去非常不错,但是当我们希望使用fp16精度或者输入的数据大一些,由于指数的存在,会很容易的溢出,比如我们仅仅将x的输入扩大100倍(x = torch.randn(2,5)*100
),
输出就会变为
1 |
|
结果变得不稳定,同时出现数据溢出。由此,便提出了
Safe_softmax
Safe_softmax
为了解决上述问题,我们可以很简单的对输入做一个平移。具体而言我们可以让每个x减去其所在行的最大值,核心可以描述为如下公式:
$$ Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}} = \frac{e^{x_i-x_{max}}}{\sum_{j}^{N} e^{x_j-x_{max}}} $$
证明比较简单,我们用上一个标准的softmax不能跑的情况测试一次
1 |
|
输出: 1
2
3
4
5
6======= ours softmax =======
tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00],
[5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])
======= Standard softmax =======
tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00],
[5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])
目前我们简单解决了Softmax的溢出问题,但是我们会发现其计算过程中,依赖全局的和。这也限制了我们希望能够局部分块处理Attention的计算。由此提出了Online Softmax
。
Online Softmax
Online Softmax
将原来的计算过程,通过两个全局变量转变为了流式计算过程。
具体而言,我们假设输入X
可以分为两个部分,即 X = (X1,X2)
。我们依次处理两个部分。
首先针对X1,
我们通过如下公式计算,其行和最大值 M1,以及Row_Sum
:L1,
$$ L_1 = \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) , 即局部exp之和 $$
由此,基于上述两个全局变量我们能够计算出这个局部的(即X1)softmax值,同理我们可以得到M2, L2
但是想要基于全局准确的Softmax结果还要进一步计算,具体可以做如下推导:
$$ \begin{aligned} M &= Max(M_1, M_2) \\ L &= \sum_{j=1}^{X.size(-1)} exp(X^j - M) \\ &= \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M) + \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M) \\ &= exp(M_1 - M) * \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) + exp(M_2 - M) * \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M_2) \\ & = L_1 * exp(M_1 - M) + L_2 * exp(M_2 - M) \end{aligned} $$
由此我们更具局部的信息计算出来全局的两个信息,当然在实际输入顺序中,L2 可以不用计算,可以计算得到M之后,直接计算局部和, 即:
$$ L = L_1 * exp(M_1 - M) + \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M) $$
得到两个的全局信息之后,在遍历一遍即可得到最后的Softmax值,具体的代码实现如下:
1 |
|
输出:
1 |
|
最后用一张图片来总结一下online softmax的思路,我觉得挺贴切的。即可以将原本三次扫描的Safe Softmax转变为两次扫描的 Online Softmax

当然在实际的Flash Attention中,可以进一步简化为one pass的扫描,这就下回分解了😋