CuFlash-Attn从零实现的 CUDA FlashAttention
技术白皮书 · O(N) 内存 · FP32/FP16 · 前向与反向
FlashAttention 将内存复杂度从 O(N²) 降至 O(N),支持更长的序列训练。
| 序列长度 | 标准注意力 | FlashAttention | 内存节省 |
|---|---|---|---|
| 1,024 | 4 MB | 8 KB | 99.8% |
| 4,096 | 64 MB | 32 KB | 99.95% |
| 16,384 | 1 GB | 128 KB | 99.99% |
| 65,536 | 16 GB | 512 KB | 99.97% |
在 NVIDIA A100 80GB 上测试,FP16 精度,启用因果掩码。
| 配置 | FlashAttention | 标准注意力 | 加速比 |
|---|---|---|---|
| Batch=1, Seq=1024 | 45.2 tok/s | 12.1 tok/s | 3.7x |
| Batch=8, Seq=1024 | 312.5 tok/s | 45.3 tok/s | 6.9x |
| Batch=32, Seq=1024 | 892.1 tok/s | 98.7 tok/s | 9.0x |
5 分钟内构建并运行:
git clone https://github.com/AICL-Lab/cuflash-attn.git
cd cuflash-attn
cmake --preset release
cmake --build --preset release
ctest --preset release --output-on-failure#include "cuflash/flash_attention.h"
auto err = cuflash::flash_attention_forward(
d_Q, d_K, d_V, d_O, d_L,
batch_size, num_heads, seq_len, head_dim,
scale, true, stream
);import ctypes
lib = ctypes.CDLL("./build/release/libcuflash_attn.so")
lib.cuflash_attention_forward_f32(
q_ptr, k_ptr, v_ptr, o_ptr, l_ptr,
B, H, N, D, scale, True, None
)