Skip to main content

GEMM Operations

GEMM (General Matrix Multiplication) is the fundamental operation at the heart of CUTLASS. This page explains how CUTLASS decomposes GEMM operations hierarchically to achieve optimal performance on NVIDIA GPUs.

What is GEMM?

GEMM computes the matrix product:
D = α × (A × B) + β × C
Where:
  • A is an M×K matrix
  • B is a K×N matrix
  • C is an M×N matrix (source)
  • D is an M×N matrix (destination)
  • α and β are scalar coefficients
CUTLASS supports various GEMM variants including batched GEMM, grouped GEMM, split-K GEMM, and sparse GEMM.

GEMM Coordinate System

CUTLASS uses a coordinate system to navigate the GEMM problem space. The GemmCoord structure represents positions within the computation:
struct GemmCoord : public Coord<3, int> {
  // GEMM M dimension - rows of the output C matrix
  static int const kM = 0;
  
  // GEMM N dimension - columns of the output C matrix
  static int const kN = 1;
  
  // GEMM K dimension - inner dimension of the GEMM problem
  static int const kK = 2;
  
  // Construct from dimensions
  GemmCoord(Index m, Index n, Index k);
  
  // Access methods
  Index m() const;  // Rows
  Index n() const;  // Columns
  Index k() const;  // Inner dimension
};
Reference: include/cutlass/gemm_coord.h:86

Hierarchical Decomposition

CUTLASS decomposes GEMM into a hierarchy of smaller operations to efficiently utilize GPU hardware:
┌─────────────────────────────────────────┐
│           Device (Grid)                 │
│  Full GEMM: M × N × K                   │
└───────────────┬─────────────────────────┘

                ├─ Threadblock Tiles
                │  e.g., 128×128×8

                ├─ Warp Tiles  
                │  e.g., 32×32×8

                └─ Thread Tiles
                   e.g., 8×8×8

1. Device Level

The entire GEMM problem is mapped across the GPU grid. Each CUDA threadblock processes one or more tiles of the output matrix.

2. Threadblock Level

Each threadblock computes a tile (e.g., 128×128) of the output matrix by:
  1. Loading tiles from global memory to shared memory
  2. Performing warp-level operations
  3. Writing results back to global memory
template <
  int M = 1,   // Rows of matrix product
  int N = 1,   // Columns of matrix product  
  int K = 1    // Inner dimension
>
struct GemmShape {
  static int const kM = M;
  static int const kN = N;
  static int const kK = K;
  
  static int const kMN = M * N;    // Elements in output
  static int const kMNK = M * N * K;  // Total operations
};
Reference: include/cutlass/gemm_coord.h:42

3. Warp Level

Warps (groups of 32 threads) collaborate to compute smaller tiles using Tensor Core instructions when available.

4. Thread Level

Individual threads process the smallest granularity of data, performing scalar or vector operations.
This hierarchical approach enables CUTLASS to:
  • Maximize data reuse through shared memory
  • Leverage Tensor Cores for accelerated computation
  • Achieve high memory bandwidth utilization
  • Scale across different GPU architectures

Basic GEMM Example

Here’s a simple example instantiating a CUTLASS GEMM kernel:
#include "cutlass/gemm/device/gemm.h"

cudaError_t CutlassSgemmNN(
  int M, int N, int K,
  float alpha,
  float const *A, int lda,
  float const *B, int ldb,
  float beta,
  float *C, int ldc) {

  using ColumnMajor = cutlass::layout::ColumnMajor;
  
  // Define CUTLASS GEMM type
  using CutlassGemm = cutlass::gemm::device::Gemm<
    float,        // Data-type of A matrix
    ColumnMajor,  // Layout of A matrix
    float,        // Data-type of B matrix
    ColumnMajor,  // Layout of B matrix
    float,        // Data-type of C matrix
    ColumnMajor>; // Layout of C matrix
  
  CutlassGemm gemm_operator;
  
  // Construct arguments
  CutlassGemm::Arguments args(
    {M, N, K},   // Problem dimensions
    {A, lda},    // Tensor-ref for A
    {B, ldb},    // Tensor-ref for B
    {C, ldc},    // Tensor-ref for C (source)
    {C, ldc},    // Tensor-ref for D (destination)
    {alpha, beta}  // Scalars
  );
  
  // Launch kernel
  cutlass::Status status = gemm_operator(args);
  
  return (status == cutlass::Status::kSuccess) 
    ? cudaSuccess : cudaErrorUnknown;
}
Reference: examples/00_basic_gemm/basic_gemm.cu:79
The threadblock tile size (e.g., 128×128×8) is a critical tuning parameter that affects performance. CUTLASS provides sensible defaults but allows customization.

Batched and Grouped GEMM

CUTLASS extends basic GEMM to support multiple matrix multiplications:

Batched GEMM

Computes multiple identical-sized GEMMs in parallel:
struct BatchedGemmCoord : public Coord<4, int> {
  static int const kM = 0;      // Rows
  static int const kN = 1;      // Columns  
  static int const kK = 2;      // Inner dimension
  static int const kBatch = 3;  // Batch dimension
  
  BatchedGemmCoord(Index m, Index n, Index k, Index b);
};
Reference: include/cutlass/gemm_coord.h:252

Grouped GEMM

Computes multiple GEMMs with different sizes in a single kernel launch, ideal for dynamic batching scenarios.

Data Movement Strategies

Efficient GEMM requires careful orchestration of data movement:
  1. Global Memory → Shared Memory
    • Use cooperative loads across threadblock
    • Leverage asynchronous copy instructions (SM80+)
  2. Shared Memory → Registers
    • Partition data across warps
    • Use swizzling to avoid bank conflicts
  3. Register → Tensor Cores
    • Feed matrix fragments to MMA instructions
    • Maximize computational throughput
Different memory levels have vastly different bandwidths:
  • Registers: ~20 TB/s (per SM)
  • Shared Memory: ~10 TB/s (per SM)
  • L2 Cache: ~3-5 TB/s
  • Global Memory (HBM): ~2-3 TB/s
CUTLASS optimizes data reuse at each level to minimize global memory accesses.

GEMM Variants

CUTLASS supports numerous GEMM specializations:
  • Split-K: Parallelizes the K dimension across threadblocks
  • Stream-K: Dynamic work distribution for improved load balancing
  • Sparse GEMM: Exploits structured sparsity in matrices
  • Complex GEMM: Native support for complex number arithmetic
  • Mixed Precision: Different input/output data types

Next Steps

CuTe Library

Learn about the tensor abstraction layer

Tensor Cores

Understand hardware-accelerated matrix operations

Memory Layouts

Explore data layout strategies

Quick Start

Build your first CUTLASS kernel

Build docs developers (and LLMs) love