Skip to content

FlashAttention 算法详解

FlashAttention 是一种 IO-aware 的精确注意力算法,将内存复杂度从 O(N2) 降至 O(N),同时数值上严格等价于标准注意力。


目录


标准注意力瓶颈

标准自注意力定义为:

Attention(Q,K,V)=softmax(QKTd)V

这会产生三个需要物化的中间矩阵:

S=QKTRN×N,P=softmax(S)RN×N,O=PVRN×d

核心问题: SP 具有 O(N2) 大小,必须存放在 HBM(设备内存)中。对于大的 N

问题影响
内存占用N=4096, 32 heads SP 就约 2 GB
带宽瓶颈GPU 算力 HBM 带宽;时间由数据搬运主导
IO 操作SP 各需写入和读出 HBM:共 4 次 O(N2) 操作

分块概览

图 1:Q/K/V 分块加载到 SRAM。中间量 SP 永不触碰 HBM。


核心 FlashAttention 概念

1. 分块 (Tiling)

QKV 切分为可放入 SRAM(共享内存 / L1 缓存)的块:

Q=[Q1,Q2,,QTr],QiRBr×dK=[K1,K2,,KTc],KjRBc×dV=[V1,V2,,VTc],VjRBc×d

分块大小选择:

GPU 架构SRAM 容量典型 Br×Bc
Volta (V100)96 KB64×64
Ampere (A100)164 KB128×64
Hopper (H100)228 KB128×128

分块为何有效:

  • 每块放入高速 SRAM($\sim19TB/sHBM\sim$2 TB/s)。
  • 避免中间结果反复访问 HBM。
  • 独立的 Q 块可并行处理。

2. Online Softmax

标准 softmax 需要对每行两次遍历(找 max 计算 exp 归一化)。FlashAttention 使用 online softmax,在单次遍历 KV 块的过程中增量更新:

对每个 KV 块 j,由 Q 块 i 处理:

mijnew=max(mijold,rowmax(Sij))P~ij=exp(Sijmijnew)lijnew=exp(mijoldmijnew)lijold+rowsum(P~ij)Oijnew=exp(mijoldmijnew)Oijold+P~ijVjlijnew

Online Softmax 状态机

图 2:Online softmax 状态更新。当新的 KV 块揭示更大的行最大值时,先前输出被 exp(moldmnew) 重新缩放。

关键洞察: 处理新 KV 块时,若全局行最大值改变,先前输出必须通过 exp(moldmnew) 修正。

数值稳定性: 追踪运行最大值确保 exp() 永不过溢,即使注意力分数很大。

3. 重计算 (Recomputation)

标准反向传播存储 O(N2) 的注意力矩阵 P 用于梯度计算。FlashAttention 的策略:

阶段存储内容内存
前向仅输出 O 和 logsumexp LO(N)
反向Q,K,V,O,L 实时重计算 PO(N)

权衡: 增加约 33% 的额外 FLOPs,但显著减少 HBM IO,整体仍加速。

反向重计算数据流

图 3:反向传播在 SRAM 中从 forward 输出重计算 Pij。不存储 O(N2) 矩阵。


前向传播算法

输入:  Q, K, V ∈ R^(N×d), scale = 1/√d
输出: O ∈ R^(N×d), L ∈ R^N

初始化: O = 0, m = -∞, l = 0  (每行)

对每个 Q 块 i (i = 1..T_r 并行):
    将 Q_i 加载到 SRAM
    对每个 KV 块 j = 1..T_c (顺序):
        将 K_j, V_j 加载到 SRAM
        S_ij = scale × Q_i × K_j^T           # [B_r, B_c] 在 SRAM 中
        m_new = max(m_i, rowmax(S_ij))       # 更新行最大值
        P = exp(S_ij - m_new)                 # 局部 softmax 分子
        l_new = exp(m_i - m_new) × l_i + rowsum(P)
        O_i = (exp(m_i - m_new) × O_i + P × V_j) / l_new
        m_i = m_new, l_i = l_new
    L_i = m_i + log(l_i)                      # 存储 logsumexp

关键操作:

  1. Q 块并行: 每个输出块由一个 CUDA block 独立计算。
  2. KV 块顺序: 跨所有 key 累积注意力。
  3. 输出修正: 发现新最大值时调整运行和。

反向传播算法

输入:  Q, K, V, O, L, dO
输出: dQ, dK, dV

对每个 KV 块 j:
    将 K_j, V_j 加载到 SRAM
    初始化 dK_j = 0, dV_j = 0
    对每个 Q 块 i:
        将 Q_i, O_i, dO_i, L_i 加载到 SRAM
        S_ij = scale × Q_i × K_j^T
        P_ij = exp(S_ij - L_i)                # 重计算注意力权重
        D_i = rowsum(dO_i ⊙ O_i)              # 对角项
        dV_j += P_ij^T × dO_i                 # V 梯度
        dP_ij = dO_i × V_j^T
        dS_ij = P_ij ⊙ (dP_ij - D_i)          # Softmax Jacobian
        dQ_i += scale × dS_ij × K_j           # Q 梯度
        dK_j += scale × dS_ij^T × Q_i        # K 梯度

梯度流:

  1. dV: 使用重计算的注意力权重对上游梯度加权求和。
  2. dQ, dK: 通过 softmax Jacobian,使用重计算的 P
  3. 内存高效: 任何时刻都不需要 O(N2) 存储。

因果掩码

对于自回归模型,位置 i 只能 attend 到位置 i。FlashAttention 的块结构支持高效因果掩码:

情况处理方式
完全跳过KV 块起始列 > Q 块结束行 跳过整个块
部分掩码块内应用掩码(设为

效率提升: 约 50% 的块可完全跳过,计算量减少一半。

因果掩码块

图 4:块级因果掩码。下三角块完全计算;对角块部分掩码;上三角块跳过。


FP16 实现

本实现完整支持 FP16(半精度)的前向与反向传播。

实现策略

FP16 输入在内部转为 FP32 计算,最终输出转回 FP16:

输入: halfQ,K,V加载FP32 寄存器计算FP32 累加器存储halfO,L

数值精度

操作精度
矩阵乘法 (Q×KT)FP32
Softmax 计算FP32
累加FP32
最终输出FP16

优势:

  • 数值稳定性与 FP32 相当。
  • 内存带宽减半(张量缩小 2 倍)。
  • 所有现代 GPU 均支持(compute capability 5.3)。

内存复杂度分析

方法前向内存反向内存HBM IO
标准注意力O(N2)O(N2)O(N2+Nd)
FlashAttentionO(N)O(N)O(N2d2M)

其中 M 为 SRAM 容量。当 M=Θ(Nd) 时,IO 复杂度趋近 O(Nd),这是最优的,因为仅输入输出就需 Θ(Nd)

实际内存节省

序列长度标准注意力FlashAttention节省
1,0244 MB8 KB99.8%
4,09664 MB32 KB99.95%
16,3841 GB128 KB99.99%

实现亮点

分块配置

head_dimBrBc每块 SRAM
326464$\sim$32 KB
646464$\sim$64 KB
1283232$\sim$128 KB

优化技术

技术收益
向量化内存访问float4 加载/存储实现合并访问
Launch Bounds__launch_bounds__(128) 控制寄存器压力
动态共享内存根据 head_dim 运行时分配
流安全显式 workspace 生命周期管理
Warp 级原语__shfl_sync 用于 warp 内归约

数据类型支持

数据类型前向反向
FP32 (float)完整完整
FP16 (half)完整完整

参考文献

  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

    • Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
    • NeurIPS 2022
    • arXiv:2205.14135
  2. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  3. Online normalizer calculation for softmax

  4. NVIDIA CUDA Programming Guide - Shared Memory

Stable v0.3.0 baseline. Lean CUDA FlashAttention reference.