Skip to main content
This guide demonstrates how to implement a dense General Matrix Multiply (GEMM) kernel using CUTLASS CuTe DSL. The example shows a complete implementation for both Ampere and Blackwell architectures.

Overview

The GEMM kernel computes C = A * B where:
  • Matrix A is MxKxL (L is batch dimension)
  • Matrix B is NxKxL
  • Matrix C is MxNxL
Key features:
  • Utilizes GPU tensor cores for MMA operations
  • Multi-stage pipelining to overlap compute and memory access
  • Shared memory buffering for efficient data transfer
  • Support for various data types (fp16, bf16, fp8, int8, tf32)

Ampere Architecture Example

1
Define the Kernel Class
2
Create a class to encapsulate the GEMM kernel configuration:
3
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils

class TensorOpGemm:
    def __init__(
        self,
        ab_dtype: Type[cutlass.Numeric],
        c_dtype: Type[cutlass.Numeric],
        acc_dtype: Type[cutlass.Numeric],
        atom_layout_mnk: Tuple[int, int, int],
    ):
        self.ab_dtype = ab_dtype
        self.c_dtype = c_dtype
        self.acc_dtype = acc_dtype
        self.cta_tiler = (128, 128, 32)  # Tile shape
        self.num_stages = 3  # Pipeline stages
        self.atom_layout_mnk = atom_layout_mnk
        atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
        self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
4
Configure Shared Memory Layouts
5
Set up swizzled shared memory layouts to avoid bank conflicts:
6
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
    major_mode_size = (
        smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
    )
    major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
    
    swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
    swizzle_bits = min(swizzle_bits, 3)
    
    layout_atom_outer = (
        cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
        if major_mode == utils.LayoutEnum.ROW_MAJOR
        else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
    )
    
    layout_atom = cute.make_composed_layout(
        cute.make_swizzle(swizzle_bits, 3, 3),
        0,
        layout_atom_outer,
    )
    return cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
7
Create Tiled MMA and Copy Atoms
8
Define the MMA operation and copy atoms:
9
# Create MMA atom with 16x8x16 shape
op = cute.nvgpu.warp.MmaF16BF16Op(
    self.ab_dtype, self.acc_dtype, (16, 8, 16)
)

# Create tiled MMA
tC = cute.make_layout(self.atom_layout_mnk)
tiled_mma = cute.make_tiled_mma(op, tC)

# Create async copy atom for global to shared memory
atom_async_copy = cute.make_copy_atom(
    cute.nvgpu.cpasync.CopyG2SOp(
        cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
    ),
    mA.element_type,
    num_bits_per_copy=128,
)
10
Implement the Kernel Function
11
The main kernel performs the GEMM computation:
12
@cute.kernel
def kernel(
    self,
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
    # ... other parameters
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, bidy, bidz = cute.arch.block_idx()
    
    # Allocate shared memory
    smem = cutlass.utils.SmemAllocator()
    sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
    sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
    
    # Get thread tiles
    thr_copy_A = tiled_copy_A.get_slice(tidx)
    tAgA = thr_copy_A.partition_S(gA)
    tAsA = thr_copy_A.partition_D(sA)
    
    # Allocate accumulator
    thr_mma = tiled_mma.get_slice(tidx)
    tCrC = tiled_mma.make_fragment_C(gC)
    tCrC.fill(0.0)
    
    # Main loop: copy and compute
    for k_tile in range(k_tile_count):
        # Copy from global to shared memory
        cute.copy(tiled_copy_A, tAgA[k_tile], tAsA[k_tile])
        cute.copy(tiled_copy_B, tBgB[k_tile], tBsB[k_tile])
        cute.arch.cp_async_commit_group()
        cute.arch.cp_async_wait_group(num_stages - 2)
        cute.arch.sync_threads()
        
        # Perform MMA
        cute.gemm(tiled_mma, tCrC, tCrA[k_tile], tCrB[k_tile], tCrC)
    
    # Store results back to global memory
    cute.copy(tiled_copy_C, tCrC, tCgC)

Blackwell Architecture Example

For Blackwell GPUs, the implementation uses more advanced features:
class DenseGemmKernel:
    def __init__(
        self,
        acc_dtype: Type[cutlass.Numeric],
        use_2cta_instrs: bool,
        mma_tiler_mn: Tuple[int, int],
        cluster_shape_mn: Tuple[int, int],
        use_tma_store: bool,
    ):
        self.acc_dtype = acc_dtype
        self.use_2cta_instrs = use_2cta_instrs
        self.cluster_shape_mn = cluster_shape_mn
        self.mma_tiler_mn = mma_tiler_mn
        self.use_tma_store = use_tma_store
        self.cta_group = (
            tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
        )

Key Blackwell Features

Tensor Memory Access (TMA) TMA provides efficient memory operations:
# Setup TMA load for A
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
    self.cluster_shape_mn, tiled_mma.thr_id
)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
    a_op, a, a_smem_layout, self.mma_tiler, tiled_mma,
    self.cluster_layout_vmnk.shape
)

# Use TMA to copy data
cute.copy(
    tma_atom_a,
    tAgA[(None, k_tile_idx)],
    tAsA[(None, producer_handle.index)],
    tma_bar_ptr=producer_handle.barrier,
    mcast_mask=a_full_mcast_mask,
)
Cluster Computing Blackwell supports CTA clusters for better parallelism:
# Compute cluster layout
self.cluster_layout_vmnk = cute.tiled_divide(
    cute.make_layout((*self.cluster_shape_mn, 1)),
    (tiled_mma.thr_id.shape,),
)

# Launch with cluster
self.kernel(...).launch(
    grid=grid,
    block=[self.threads_per_cta, 1, 1],
    cluster=(*self.cluster_shape_mn, 1),
    stream=stream,
)

Running the Example

python examples/python/CuTeDSL/ampere/tensorop_gemm.py \
  --mnkl 8192,8192,8192,1 \
  --atom_layout_mnk 2,2,1 \
  --ab_dtype Float16 \
  --c_dtype Float16 \
  --acc_dtype Float32 \
  --a_major m --b_major n --c_major n

Performance Profiling

Use NVIDIA Nsight Compute to profile your kernel:
ncu python examples/python/CuTeDSL/ampere/tensorop_gemm.py \
  --mnkl 8192,8192,8192,1 \
  --atom_layout_mnk 2,2,1 \
  --ab_dtype Float16 \
  --c_dtype Float16 \
  --acc_dtype Float32 \
  --a_major m --b_major n --c_major n \
  --skip_ref_check --iterations 2

Key Concepts

The tile shape determines how the problem is divided:
  • M dimension: Affects register usage and occupancy
  • N dimension: Should align with MMA instruction shape
  • K dimension: Larger values improve arithmetic intensity
Common configurations:
  • Ampere: 128x128x32, 128x256x32, 256x128x32
  • Blackwell: 128x128x64, 256x128x128
Understanding the memory hierarchy is crucial:
  1. Global Memory (GMEM): Input matrices A, B, output C
  2. Shared Memory (SMEM): Tile-level cache, swizzled to avoid bank conflicts
  3. Registers (RMEM): Thread-level storage for MMA operands
  4. Tensor Memory (TMEM): Blackwell-specific accumulator storage
Multi-stage pipelining overlaps memory and compute:
  • Stage 1: Load next tile from GMEM to SMEM
  • Stage 2: Compute current tile MMA
  • Stage 3: Write previous results to GMEM
More stages provide better overlap but increase SMEM usage.

Source Code

Ampere GEMM

Complete source: examples/python/CuTeDSL/ampere/tensorop_gemm.py

Blackwell GEMM

Complete source: examples/python/CuTeDSL/blackwell/dense_gemm.py

Build docs developers (and LLMs) love