Overview
The CUTLASS GEMM API provides highly optimized matrix multiplication kernels for NVIDIA GPUs. It computes:A, B, C, and D are matrices, and alpha and beta are scalars.
cutlass::gemm::device::Gemm
The primary GEMM device operator template.Template Signature
include/cutlass/gemm/device/gemm.h
Template Parameters
Data type for elements of matrix A (e.g.,
float, cutlass::half_t, cutlass::bfloat16_t)Memory layout for matrix A:
cutlass::layout::ColumnMajor- Column-major (Fortran-style)cutlass::layout::RowMajor- Row-major (C-style)
Data type for elements of matrix B
Memory layout for matrix B
Data type for elements of matrices C and D
Memory layout for matrices C and D
Data type used for internal accumulation. Higher precision than inputs can improve accuracy.
Compute architecture:
arch::OpClassSimt- CUDA coresarch::OpClassTensorOp- Tensor Coresarch::OpClassWmmaTensorOp- WMMA Tensor Cores (Volta)
Target architecture (e.g.,
arch::Sm70, arch::Sm75, arch::Sm80, arch::Sm90, arch::Sm100)Threadblock tile size specified as
cutlass::gemm::GemmShape<M, N, K>Example: cutlass::gemm::GemmShape<128, 128, 32>Warp tile size specified as
cutlass::gemm::GemmShape<M, N, K>Example: cutlass::gemm::GemmShape<64, 64, 32>Instruction-level tile size for Tensor Core operationsExample:
cutlass::gemm::GemmShape<16, 8, 16> for FP16 Tensor Cores on SM80Number of pipeline stages for overlapping data movement and computation
Enable split-K with serial reduction for better load balancing
Arguments Structure
include/cutlass/gemm/device/gemm.h:292
Problem dimensions as
GemmCoord(M, N, K) where output is M×NReference to matrix A in device memory with pointer and leading dimension
Reference to matrix B in device memory
Reference to source matrix C in device memory
Reference to destination matrix D in device memory (may alias C)
Parameters for epilogue operation (typically contains alpha and beta scalars)
Number of partitions along K dimension for split-K mode
Member Functions
Checks if the kernel can execute the given problem. Returns
Status::kSuccess if feasible.Source: include/cutlass/gemm/device/gemm.h:360Returns the required workspace size in bytes (needed for split-K mode).Source:
include/cutlass/gemm/device/gemm.h:382Initializes kernel parameters. Must be called before
run().Source: include/cutlass/gemm/device/gemm.h:403Lightweight update of kernel parameters (pointers and epilogue params only).Source:
include/cutlass/gemm/device/gemm.h:454Launches the kernel on the specified CUDA stream.Source:
include/cutlass/gemm/device/gemm.h:473Convenience function that calls
initialize() followed by run().Source: include/cutlass/gemm/device/gemm.h:508Basic Usage Example
Fromexamples/00_basic_gemm/basic_gemm.cu:
Advanced Features
Tensor Core GEMM (FP16)
Split-K Parallelization
For small M and N with large K:Split-K Example
Batched GEMM
For multiple GEMMs with the same dimensions:Batched GEMM
Supported Data Types
Floating Point
float(FP32)cutlass::half_t(FP16)cutlass::bfloat16_t(BF16)cutlass::tfloat32_t(TF32)double(FP64)
Integer
int8_t(INT8)uint8_t(UINT8)int32_t(INT32)cutlass::int4b_t(INT4)cutlass::uint4b_t(UINT4)
Complex
cutlass::complex<float>cutlass::complex<double>cutlass::complex<cutlass::half_t>
Special
cutlass::float_e4m3_t(FP8 E4M3)cutlass::float_e5m2_t(FP8 E5M2)
Performance Tuning
Tile Size Selection
Choosing appropriate tile sizes is critical for performance:Threadblock Tile: Larger tiles (128×128 or 256×128) generally perform better for large matrices but may have lower occupancy.Warp Tile: Should divide evenly into threadblock tile. Common: 64×64, 32×64.Instruction Tile: Determined by hardware (e.g., 16×8×16 for FP16 Tensor Cores on Ampere).
Pipeline Stages
Memory Alignment
Ensure data is properly aligned:Error Handling
Error Handling
Related APIs
GemmUniversal
cutlass::gemm::device::GemmUniversal - Universal GEMM supporting multiple modes (batched, array, split-K)GemmSplitKParallel
cutlass::gemm::device::GemmSplitKParallel - Split-K with parallel reductionGemmComplex
cutlass::gemm::device::GemmComplex - Complex-valued GEMMGemmSparse
cutlass::gemm::device::GemmSparse - Sparse matrix multiplication (2:4 structured sparsity)See Also
- Epilogue Operations - Customize output operations
- CuTe Library - Modern tensor abstraction for writing kernels
- GEMM Examples - Complete examples in the repository