Skip to content

vs CUTLASS

本文档对比 Mini-Inference Engine 与 NVIDIA CUTLASS。


CUTLASS 简介

CUTLASS (CUDA Templates for Linear Algebra Subroutines) 是 NVIDIA 开源的 GEMM 模板库。

特点

特点说明
模板化高度可配置的 kernel 设计
Tensor Core完整的 Tensor Core 支持
代码质量生产级代码,学习价值极高
持续更新跟进最新 GPU 架构

代码结构

cutlass/
├── include/
│   ├── gemm/           # GEMM 核心实现
│   │   ├── kernel/     # Kernel 实现
│   │   ├── thread/     # 线程级操作
│   │   └── warp/       # 线程束级操作
│   ├── arch/           # 架构相关
│   ├── transform/      # 数据变换
│   └── epilogue/       # 后处理(融合)
└── examples/           # 示例代码

本项目与 CUTLASS 的关系

复杂度对比

方面本项目CUTLASS
代码量~3000 行~50000+ 行
模板使用最少大量
配置项~10 个~100+ 个
学习曲线平缓陡峭

功能对比

功能本项目CUTLASS
FP32 GEMM
FP16 GEMM
INT8 GEMM
Tensor Core
Batch GEMM
算子融合✅ (简单)✅ (完整)
多 GPU

CUTLASS 的核心概念

1. 分层抽象

CUTLASS 将 GEMM 分解为多个层次:

cpp
// 伪代码展示分层结构
namespace cutlass::gemm {

// Threadblock 级:计算一个 C 的 tile
class GemmKernel {
    // Warp 级:计算 tile 的一部分
    using WarpIterators = ...;
    
    // Thread 级:计算 warp 内的部分
    using ThreadIterators = ...;
};

}

2. 模板参数

CUTLASS 使用大量模板参数配置 kernel:

cpp
cutlass::gemm::device::Gemm<
    float,                          // ElementA
    cutlass::layout::RowMajor,      // LayoutA
    float,                          // ElementB
    cutlass::layout::ColumnMajor,   // LayoutB
    float,                          // ElementC
    cutlass::layout::RowMajor,      // LayoutC
    float,                          // ElementAccumulator
    cutlass::arch::OpClassSimt,     // OpClass
    cutlass::arch::Sm80             // ArchTag
> gemm_op;

3. Epilogue 融合

CUTLASS 的 Epilogue 机制支持算子融合:

cpp
using Epilogue = cutlass::epilogue::thread::LinearCombination<
    float,          // Output type
    4,              // Elements per access
    float,          // Accumulator type
    float           // Scale bias type
>;

本项目可借鉴的设计

1. 分层设计

本项目的四层架构参考了 CUTLASS 的设计:

Application Layer  →  Benchmark / Tests
Engine Layer       →  InferenceEngine / Tensor
Kernel Layer       →  7-Level GEMM
Infrastructure     →  MemoryPool / StreamManager

2. 参数化设计

本项目的 AutoTuner 参考 CUTLASS 的参数化思路:

cpp
struct GemmConfig {
    int BLOCK_M;
    int BLOCK_N;
    int BLOCK_K;
    int THREAD_M;
    int THREAD_N;
};

3. 性能分析

学习 CUTLASS 的 profiling 方法:

cpp
// CUTLASS 内置的性能分析
cutlass::profiler::GemmProfiler<
    GemmKernel,
    ProblemSize
> profiler;
profiler.run();

学习路径

推荐顺序

Week 1-2: 本项目

    │  理解 GEMM 优化基础
    │  掌握共享内存、寄存器分块


Week 3-4: CUTLASS Examples

    │  阅读基本示例
    │  理解模板参数


Week 5-6: CUTLASS Source

    │  深入 kernel 实现
    │  学习 warp 级操作


Week 7+:  CUTLASS 高级特性

    │  Tensor Core
    │  Epilogue 融合
    │  多 GPU 并行

CUTLASS 学习资源

官方资源

推荐示例

示例学习重点
0_basic_gemm基本用法
10_planar_complex复杂布局
15_gemm_universal通用接口
23_gemm_grouped批量 GEMM
27_gemm_with_epilogue算子融合

总结

方面本项目CUTLASS
定位入门教学生产级库
复杂度
学习曲线平缓陡峭
生产可用教学用途

最佳实践: 本项目作为 CUTLASS 的前置学习材料,帮助理解 CUTLASS 的设计思想。

MIT License | CUDA GEMM optimization tutorial