Skip to main content
This guide demonstrates how to perform General Matrix Multiplication (GEMM) operations using the CUTLASS Python interface, including both the high-level API and the CuTe DSL.

Overview

CUTLASS provides two main Python interfaces for GEMM operations:
  1. CUTLASS Python Interface: High-level API for ease of use
  2. CuTe DSL: Low-level Python DSL for maximum control and performance

High-Level CUTLASS Python Interface

The CUTLASS Python interface prioritizes ease of use with a simple, high-level API.

Installation

Install via PyPI:
pip install nvidia-cutlass
Or from source:
pip install .

Basic GEMM Example

Here’s a simple example using the high-level interface:
import cutlass
import numpy as np

# Create a GEMM plan
plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)

# Create input tensors
A = np.ones((1024, 1024), dtype=np.float16)
B = np.ones((1024, 1024), dtype=np.float16)
C = np.zeros((1024, 1024), dtype=np.float16)
D = np.zeros((1024, 1024), dtype=np.float16)

# Run the GEMM: D = A * B + C
plan.run(A, B, C, D)

print(f"Result shape: {D.shape}")
print(f"First element: {D[0, 0]}")

Features

The high-level interface provides:
  • Simple API: Requires only a few parameters to get started
  • Sensible defaults: Automatically selects reasonable kernel configurations
  • Configuration enumeration: Lists available configurations for your hardware
  • Descriptive exceptions: Python-friendly error messages instead of C++ compile errors
  • Framework integration: Easy export to PyTorch CUDA extensions

Supported Operations

  • Standard GEMMs
  • GEMMs with fused elementwise epilogues (e.g., ReLU)
  • Stream K swizzling (pre-SM90)
  • Grouped GEMM (pre-SM90)

CuTe DSL GEMM Examples

For advanced users who need maximum control, CUTLASS provides the CuTe DSL, a Python-embedded domain-specific language for writing high-performance kernels.

Simple SIMT GEMM (Ampere)

Here’s an example of a dense FP32 GEMM using SIMT operations:
import cutlass
import cutlass.cute as cute
import cutlass.cute.testing as testing
import cuda.bindings.driver as cuda

class SGemm:
    def __init__(
        self,
        cta_tiler=(128, 128, 8),
        num_stages=3,
        num_threads=256,
    ):
        self._cta_tiler = cta_tiler
        self._num_stages = num_stages
        self._num_threads = num_threads
        self._bM, self._bN, self._bK = cta_tiler

    @cute.jit
    def __call__(
        self,
        mA: cute.Tensor,
        mB: cute.Tensor,
        mC: cute.Tensor,
        stream: cuda.CUstream,
    ):
        # Configure shared memory layouts
        sA_layout = cute.make_layout(
            (self._bM, self._bK, self._num_stages),
            stride=(1, self._bM + 4, self._bK * (self._bM + 4)),
        )
        sB_layout = cute.make_layout(
            (self._bN, self._bK, self._num_stages),
            stride=(1, self._bN + 4, self._bK * (self._bN + 4)),
        )

        # Create tiled MMA
        atoms_layout = cute.make_layout(
            (self._num_threads // 16, 16, 1),
            stride=(16, 1, 0)
        )
        op = cute.nvgpu.MmaUniversalOp(cutlass.Float32)
        tiled_mma = cute.make_tiled_mma(op, atoms_layout)

        # Launch kernel
        grid_dim = (*cute.ceil_div(mC.shape, (self._bM, self._bN)), 1)
        self.kernel(mA, mB, mC, sA_layout, sB_layout, tiled_mma).launch(
            grid=grid_dim,
            block=[self._num_threads, 1, 1],
            stream=stream,
        )

    @cute.kernel
    def kernel(
        self,
        mA: cute.Tensor,
        mB: cute.Tensor,
        mC: cute.Tensor,
        sA_layout: cute.Layout,
        sB_layout: cute.Layout,
        tiled_mma: cute.TiledMma,
    ):
        # Kernel implementation
        # ... (load data, compute GEMM, store results)
        pass
Run the example:
python examples/ampere/sgemm.py --mnk 8192,8192,8192

Blackwell GEMM with TMA

For the latest Blackwell architecture, CUTLASS provides highly optimized kernels using Tensor Memory Access (TMA):
import cutlass
import cutlass.cute as cute
from cutlass.cute.nvgpu import tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils

class DenseGemmKernel:
    def __init__(
        self,
        acc_dtype,
        use_2cta_instrs,
        mma_tiler_mn,
        cluster_shape_mn,
        use_tma_store,
    ):
        self.acc_dtype = acc_dtype
        self.use_2cta_instrs = use_2cta_instrs
        self.mma_tiler_mn = mma_tiler_mn
        self.cluster_shape_mn = cluster_shape_mn
        self.use_tma_store = use_tma_store

    @cute.jit
    def __call__(self, a, b, c, stream, epilogue_op=lambda x: x):
        # Setup TMA atoms for efficient memory access
        tiled_mma = sm100_utils.make_trivial_tiled_mma(
            a.element_type,
            self.a_major_mode,
            self.b_major_mode,
            self.acc_dtype,
            self.cta_group,
            self.mma_tiler_mn,
        )

        # Launch kernel with cluster configuration
        self.kernel(tiled_mma, ...).launch(
            grid=grid,
            block=[128, 1, 1],
            cluster=(*self.cluster_shape_mn, 1),
            stream=stream,
        )
Run the Blackwell example:
python examples/blackwell/dense_gemm.py \
  --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
  --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \
  --mnkl 8192,8192,8192,1 \
  --use_tma_store --use_2cta_instrs

Key Features

SIMT GEMM (Ampere):
  • FPU-based matrix multiply-accumulate
  • Multi-stage pipeline for latency hiding
  • Vectorized memory copies
  • Bank conflict reduction with padding
Blackwell GEMM:
  • Tensor Memory Access (TMA) for efficient memory operations
  • tcgen05.mma instructions for matrix operations
  • TMA multicast with cluster for reduced L2 traffic
  • Support for 2-CTA instructions
  • Multi-stage pipeline

Profiling

Using NVIDIA Nsight Compute

Profile your GEMM kernels with NCU:
ncu python your_gemm_script.py \
  --mnk 8192,8192,8192 \
  --iterations 10 \
  --skip_ref_check

Performance Tips

  1. Choose appropriate tile sizes: Balance occupancy and shared memory usage
  2. Use TMA on Hopper/Blackwell: Significantly reduces memory access overhead
  3. Enable clustering: Improves L2 cache utilization
  4. Profile different configurations: Use the profiler to find optimal parameters

Data Type Support

High-Level Interface

  • FP32, FP16, BF16
  • INT8, INT4
  • TF32 (on Ampere+)

CuTe DSL

Ampere SIMT:
  • FP32
Blackwell TensorCore:
  • FP16, BF16, TF32
  • INT8, UINT8
  • FP8 (E4M3FN, E5M2)
  • Mixed precision accumulation (FP32, FP16, INT32)

Next Steps

Grouped GEMM

Learn how to perform batched GEMMs with different problem sizes

Custom Epilogue

Fuse custom operations into the GEMM epilogue

Example Code

Find complete working examples in the CUTLASS repository:
  • High-level interface: python/README.md
  • SIMT GEMM: examples/python/CuTeDSL/ampere/sgemm.py
  • Blackwell GEMM: examples/python/CuTeDSL/blackwell/dense_gemm.py
  • Hopper GEMM: examples/python/CuTeDSL/hopper/dense_gemm.py

Build docs developers (and LLMs) love