FlashAttention 算法详解
FlashAttention 是一种 IO-aware 的精确注意力算法,将内存复杂度从
目录
标准注意力瓶颈
标准自注意力定义为:
这会产生三个需要物化的中间矩阵:
核心问题:
| 问题 | 影响 |
|---|---|
| 内存占用 | |
| 带宽瓶颈 | GPU 算力 |
| IO 操作 |
图 1:Q/K/V 分块加载到 SRAM。中间量
核心 FlashAttention 概念
1. 分块 (Tiling)
将
分块大小选择:
| GPU 架构 | SRAM 容量 | 典型 |
|---|---|---|
| Volta (V100) | 96 KB | |
| Ampere (A100) | 164 KB | |
| Hopper (H100) | 228 KB |
分块为何有效:
- 每块放入高速 SRAM($\sim
\sim$2 TB/s)。 - 避免中间结果反复访问 HBM。
- 独立的 Q 块可并行处理。
2. Online Softmax
标准 softmax 需要对每行两次遍历(找
对每个 KV 块
图 2:Online softmax 状态更新。当新的 KV 块揭示更大的行最大值时,先前输出被
关键洞察: 处理新 KV 块时,若全局行最大值改变,先前输出必须通过
数值稳定性: 追踪运行最大值确保
3. 重计算 (Recomputation)
标准反向传播存储
| 阶段 | 存储内容 | 内存 |
|---|---|---|
| 前向 | 仅输出 | |
| 反向 | 从 |
权衡: 增加约 33% 的额外 FLOPs,但显著减少 HBM IO,整体仍加速。
图 3:反向传播在 SRAM 中从 forward 输出重计算
前向传播算法
输入: Q, K, V ∈ R^(N×d), scale = 1/√d
输出: O ∈ R^(N×d), L ∈ R^N
初始化: O = 0, m = -∞, l = 0 (每行)
对每个 Q 块 i (i = 1..T_r 并行):
将 Q_i 加载到 SRAM
对每个 KV 块 j = 1..T_c (顺序):
将 K_j, V_j 加载到 SRAM
S_ij = scale × Q_i × K_j^T # [B_r, B_c] 在 SRAM 中
m_new = max(m_i, rowmax(S_ij)) # 更新行最大值
P = exp(S_ij - m_new) # 局部 softmax 分子
l_new = exp(m_i - m_new) × l_i + rowsum(P)
O_i = (exp(m_i - m_new) × O_i + P × V_j) / l_new
m_i = m_new, l_i = l_new
L_i = m_i + log(l_i) # 存储 logsumexp关键操作:
- Q 块并行: 每个输出块由一个 CUDA block 独立计算。
- KV 块顺序: 跨所有 key 累积注意力。
- 输出修正: 发现新最大值时调整运行和。
反向传播算法
输入: Q, K, V, O, L, dO
输出: dQ, dK, dV
对每个 KV 块 j:
将 K_j, V_j 加载到 SRAM
初始化 dK_j = 0, dV_j = 0
对每个 Q 块 i:
将 Q_i, O_i, dO_i, L_i 加载到 SRAM
S_ij = scale × Q_i × K_j^T
P_ij = exp(S_ij - L_i) # 重计算注意力权重
D_i = rowsum(dO_i ⊙ O_i) # 对角项
dV_j += P_ij^T × dO_i # V 梯度
dP_ij = dO_i × V_j^T
dS_ij = P_ij ⊙ (dP_ij - D_i) # Softmax Jacobian
dQ_i += scale × dS_ij × K_j # Q 梯度
dK_j += scale × dS_ij^T × Q_i # K 梯度梯度流:
- dV: 使用重计算的注意力权重对上游梯度加权求和。
- dQ, dK: 通过 softmax Jacobian,使用重计算的
。 - 内存高效: 任何时刻都不需要
存储。
因果掩码
对于自回归模型,位置
| 情况 | 处理方式 |
|---|---|
| 完全跳过 | KV 块起始列 |
| 部分掩码 | 块内应用掩码(设为 |
效率提升: 约 50% 的块可完全跳过,计算量减少一半。
图 4:块级因果掩码。下三角块完全计算;对角块部分掩码;上三角块跳过。
FP16 实现
本实现完整支持 FP16(半精度)的前向与反向传播。
实现策略
FP16 输入在内部转为 FP32 计算,最终输出转回 FP16:
数值精度
| 操作 | 精度 |
|---|---|
| 矩阵乘法 ( | FP32 |
| Softmax 计算 | FP32 |
| 累加 | FP32 |
| 最终输出 | FP16 |
优势:
- 数值稳定性与 FP32 相当。
- 内存带宽减半(张量缩小 2 倍)。
- 所有现代 GPU 均支持(compute capability
5.3)。
内存复杂度分析
| 方法 | 前向内存 | 反向内存 | HBM IO |
|---|---|---|---|
| 标准注意力 | |||
| FlashAttention |
其中
实际内存节省
| 序列长度 | 标准注意力 | FlashAttention | 节省 |
|---|---|---|---|
| 1,024 | 4 MB | 8 KB | 99.8% |
| 4,096 | 64 MB | 32 KB | 99.95% |
| 16,384 | 1 GB | 128 KB | 99.99% |
实现亮点
分块配置
| head_dim | 每块 SRAM | ||
|---|---|---|---|
| 32 | 64 | 64 | $\sim$32 KB |
| 64 | 64 | 64 | $\sim$64 KB |
| 128 | 32 | 32 | $\sim$128 KB |
优化技术
| 技术 | 收益 |
|---|---|
| 向量化内存访问 | float4 加载/存储实现合并访问 |
| Launch Bounds | __launch_bounds__(128) 控制寄存器压力 |
| 动态共享内存 | 根据 head_dim 运行时分配 |
| 流安全 | 显式 workspace 生命周期管理 |
| Warp 级原语 | __shfl_sync 用于 warp 内归约 |
数据类型支持
| 数据类型 | 前向 | 反向 |
|---|---|---|
FP32 (float) | 完整 | 完整 |
FP16 (half) | 完整 | 完整 |
参考文献
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
- NeurIPS 2022
- arXiv:2205.14135
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Tri Dao
- ICLR 2024
- arXiv:2307.08691
Online normalizer calculation for softmax
- Maxim Milakov, Natalia Gimelshein
- arXiv:1805.02867
NVIDIA CUDA Programming Guide - Shared Memory