Skip to content

FlashAttention Algorithm Deep Dive

FlashAttention is an IO-aware exact attention algorithm that reduces memory complexity from O(N2) to O(N) while matching standard attention numerically.


Table of Contents


Standard Attention Bottleneck

Standard self-attention is defined as:

Attention(Q,K,V)=softmax(QKTd)V

This expands to three materialized intermediate matrices:

S=QKTRN×N,P=softmax(S)RN×N,O=PVRN×d

Core Problem: S and P have O(N2) size and must reside in HBM (device memory). For large N:

IssueImpact
Memory UsageN=4096, 32 heads ~2 GB just for S and P
Bandwidth BottleneckGPU compute HBM bandwidth; time dominated by data movement
IO OperationsS and P each require write-to and read-from HBM: 4 O(N2) operations total

Tiling Overview

Figure 1: Q/K/V tiling into SRAM blocks. Intermediate S and P never touch HBM.


Core FlashAttention Concepts

1. Tiling

Divide Q, K, V into blocks that fit in SRAM (shared memory / L1 cache):

Q=[Q1,Q2,,QTr],QiRBr×dK=[K1,K2,,KTc],KjRBc×dV=[V1,V2,,VTc],VjRBc×d

Block Size Selection:

GPU ArchitectureSRAM SizeTypical Br×Bc
Volta (V100)96 KB64×64
Ampere (A100)164 KB128×64
Hopper (H100)228 KB128×128

Why Tiling Works:

  • Each block fits in fast SRAM ($\sim19TB/s)insteadofslowHBM(\sim$2 TB/s).
  • Avoids repeated HBM accesses for intermediate results.
  • Enables parallel processing of independent Q blocks.

2. Online Softmax

Standard softmax requires two passes over each row (find max compute exp sum normalize). FlashAttention uses online softmax to update incrementally in a single pass over KV blocks:

For each KV block j processed by Q block i:

mijnew=max(mijold,rowmax(Sij))P~ij=exp(Sijmijnew)lijnew=exp(mijoldmijnew)lijold+rowsum(P~ij)Oijnew=exp(mijoldmijnew)Oijold+P~ijVjlijnew

Online Softmax State Machine

Figure 2: Online softmax state updates. When a new KV block reveals a larger row max, previous outputs are rescaled by exp(moldmnew).

Key Insight: When processing a new KV block, previous outputs must be corrected by exp(moldmnew) because the global row maximum may have changed.

Numerical Stability: Tracking the running maximum ensures exp() never overflows, even for large attention scores.

3. Recomputation

Standard backward pass stores the O(N2) attention matrix P for gradient computation. FlashAttention's strategy:

PhaseStorageMemory
ForwardOutput O and logsumexp L onlyO(N)
BackwardRecompute P from Q,K,V,O,L on-the-flyO(N)

Trade-off: Increases computation by $\sim$33% extra FLOPs, but significantly reduces HBM IO, resulting in overall speedup.

Backward Recompute Flow

Figure 3: Backward pass recomputes Pij in SRAM from forward outputs. No O(N2) matrix is stored.


Forward Pass Algorithm

Input:  Q, K, V ∈ R^(N×d), scale = 1/√d
Output: O ∈ R^(N×d), L ∈ R^N

Initialize: O = 0, m = -∞, l = 0  (per row)

For each Q block i (parallel over i = 1..T_r):
    Load Q_i to SRAM
    For each KV block j = 1..T_c (sequential):
        Load K_j, V_j to SRAM
        S_ij = scale × Q_i × K_j^T           # [B_r, B_c] in SRAM
        m_new = max(m_i, rowmax(S_ij))        # Update row max
        P = exp(S_ij - m_new)                 # Local softmax numerator
        l_new = exp(m_i - m_new) × l_i + rowsum(P)
        O_i = (exp(m_i - m_new) × O_i + P × V_j) / l_new
        m_i = m_new, l_i = l_new
    L_i = m_i + log(l_i)                      # Store logsumexp

Key Operations:

  1. Parallel over Q blocks: Each output block computed independently by one CUDA block.
  2. Sequential over KV blocks: Accumulate attention across all keys.
  3. Output correction: Adjust running sum when a new maximum is found.

Backward Pass Algorithm

Input:  Q, K, V, O, L, dO
Output: dQ, dK, dV

For each KV block j:
    Load K_j, V_j to SRAM
    Initialize dK_j = 0, dV_j = 0
    For each Q block i:
        Load Q_i, O_i, dO_i, L_i to SRAM
        S_ij = scale × Q_i × K_j^T
        P_ij = exp(S_ij - L_i)                # Recompute attention weights
        D_i = rowsum(dO_i ⊙ O_i)              # Diagonal term
        dV_j += P_ij^T × dO_i                 # V gradient
        dP_ij = dO_i × V_j^T
        dS_ij = P_ij ⊙ (dP_ij - D_i)          # Softmax Jacobian
        dQ_i += scale × dS_ij × K_j           # Q gradient
        dK_j += scale × dS_ij^T × Q_i        # K gradient

Gradient Flow:

  1. dV: Weighted sum of upstream gradients using recomputed attention weights.
  2. dQ, dK: Through softmax Jacobian using recomputed P.
  3. Memory efficient: No O(N2) storage needed at any point.

Causal Masking

For autoregressive models, position i can only attend to positions i. FlashAttention's block structure enables efficient causal masking:

CaseHandling
Full skipKV block start column > Q block end row skip entire block
Partial maskApply mask within block (set to )

Efficiency Gain: Approximately 50% of blocks can be skipped entirely, reducing computation by half.

Causal Masking Blocks

Figure 4: Causal masking at block granularity. Lower-triangular blocks are fully computed; diagonal blocks are partially masked; upper-triangular blocks are skipped.


FP16 Implementation

This implementation fully supports FP16 (half precision) for both forward and backward passes.

Implementation Strategy

FP16 inputs are converted to FP32 internally for computation, then converted back to FP16 for output:

Input: halfQ,K,VloadFP32 registerscomputeFP32 accumulatorstorehalfO,L

Numerical Precision

OperationPrecision
Matrix multiplication (Q×KT)FP32
Softmax computationFP32
AccumulationFP32
Final outputFP16

Benefits:

  • Numerical stability comparable to FP32.
  • Reduced memory bandwidth (2× smaller tensors).
  • Supported on all modern GPUs (compute capability 5.3).

Memory Complexity Analysis

MethodForward MemoryBackward MemoryHBM IO
Standard AttentionO(N2)O(N2)O(N2+Nd)
FlashAttentionO(N)O(N)O(N2d2M)

Where M is SRAM size. When M=Θ(Nd), IO complexity approaches O(Nd), which is optimal since the inputs and outputs alone are Θ(Nd).

Real Memory Savings

Sequence LengthStandard AttentionFlashAttentionSavings
1,0244 MB8 KB99.8%
4,09664 MB32 KB99.95%
16,3841 GB128 KB99.99%

Implementation Highlights

Block Configuration

head_dimBrBcSRAM per Block
326464$\sim$32 KB
646464$\sim$64 KB
1283232$\sim$128 KB

Optimization Techniques

TechniqueBenefit
Vectorized Memory Accessfloat4 loads/stores for coalesced bandwidth
Launch Bounds__launch_bounds__(128) controls register pressure
Dynamic Shared MemoryRuntime allocation based on head_dim
Stream SafetyExplicit workspace lifetime management
Warp-level Primitives__shfl_sync for intra-warp reduction

Data Type Support

Data TypeForwardBackward
FP32 (float)FullFull
FP16 (half)FullFull

References

  1. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

    • Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
    • NeurIPS 2022
    • arXiv:2205.14135
  2. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

  3. Online normalizer calculation for softmax

  4. NVIDIA CUDA Programming Guide - Shared Memory

Stable v0.3.0 baseline. Lean CUDA FlashAttention reference.