Skip to content

DIYFlashAttention

用 Triton 从零构建 FlashAttention,掌握 GPU 内核优化的核心技术

⚡ 内存减少 99% 🚀 速度提升 1.6x 📖 教育级代码

为什么选择这个项目?

紧凑但真实:代码量控制在可完整阅读的范围内,但不是玩具示例。你可以:

  • ✅ 在 GPU 上运行真实基准测试
  • ✅ 对比 PyTorch SDPA 的性能差异
  • ✅ 理解每一行代码背后的设计决策

你将学到什么

主题收获
GPU 内存层级HBM → L2 → SRAM → 寄存器的数据流动
Triton 编程自动分块、autotune、内核优化技巧
FlashAttention 算法在线 softmax、因果掩码、变长序列处理
性能调优块大小选择、occupancy 优化、内存分析

项目数据

2+
核心 Triton 内核
O(N)
注意力内存复杂度
6
GPU 架构支持
99%
内存节省(长序列)

快速开始

bash
# 安装
pip install diy-flash-attention

# 或者从源码安装
pip install -e ".[dev]"

# 验证安装
python -c "from kernels import flash_attention; print('✓ 安装成功')"

运行示例

python
import torch
from kernels import flash_attention

# FlashAttention — 长序列内存减少 99%
q = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
k = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)
v = torch.randn(2, 8, 4096, 64, device="cuda", dtype=torch.float16)

out = flash_attention(q, k, v, causal=True)  # GPT 风格因果掩码
print(f"输出形状: {out.shape}")  # [2, 8, 4096, 64]

学习路径

🧑‍💻
内核开发者
从教程开始,逐行理解 FlashAttention 实现
推荐:教程 → API → 性能指南
🔬
研究人员
快速查阅 API 契约,复现和修改内核
推荐:API 参考 → 源码
🚀
性能工程师
深入性能调优,理解块大小和架构适配
推荐:性能指南 → 基准测试
📚
学习者
系统学习 GPU 编程和注意力优化
推荐:教程 → 速查表 → FAQ
开始你的 FlashAttention 学习之旅
从教程入手,理解实现;用 API 参考,确认契约;看性能指南,获取证据。

语言切换

Forward-only educational Triton FlashAttention project · MIT License