关于FlashAttention?
·5 分钟阅读
- 减少 HBM 读写(Tiling 瓦片化) 在标准的 Attention 计算中,计算过程通常是: 计算 $QK^T = S$(写入显存) 计算 $Softmax(S) = P$(读取 $S$,计算后写入显存) 计算 $PV = O$(读取 $P$,计算后写入显存) 显存(HBM)的读写速度远慢于 GPU 核心的计算速度。FlashAttention 使用了 Tiling(分块) 技术: 将大矩阵切分成小的方块。 将小块数据加载到 GPU 内部极快的 SRAM(静态随机存取存储器)中。 在 SRAM 中直接完成整个 Attention 的流水线操作,最后只将结果写回显存。 FlashAttention是如何解决Softmax要求必须先读取整块矩阵的
- 标准 Softmax 的局部瓶颈 为了防止数值溢出,标准的 Softmax 通常分为三步: 找出整行最大值 $m = \max(x)$。 计算每个元素的指数并求和:$d = \sum e^{x_i - m}$。 归一化:$y_i = \frac{e^{x_i - m}}{d}$。 这要求必须遍历两次完整数据(一次找最大值,一次求和)后,才能进行第三次遍历计算最终结果。在分块(Tiling)计算时,如果你只看一个小块,你不知道这个小块的最大值是不是整行的全局最大值。
- 在线 Softmax 的增量公式 FlashAttention 引入了一个可以合并的公式。假设我们将一行数据拆分为两块:$B_1$ 和 $B_2$。 第一步: 当我们只处理 $B_1$ 时,我们记录该块的最大值 $m_1$ 和局部累加和 $d_1$。 第二步: 当处理 $B_2$ 时,我们得到它的局部最大值 $m_2$ 和累加和 $d_2$。 关键:更新全局统计量。 全局最大值 $m_{new} = \max(m_1, m_2)$。 重缩放(Rescaling): 这是最精妙的地方。由于最大值变了,之前的累加和 $d_1$ 就不对了。我们需要对旧的累加和进行“修正”: $$d_{new} = d_1 \cdot e^{m_1 - m_{new}} + d_2 \cdot e^{m_2 - m_{new}}$$ 通过这种方式,可以在遍历数据的同时,不断更新当前的全局最大值和归一化分母,而不需要一次性读入整行。








