Skip to main content

Overview

The CUTLASS GEMM API provides highly optimized matrix multiplication kernels for NVIDIA GPUs. It computes:
D = alpha * (A @ B) + beta * C
Where A, B, C, and D are matrices, and alpha and beta are scalars.

cutlass::gemm::device::Gemm

The primary GEMM device operator template.

Template Signature

include/cutlass/gemm/device/gemm.h
template <
  /// Element type for A matrix operand
  typename ElementA_,
  /// Layout type for A matrix operand
  typename LayoutA_,
  /// Element type for B matrix operand
  typename ElementB_,
  /// Layout type for B matrix operand
  typename LayoutB_,
  /// Element type for C and D matrix operands
  typename ElementC_,
  /// Layout type for C and D matrix operands
  typename LayoutC_,
  /// Element type for internal accumulation
  typename ElementAccumulator_ = ElementC_,
  /// Operator class tag
  typename OperatorClass_ = arch::OpClassSimt,
  /// Tag indicating architecture to tune for
  typename ArchTag_ = arch::Sm70,
  /// Threadblock-level tile size (concept: GemmShape)
  typename ThreadblockShape_ = /* default based on config */,
  /// Warp-level tile size (concept: GemmShape)
  typename WarpShape_ = /* default based on config */,
  /// Instruction-level tile size (concept: GemmShape)
  typename InstructionShape_ = /* default based on config */,
  /// Epilogue output operator
  typename EpilogueOutputOp_ = /* default based on config */,
  /// Threadblock-level swizzling operator
  typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
  /// Number of stages used in the pipelined mainloop
  int Stages = /* default based on config */,
  /// Access granularity of A matrix in units of elements
  int AlignmentA = /* default based on config */,
  /// Access granularity of B matrix in units of elements
  int AlignmentB = /* default based on config */,
  /// If true, kernel supports split-K with serial reduction
  bool SplitKSerial = false,
  /// Operation performed by GEMM
  typename Operator_ = /* default based on config */,
  /// Gather operand A by using an index array
  bool GatherA = false,
  /// Gather operand B by using an index array
  bool GatherB = false,
  /// Scatter result D by using an index array
  bool ScatterD = false,
  /// Permute result D
  typename PermuteDLayout = layout::NoPermute
>
class Gemm;

Template Parameters

ElementA_
typename
Data type for elements of matrix A (e.g., float, cutlass::half_t, cutlass::bfloat16_t)
LayoutA_
typename
Memory layout for matrix A:
  • cutlass::layout::ColumnMajor - Column-major (Fortran-style)
  • cutlass::layout::RowMajor - Row-major (C-style)
ElementB_
typename
Data type for elements of matrix B
LayoutB_
typename
Memory layout for matrix B
ElementC_
typename
Data type for elements of matrices C and D
LayoutC_
typename
Memory layout for matrices C and D
ElementAccumulator_
typename
default:"ElementC_"
Data type used for internal accumulation. Higher precision than inputs can improve accuracy.
OperatorClass_
typename
default:"arch::OpClassSimt"
Compute architecture:
  • arch::OpClassSimt - CUDA cores
  • arch::OpClassTensorOp - Tensor Cores
  • arch::OpClassWmmaTensorOp - WMMA Tensor Cores (Volta)
ArchTag_
typename
default:"arch::Sm70"
Target architecture (e.g., arch::Sm70, arch::Sm75, arch::Sm80, arch::Sm90, arch::Sm100)
ThreadblockShape_
typename
Threadblock tile size specified as cutlass::gemm::GemmShape<M, N, K>Example: cutlass::gemm::GemmShape<128, 128, 32>
WarpShape_
typename
Warp tile size specified as cutlass::gemm::GemmShape<M, N, K>Example: cutlass::gemm::GemmShape<64, 64, 32>
InstructionShape_
typename
Instruction-level tile size for Tensor Core operationsExample: cutlass::gemm::GemmShape<16, 8, 16> for FP16 Tensor Cores on SM80
Stages
int
Number of pipeline stages for overlapping data movement and computation
SplitKSerial
bool
default:"false"
Enable split-K with serial reduction for better load balancing

Arguments Structure

include/cutlass/gemm/device/gemm.h:292
struct Arguments {
  GemmCoord problem_size;
  TensorRef<ElementA const, LayoutA> ref_A;
  TensorRef<ElementB const, LayoutB> ref_B;
  TensorRef<ElementC const, LayoutC> ref_C;
  TensorRef<ElementC, LayoutC> ref_D;
  typename EpilogueOutputOp::Params epilogue;
  int split_k_slices;
  // For gather+scatter operations
  int const *gather_A_indices;
  int const *gather_B_indices;
  int const *scatter_D_indices;
};
problem_size
GemmCoord
Problem dimensions as GemmCoord(M, N, K) where output is M×N
ref_A
TensorRef
Reference to matrix A in device memory with pointer and leading dimension
ref_B
TensorRef
Reference to matrix B in device memory
ref_C
TensorRef
Reference to source matrix C in device memory
ref_D
TensorRef
Reference to destination matrix D in device memory (may alias C)
epilogue
EpilogueOutputOp::Params
Parameters for epilogue operation (typically contains alpha and beta scalars)
split_k_slices
int
default:"1"
Number of partitions along K dimension for split-K mode

Member Functions

static Status can_implement(Arguments const &args)
function
Checks if the kernel can execute the given problem. Returns Status::kSuccess if feasible.Source: include/cutlass/gemm/device/gemm.h:360
static size_t get_workspace_size(Arguments const &args)
function
Returns the required workspace size in bytes (needed for split-K mode).Source: include/cutlass/gemm/device/gemm.h:382
Status initialize(Arguments const &args, void *workspace, cudaStream_t stream)
function
Initializes kernel parameters. Must be called before run().Source: include/cutlass/gemm/device/gemm.h:403
Status update(Arguments const &args, void *workspace)
function
Lightweight update of kernel parameters (pointers and epilogue params only).Source: include/cutlass/gemm/device/gemm.h:454
Status run(cudaStream_t stream)
function
Launches the kernel on the specified CUDA stream.Source: include/cutlass/gemm/device/gemm.h:473
Status operator()(Arguments const &args, void *workspace, cudaStream_t stream)
function
Convenience function that calls initialize() followed by run().Source: include/cutlass/gemm/device/gemm.h:508

Basic Usage Example

From examples/00_basic_gemm/basic_gemm.cu:
#include "cutlass/gemm/device/gemm.h"

// Define type for single-precision CUTLASS GEMM
using ColumnMajor = cutlass::layout::ColumnMajor;

using CutlassGemm = cutlass::gemm::device::Gemm<
  float,        // ElementA
  ColumnMajor,  // LayoutA
  float,        // ElementB
  ColumnMajor,  // LayoutB
  float,        // ElementC
  ColumnMajor   // LayoutC
>;

Advanced Features

Tensor Core GEMM (FP16)

using GemmTensorOp = cutlass::gemm::device::Gemm<
  cutlass::half_t,                              // ElementA
  cutlass::layout::RowMajor,                    // LayoutA
  cutlass::half_t,                              // ElementB
  cutlass::layout::RowMajor,                    // LayoutB
  cutlass::half_t,                              // ElementC
  cutlass::layout::RowMajor,                    // LayoutC
  float,                                        // ElementAccumulator
  cutlass::arch::OpClassTensorOp,               // OpClass
  cutlass::arch::Sm80,                          // ArchTag
  cutlass::gemm::GemmShape<128, 128, 32>,       // ThreadblockShape
  cutlass::gemm::GemmShape<64, 64, 32>,         // WarpShape
  cutlass::gemm::GemmShape<16, 8, 16>,          // InstructionShape
  cutlass::epilogue::thread::LinearCombination< 
    cutlass::half_t, 8, float, float>,          // EpilogueOp
  cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
  3                                             // Stages
>;

Split-K Parallelization

For small M and N with large K:
Split-K Example
int split_k_slices = 4;  // Partition K dimension

CutlassGemm::Arguments args(
  {M, N, K},
  {A, lda},
  {B, ldb},
  {C, ldc},
  {D, ldd},
  {alpha, beta},
  split_k_slices  // Enable split-K
);

// Get workspace size
size_t workspace_size = CutlassGemm::get_workspace_size(args);
void* workspace;
cudaMalloc(&workspace, workspace_size);

// Launch with workspace
CutlassGemm gemm_op;
gemm_op(args, workspace, stream);

Batched GEMM

For multiple GEMMs with the same dimensions:
Batched GEMM
#include "cutlass/gemm/device/gemm_batched.h"

using GemmBatched = cutlass::gemm::device::GemmBatched<
  float, cutlass::layout::RowMajor,
  float, cutlass::layout::RowMajor,
  float, cutlass::layout::RowMajor
>;

GemmBatched::Arguments args(
  {M, N, K},
  {A, lda},
  lda * M,          // batch_stride_A
  {B, ldb},
  ldb * K,          // batch_stride_B
  {C, ldc},
  ldc * M,          // batch_stride_C
  {D, ldd},
  ldd * M,          // batch_stride_D
  {alpha, beta},
  batch_count
);

Supported Data Types

Floating Point

  • float (FP32)
  • cutlass::half_t (FP16)
  • cutlass::bfloat16_t (BF16)
  • cutlass::tfloat32_t (TF32)
  • double (FP64)

Integer

  • int8_t (INT8)
  • uint8_t (UINT8)
  • int32_t (INT32)
  • cutlass::int4b_t (INT4)
  • cutlass::uint4b_t (UINT4)

Complex

  • cutlass::complex<float>
  • cutlass::complex<double>
  • cutlass::complex<cutlass::half_t>

Special

  • cutlass::float_e4m3_t (FP8 E4M3)
  • cutlass::float_e5m2_t (FP8 E5M2)

Performance Tuning

Tile Size Selection

Choosing appropriate tile sizes is critical for performance:
Threadblock Tile: Larger tiles (128×128 or 256×128) generally perform better for large matrices but may have lower occupancy.Warp Tile: Should divide evenly into threadblock tile. Common: 64×64, 32×64.Instruction Tile: Determined by hardware (e.g., 16×8×16 for FP16 Tensor Cores on Ampere).

Pipeline Stages

// More stages = better overlap but higher register/shared memory usage
int Stages = 3;  // Good for SM80+
int Stages = 2;  // Good for SM75

Memory Alignment

Ensure data is properly aligned:
// AlignmentA and AlignmentB should be powers of 2
// Higher alignment enables vectorized loads
int AlignmentA = 8;  // 8 elements (64 bytes for FP16)
int AlignmentB = 8;

Error Handling

Error Handling
CutlassGemm gemm_op;

// Check if problem is supported
Status status = CutlassGemm::can_implement(args);
if (status != Status::kSuccess) {
  std::cerr << "Problem not supported: " 
            << cutlass::cutlassGetStatusString(status) 
            << std::endl;
  return;
}

// Initialize and run
status = gemm_op.initialize(args, workspace, stream);
if (status != Status::kSuccess) {
  std::cerr << "Initialization failed" << std::endl;
  return;
}

status = gemm_op.run(stream);
if (status != Status::kSuccess) {
  std::cerr << "Kernel execution failed" << std::endl;
  return;
}

GemmUniversal

cutlass::gemm::device::GemmUniversal - Universal GEMM supporting multiple modes (batched, array, split-K)

GemmSplitKParallel

cutlass::gemm::device::GemmSplitKParallel - Split-K with parallel reduction

GemmComplex

cutlass::gemm::device::GemmComplex - Complex-valued GEMM

GemmSparse

cutlass::gemm::device::GemmSparse - Sparse matrix multiplication (2:4 structured sparsity)

See Also

Build docs developers (and LLMs) love