Skip to content

GEMM Kernels

General Matrix Multiply (GEMM) is the fundamental operation in deep learning. TensorCraft-HPC provides progressive optimization paths from naive to Tensor Core implementations.

Overview

GEMM computes C = α × A × B + β × C where:

  • A is an M×K matrix
  • B is a K×N matrix
  • C is an M×N matrix
  • α and β are scalar coefficients

Why GEMM Matters

GEMM accounts for 80-90% of computation in modern neural networks. Understanding its optimization is crucial for high-performance AI systems.

Optimization Path

TensorCraft-HPC provides 4 levels of GEMM optimization:

LevelNameKey TechniquePerformance
1NaiveDirect triple loop~5% cuBLAS
2TiledShared memory blocking~45% cuBLAS
3Double BufferPipeline memory access~75% cuBLAS
4Tensor CoreWMMA instructions~92% cuBLAS

API Reference

Core Functions

gemm<T>(A, B, C, M, N, K, alpha, beta)

Performs general matrix multiplication.

cpp
template<typename T>
void gemm(
    const T* A,      // Input matrix A (M×K)
    const T* B,      // Input matrix B (K×N)
    T* C,            // Output matrix C (M×N)
    size_t M,        // Rows of A and C
    size_t N,        // Columns of B and C
    size_t K,        // Columns of A / Rows of B
    T alpha = 1.0,   // Scalar multiplier for A×B
    T beta = 0.0     // Scalar multiplier for C
);

Template Parameters:

  • T — Data type: float, double, half (FP16), or __nv_bfloat16

Example:

cpp
#include "tensorcraft/kernels/gemm.hpp"
#include "tensorcraft/memory/tensor.hpp"

using namespace tensorcraft;

// Create matrices
FloatTensor A({4096, 4096});
FloatTensor B({4096, 4096});
FloatTensor C({4096, 4096});

// Initialize A and B with data...

// Compute C = A × B
kernels::gemm(A.data(), B.data(), C.data(), 4096, 4096, 4096);

Specialized Variants

gemm_fp16 — FP16 Tensor Core

Optimized for FP16 computation using Tensor Cores.

cpp
void gemm_fp16(
    const half* A,
    const half* B,
    half* C,
    size_t M, size_t N, size_t K
);

Requirements:

  • SM70+ (Volta or later)
  • CUDA 11.0+

gemm_batched — Batched GEMM

Computes multiple independent GEMM operations.

cpp
template<typename T>
void gemm_batched(
    const T* const A[],   // Array of A matrices
    const T* const B[],   // Array of B matrices
    T* const C[],         // Array of C matrices
    size_t batch_count,
    size_t M, size_t N, size_t K
);

Performance Benchmarks

A100 80GB, FP16 Tensor Core

Matrix SizeTensorCraftcuBLASRatio
512×5120.15ms0.14ms93%
1024×10240.82ms0.71ms87%
2048×20483.1ms2.8ms89%
4096×409612.1ms11.0ms91%
8192×819295.2ms88.0ms92%

Scaling Across Architectures

GPUSM4096² FP16cuBLASRatio
V1007014.2ms12.8ms89%
A1008012.1ms11.0ms91%
H100908.5ms7.8ms92%

Usage Examples

Basic Usage

cpp
#include "tensorcraft/kernels/gemm.hpp"

// FP32 GEMM
tensorcraft::kernels::gemm(A_f32, B_f32, C_f32, M, N, K);

// FP16 GEMM (Tensor Core)
tensorcraft::kernels::gemm_fp16(A_f16, B_f16, C_f16, M, N, K);

With Python Bindings

python
import tensorcraft_ops as tc
import numpy as np

# Create matrices
A = np.random.randn(4096, 4096).astype(np.float16)
B = np.random.randn(4096, 4096).astype(np.float16)

# GPU-accelerated GEMM
C = tc.gemm(A, B)

Batched Processing

cpp
#include "tensorcraft/kernels/gemm.hpp"

std::vector<const half*> A_batch(batch_size);
std::vector<const half*> B_batch(batch_size);
std::vector<half*> C_batch(batch_size);

// Initialize batch pointers...

tensorcraft::kernels::gemm_batched(
    A_batch.data(), B_batch.data(), C_batch.data(),
    batch_size, M, N, K
);

Implementation Details

Memory Layout

All matrices are expected in row-major order:

A[M×K]: A[0,0], A[0,1], ..., A[0,K-1], A[1,0], ...
B[K×N]: B[0,0], B[0,1], ..., B[0,N-1], B[1,0], ...
C[M×N]: C[0,0], C[0,1], ..., C[0,N-1], C[1,0], ...

Thread Block Configuration

OptimizationBlock SizeTile Size
Tiled25632×32
Double Buffer25632×32 × 2
Tensor Core12864×64 (WMMA)

Shared Memory Usage

  • Tiled: 2 × 32 × 32 × sizeof(T) per block
  • Double Buffer: 4 × 32 × 32 × sizeof(T) per block
  • Tensor Core: 2 × 64 × 64 × sizeof(T) per block

References

Next Steps

Released under the Apache 2.0 License.