FlashAttention Algorithm Deep Dive
FlashAttention is an IO-aware exact attention algorithm that reduces memory complexity from
Table of Contents
- Standard Attention Bottleneck
- Core FlashAttention Concepts
- Forward Pass Algorithm
- Backward Pass Algorithm
- Causal Masking
- FP16 Implementation
- Memory Complexity Analysis
- Implementation Highlights
- References
Standard Attention Bottleneck
Standard self-attention is defined as:
This expands to three materialized intermediate matrices:
Core Problem:
| Issue | Impact |
|---|---|
| Memory Usage | |
| Bandwidth Bottleneck | GPU compute |
| IO Operations |
Figure 1: Q/K/V tiling into SRAM blocks. Intermediate
Core FlashAttention Concepts
1. Tiling
Divide
Block Size Selection:
| GPU Architecture | SRAM Size | Typical |
|---|---|---|
| Volta (V100) | 96 KB | |
| Ampere (A100) | 164 KB | |
| Hopper (H100) | 228 KB |
Why Tiling Works:
- Each block fits in fast SRAM ($\sim
\sim$2 TB/s). - Avoids repeated HBM accesses for intermediate results.
- Enables parallel processing of independent Q blocks.
2. Online Softmax
Standard softmax requires two passes over each row (find
For each KV block
Figure 2: Online softmax state updates. When a new KV block reveals a larger row max, previous outputs are rescaled by
Key Insight: When processing a new KV block, previous outputs must be corrected by
Numerical Stability: Tracking the running maximum ensures
3. Recomputation
Standard backward pass stores the
| Phase | Storage | Memory |
|---|---|---|
| Forward | Output | |
| Backward | Recompute |
Trade-off: Increases computation by $\sim$33% extra FLOPs, but significantly reduces HBM IO, resulting in overall speedup.
Figure 3: Backward pass recomputes
Forward Pass Algorithm
Input: Q, K, V ∈ R^(N×d), scale = 1/√d
Output: O ∈ R^(N×d), L ∈ R^N
Initialize: O = 0, m = -∞, l = 0 (per row)
For each Q block i (parallel over i = 1..T_r):
Load Q_i to SRAM
For each KV block j = 1..T_c (sequential):
Load K_j, V_j to SRAM
S_ij = scale × Q_i × K_j^T # [B_r, B_c] in SRAM
m_new = max(m_i, rowmax(S_ij)) # Update row max
P = exp(S_ij - m_new) # Local softmax numerator
l_new = exp(m_i - m_new) × l_i + rowsum(P)
O_i = (exp(m_i - m_new) × O_i + P × V_j) / l_new
m_i = m_new, l_i = l_new
L_i = m_i + log(l_i) # Store logsumexpKey Operations:
- Parallel over Q blocks: Each output block computed independently by one CUDA block.
- Sequential over KV blocks: Accumulate attention across all keys.
- Output correction: Adjust running sum when a new maximum is found.
Backward Pass Algorithm
Input: Q, K, V, O, L, dO
Output: dQ, dK, dV
For each KV block j:
Load K_j, V_j to SRAM
Initialize dK_j = 0, dV_j = 0
For each Q block i:
Load Q_i, O_i, dO_i, L_i to SRAM
S_ij = scale × Q_i × K_j^T
P_ij = exp(S_ij - L_i) # Recompute attention weights
D_i = rowsum(dO_i ⊙ O_i) # Diagonal term
dV_j += P_ij^T × dO_i # V gradient
dP_ij = dO_i × V_j^T
dS_ij = P_ij ⊙ (dP_ij - D_i) # Softmax Jacobian
dQ_i += scale × dS_ij × K_j # Q gradient
dK_j += scale × dS_ij^T × Q_i # K gradientGradient Flow:
- dV: Weighted sum of upstream gradients using recomputed attention weights.
- dQ, dK: Through softmax Jacobian using recomputed
. - Memory efficient: No
storage needed at any point.
Causal Masking
For autoregressive models, position
| Case | Handling |
|---|---|
| Full skip | KV block start column |
| Partial mask | Apply mask within block (set to |
Efficiency Gain: Approximately 50% of blocks can be skipped entirely, reducing computation by half.
Figure 4: Causal masking at block granularity. Lower-triangular blocks are fully computed; diagonal blocks are partially masked; upper-triangular blocks are skipped.
FP16 Implementation
This implementation fully supports FP16 (half precision) for both forward and backward passes.
Implementation Strategy
FP16 inputs are converted to FP32 internally for computation, then converted back to FP16 for output:
Numerical Precision
| Operation | Precision |
|---|---|
| Matrix multiplication ( | FP32 |
| Softmax computation | FP32 |
| Accumulation | FP32 |
| Final output | FP16 |
Benefits:
- Numerical stability comparable to FP32.
- Reduced memory bandwidth (2
smaller tensors). - Supported on all modern GPUs (compute capability
5.3).
Memory Complexity Analysis
| Method | Forward Memory | Backward Memory | HBM IO |
|---|---|---|---|
| Standard Attention | |||
| FlashAttention |
Where
Real Memory Savings
| Sequence Length | Standard Attention | FlashAttention | Savings |
|---|---|---|---|
| 1,024 | 4 MB | 8 KB | 99.8% |
| 4,096 | 64 MB | 32 KB | 99.95% |
| 16,384 | 1 GB | 128 KB | 99.99% |
Implementation Highlights
Block Configuration
| head_dim | SRAM per Block | ||
|---|---|---|---|
| 32 | 64 | 64 | $\sim$32 KB |
| 64 | 64 | 64 | $\sim$64 KB |
| 128 | 32 | 32 | $\sim$128 KB |
Optimization Techniques
| Technique | Benefit |
|---|---|
| Vectorized Memory Access | float4 loads/stores for coalesced bandwidth |
| Launch Bounds | __launch_bounds__(128) controls register pressure |
| Dynamic Shared Memory | Runtime allocation based on head_dim |
| Stream Safety | Explicit workspace lifetime management |
| Warp-level Primitives | __shfl_sync for intra-warp reduction |
Data Type Support
| Data Type | Forward | Backward |
|---|---|---|
FP32 (float) | Full | Full |
FP16 (half) | Full | Full |
References
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
- NeurIPS 2022
- arXiv:2205.14135
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Tri Dao
- ICLR 2024
- arXiv:2307.08691
Online normalizer calculation for softmax
- Maxim Milakov, Natalia Gimelshein
- arXiv:1805.02867
NVIDIA CUDA Programming Guide - Shared Memory