Overview
The GEMM kernel computesC = A * B where:
- Matrix A is MxKxL (L is batch dimension)
- Matrix B is NxKxL
- Matrix C is MxNxL
- Utilizes GPU tensor cores for MMA operations
- Multi-stage pipelining to overlap compute and memory access
- Shared memory buffering for efficient data transfer
- Support for various data types (fp16, bf16, fp8, int8, tf32)
Ampere Architecture Example
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
class TensorOpGemm:
def __init__(
self,
ab_dtype: Type[cutlass.Numeric],
c_dtype: Type[cutlass.Numeric],
acc_dtype: Type[cutlass.Numeric],
atom_layout_mnk: Tuple[int, int, int],
):
self.ab_dtype = ab_dtype
self.c_dtype = c_dtype
self.acc_dtype = acc_dtype
self.cta_tiler = (128, 128, 32) # Tile shape
self.num_stages = 3 # Pipeline stages
self.atom_layout_mnk = atom_layout_mnk
atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk
self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32
def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler):
major_mode_size = (
smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0]
)
major_mode_size = 64 if major_mode_size >= 64 else major_mode_size
swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits))
swizzle_bits = min(swizzle_bits, 3)
layout_atom_outer = (
cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1))
if major_mode == utils.LayoutEnum.ROW_MAJOR
else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size))
)
layout_atom = cute.make_composed_layout(
cute.make_swizzle(swizzle_bits, 3, 3),
0,
layout_atom_outer,
)
return cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2))
# Create MMA atom with 16x8x16 shape
op = cute.nvgpu.warp.MmaF16BF16Op(
self.ab_dtype, self.acc_dtype, (16, 8, 16)
)
# Create tiled MMA
tC = cute.make_layout(self.atom_layout_mnk)
tiled_mma = cute.make_tiled_mma(op, tC)
# Create async copy atom for global to shared memory
atom_async_copy = cute.make_copy_atom(
cute.nvgpu.cpasync.CopyG2SOp(
cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL
),
mA.element_type,
num_bits_per_copy=128,
)
@cute.kernel
def kernel(
self,
mA: cute.Tensor,
mB: cute.Tensor,
mC: cute.Tensor,
# ... other parameters
):
tidx, _, _ = cute.arch.thread_idx()
bidx, bidy, bidz = cute.arch.block_idx()
# Allocate shared memory
smem = cutlass.utils.SmemAllocator()
sA = smem.allocate_tensor(mA.element_type, sA_layout, 16)
sB = smem.allocate_tensor(mB.element_type, sB_layout, 16)
# Get thread tiles
thr_copy_A = tiled_copy_A.get_slice(tidx)
tAgA = thr_copy_A.partition_S(gA)
tAsA = thr_copy_A.partition_D(sA)
# Allocate accumulator
thr_mma = tiled_mma.get_slice(tidx)
tCrC = tiled_mma.make_fragment_C(gC)
tCrC.fill(0.0)
# Main loop: copy and compute
for k_tile in range(k_tile_count):
# Copy from global to shared memory
cute.copy(tiled_copy_A, tAgA[k_tile], tAsA[k_tile])
cute.copy(tiled_copy_B, tBgB[k_tile], tBsB[k_tile])
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(num_stages - 2)
cute.arch.sync_threads()
# Perform MMA
cute.gemm(tiled_mma, tCrC, tCrA[k_tile], tCrB[k_tile], tCrC)
# Store results back to global memory
cute.copy(tiled_copy_C, tCrC, tCgC)
Blackwell Architecture Example
For Blackwell GPUs, the implementation uses more advanced features:Key Blackwell Features
Tensor Memory Access (TMA) TMA provides efficient memory operations:Running the Example
Performance Profiling
Use NVIDIA Nsight Compute to profile your kernel:Key Concepts
Tile Shape Selection
Tile Shape Selection
The tile shape determines how the problem is divided:
- M dimension: Affects register usage and occupancy
- N dimension: Should align with MMA instruction shape
- K dimension: Larger values improve arithmetic intensity
- Ampere: 128x128x32, 128x256x32, 256x128x32
- Blackwell: 128x128x64, 256x128x128
Memory Hierarchy
Memory Hierarchy
Understanding the memory hierarchy is crucial:
- Global Memory (GMEM): Input matrices A, B, output C
- Shared Memory (SMEM): Tile-level cache, swizzled to avoid bank conflicts
- Registers (RMEM): Thread-level storage for MMA operands
- Tensor Memory (TMEM): Blackwell-specific accumulator storage
Pipeline Stages
Pipeline Stages
Multi-stage pipelining overlaps memory and compute:
- Stage 1: Load next tile from GMEM to SMEM
- Stage 2: Compute current tile MMA
- Stage 3: Write previous results to GMEM
Source Code
Ampere GEMM
Complete source:
examples/python/CuTeDSL/ampere/tensorop_gemm.pyBlackwell GEMM
Complete source:
examples/python/CuTeDSL/blackwell/dense_gemm.py