Softmax的进化之旅

由于主包最近准备了一场 Coding面试,初步领略了下Flash Attention的神奇,其中的关键操作便是对于Softmax的优化。面试结束之后主包也是决定记录一下这个操作,于是有了这篇博客。

由于是准备的Coding面,这里也会附上代码的😁 (所以也可以叫 手撕online softmax?)

为什么需要Online Softmax

这个首先要从Flash Attention说起,这里简单介绍,原始Transformer中,会存在一个Score矩阵,其维度为 N ⋅ N,非常巨大。这导致其在GPU上没法分块计算。由此我们想要分块计算Q,K,V矩阵,Flash Attention便完成了这个任务,有效减小了GPU的IO操作数量,从而实现了加速。

想法很简单,但是实现起来存在一些问题,其中的关键便是Softmax,我们从头开始说起

Softmax的“进化”

标准的Softmax

我们都知道Softmax的公式如下:

$$ Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}} $$

这个的手撕代码很简单,我们对矩阵做行Softmax:

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn.functional as F

x = torch.randn(2,5)
row_sum = torch.exp(x).sum(dim=-1,keepdim=True)
ours_out = torch.exp(x) / row_sum

print("======= ours softmax =======")
print(ours_out)
print("======= Standard softmax =======")
print(F.softmax(x, dim=-1))

输出如下:

1
2
3
4
5
6
======= ours softmax =======
tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675],
[0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])
======= Standard softmax =======
tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675],
[0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])


看上去非常不错,但是当我们希望使用fp16精度或者输入的数据大一些,由于指数的存在,会很容易的溢出,比如我们仅仅将x的输入扩大100倍(x = torch.randn(2,5)*100), 输出就会变为

1
2
3
======= ours softmax =======
tensor([[nan, 0., 0., 0., 0.],
[0., 0., nan, 0., 0.]])

结果变得不稳定,同时出现数据溢出。由此,便提出了 Safe_softmax

Safe_softmax

为了解决上述问题,我们可以很简单的对输入做一个平移。具体而言我们可以让每个x减去其所在行的最大值,核心可以描述为如下公式:

$$ Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}} = \frac{e^{x_i-x_{max}}}{\sum_{j}^{N} e^{x_j-x_{max}}} $$

证明比较简单,我们用上一个标准的softmax不能跑的情况测试一次

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn.functional as F

x = torch.randn(2,5)*100
x_max,_ = x.max(dim=-1, keepdim=True)
x = x - x_max
row_sum = torch.exp(x).sum(dim=-1,keepdim=True)
ours_out = torch.exp(x) / row_sum
print("======= ours softmax =======")
print(ours_out)
print("======= Standard softmax =======")
print(F.softmax(x, dim=-1))

输出:

1
2
3
4
5
6
======= ours softmax =======
tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00],
[5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])
======= Standard softmax =======
tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00],
[5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])


目前我们简单解决了Softmax的溢出问题,但是我们会发现其计算过程中,依赖全局的和。这也限制了我们希望能够局部分块处理Attention的计算。由此提出了Online Softmax

Online Softmax

Online Softmax将原来的计算过程,通过两个全局变量转变为了流式计算过程。

具体而言,我们假设输入X可以分为两个部分,即 X = (X1,X2) 。我们依次处理两个部分。

首先针对X1, 我们通过如下公式计算,其行和最大值 M1,以及Row_Sum:L1,

$$ L_1 = \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) , 即局部exp之和 $$

由此,基于上述两个全局变量我们能够计算出这个局部的(即X1)softmax值,同理我们可以得到M2, L2

但是想要基于全局准确的Softmax结果还要进一步计算,具体可以做如下推导:

$$ \begin{aligned} M &= Max(M_1, M_2) \\ L &= \sum_{j=1}^{X.size(-1)} exp(X^j - M) \\ &= \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M) + \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M) \\ &= exp(M_1 - M) * \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) + exp(M_2 - M) * \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M_2) \\ & = L_1 * exp(M_1 - M) + L_2 * exp(M_2 - M) \end{aligned} $$

由此我们更具局部的信息计算出来全局的两个信息,当然在实际输入顺序中,L2 可以不用计算,可以计算得到M之后,直接计算局部和, 即:

$$ L = L_1 * exp(M_1 - M) + \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M) $$

得到两个的全局信息之后,在遍历一遍即可得到最后的Softmax值,具体的代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn.functional as F

x = torch.randn(2,16)
out = torch.zeros_like(x)
x_blocks = torch.split(x, 4, dim=1)
out_blocks = list(torch.split(out, 4, dim=1))

m = torch.ones(x.size(0)).unsqueeze(-1)*-1e9
l = torch.zeros(x.size(0)).unsqueeze(-1)

for xi in x_blocks:
mi,_ = xi.max(dim=-1, keepdim=True)
m_new = torch.maximum(mi,m)
l = l*torch.exp(m - m_new) + torch.exp(xi-m_new).sum(dim=-1, keepdim=True)
m = m_new

for i in range(len(x_blocks)):
out_blocks[i] = torch.exp(x_blocks[i] - m) / l

out = torch.cat(out_blocks, dim=1)
print("======= ours softmax =======")
print(out)
print("======= Standard softmax =======")
print(F.softmax(x, dim=-1))

输出:

1
2
3
4
5
6
7
8
9
10
======= ours softmax =======
tensor([[0.0072, 0.1026, 0.0273, 0.0057, 0.2055, 0.1391, 0.0090, 0.0694, 0.0036,
0.0191, 0.0087, 0.0462, 0.0069, 0.0355, 0.2882, 0.0259],
[0.1437, 0.0477, 0.0249, 0.0123, 0.0210, 0.0544, 0.0244, 0.0082, 0.2192,
0.1388, 0.0311, 0.1340, 0.0217, 0.0071, 0.0646, 0.0468]])
======= Standard softmax =======
tensor([[0.0072, 0.1026, 0.0273, 0.0057, 0.2055, 0.1391, 0.0090, 0.0694, 0.0036,
0.0191, 0.0087, 0.0462, 0.0069, 0.0355, 0.2882, 0.0259],
[0.1437, 0.0477, 0.0249, 0.0123, 0.0210, 0.0544, 0.0244, 0.0082, 0.2192,
0.1388, 0.0311, 0.1340, 0.0217, 0.0071, 0.0646, 0.0468]])

最后用一张图片来总结一下online softmax的思路,我觉得挺贴切的。即可以将原本三次扫描的Safe Softmax转变为两次扫描的 Online Softmax

示意图

当然在实际的Flash Attention中,可以进一步简化为one pass的扫描,这就下回分解了😋

参考

【手撕online softmax】Flash Attention前传,一撕一个不吱声


Softmax的进化之旅
https://gongzihang.github.io/BLOG/2025/06/19/Softmax的进化之旅/
作者
Zihang Gong
发布于
2025年6月19日
更新于
2025年6月20日
许可协议