Skip to main content
Grouped GEMM enables you to compute a batch of GEMM operations where each operation can have a different problem size, unlike standard batched GEMM where all operations have the same dimensions.

Overview

Grouped GEMM is ideal for scenarios where you need to perform multiple matrix multiplications with varying dimensions, such as:
  • Multi-head attention with different head sizes
  • Variable-length sequence processing
  • Sparse neural network layers
  • Mixed expert models

What is Grouped GEMM?

Grouped GEMM differs from “Batched Array” GEMM:
  • Batched GEMM: All matrices have the same dimensions (M, N, K)
  • Grouped GEMM: Each group can have different dimensions
Each group performs: C[i] = A[i] × B[i] where A[i], B[i], and C[i] can have different sizes.

Implementation Example

Here’s a complete example using the Blackwell architecture with CuTe DSL:

Step 1: Define the Kernel

import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils

class GroupedGemmKernel:
    def __init__(
        self,
        acc_dtype,
        use_2cta_instrs,
        mma_tiler_mn,
        cluster_shape_mn,
        tensormap_update_mode=utils.TensorMapUpdateMode.SMEM,
    ):
        self.acc_dtype = acc_dtype
        self.use_2cta_instrs = use_2cta_instrs
        self.mma_tiler = (*mma_tiler_mn, 1)
        self.cluster_shape_mn = cluster_shape_mn
        self.tensormap_update_mode = tensormap_update_mode
        
        self.cta_group = (
            tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
        )

Step 2: Call the Kernel

    @cute.jit
    def __call__(
        self,
        initial_a: cute.Tensor,
        initial_b: cute.Tensor,
        initial_c: cute.Tensor,
        group_count: int,
        problem_shape_mnkl: cute.Tensor,
        strides_abc: cute.Tensor,
        tensor_address_abc: cute.Tensor,
        total_num_clusters: int,
        tensormap_cute_tensor: cute.Tensor,
        max_active_clusters: int,
        stream: cuda.CUstream,
    ):
        # Setup attributes based on input tensors
        self.a_dtype = initial_a.element_type
        self.b_dtype = initial_b.element_type
        self.c_dtype = initial_c.element_type
        
        # Configure TMA atoms for A, B, C
        tiled_mma = sm100_utils.make_trivial_tiled_mma(
            self.a_dtype,
            self.a_major_mode,
            self.b_major_mode,
            self.acc_dtype,
            self.cta_group,
            self.mma_tiler[:2],
        )
        
        # Setup TMA for each tensor
        a_op = sm100_utils.cluster_shape_to_tma_atom_A(
            self.cluster_shape_mn, tiled_mma.thr_id
        )
        tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
            a_op, initial_a, a_smem_layout,
            self.mma_tiler, tiled_mma,
            self.cluster_layout_vmnk.shape,
        )
        
        # Launch kernel
        self.kernel(...).launch(
            grid=grid,
            block=[self.threads_per_cta, 1, 1],
            cluster=(*self.cluster_shape_mn, 1),
            stream=stream,
        )

Step 3: Implement the Device Kernel

    @cute.kernel
    def kernel(
        self,
        tiled_mma: cute.TiledMma,
        tma_atom_a: cute.CopyAtom,
        mA_mkl: cute.Tensor,
        tma_atom_b: cute.CopyAtom,
        mB_nkl: cute.Tensor,
        tma_atom_c: cute.CopyAtom,
        mC_mnl: cute.Tensor,
        # ... other parameters
    ):
        # Warp specialization
        warp_idx = cute.arch.warp_idx()
        
        # TMA warp: Load data
        if warp_idx == self.tma_warp_id:
            # Update tensormaps for each group
            # Perform TMA loads
            pass
        
        # MMA warp: Compute
        if warp_idx == self.mma_warp_id:
            # Perform matrix multiply-accumulate
            pass
        
        # Epilogue warps: Store results
        if warp_idx < self.mma_warp_id:
            # Store results to global memory
            pass

Running the Example

1

Prepare problem sizes

Define the dimensions for each group:
problem_sizes = [
    (8192, 1280, 32, 1),   # Group 0: M=8192, N=1280, K=32, L=1
    (16, 384, 1536, 1),     # Group 1: M=16, N=384, K=1536, L=1
    (640, 1280, 16, 1),     # Group 2: M=640, N=1280, K=16, L=1
    (640, 160, 16, 1),      # Group 3: M=640, N=160, K=16, L=1
]
2

Run from command line

Execute the grouped GEMM example:
python examples/blackwell/grouped_gemm.py \
  --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
  --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \
  --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \
  --num_groups 4 --tensormap_update_mode SMEM
3

Profile with NCU

Analyze performance:
ncu python examples/blackwell/grouped_gemm.py \
  --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
  --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \
  --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \
  --num_groups 4 --tensormap_update_mode SMEM \
  --warmup_iterations 1 --iterations 10 --skip_ref_check

Key Features

Warp Specialization

Grouped GEMM uses specialized warps for different tasks:
  • TMA Warp: Handles tensormap updates and data loading
  • MMA Warp: Performs matrix multiply-accumulate operations
  • Epilogue Warps: Handle result storage and post-processing
This specialization improves latency hiding and overall performance.

Tensormap Update Modes

Grouped GEMM supports two modes for updating tensormaps:
# Update tensormaps in shared memory
tensormap_update_mode = utils.TensorMapUpdateMode.SMEM
# Buffers 3 tensormaps (A, B, C) in SMEM (128B each)
# Better for workloads with frequent group changes
Performance varies by workload—profile both modes to find the optimal choice.

Persistent Tile Scheduling

The kernel uses persistent tile scheduling to:
  • Minimize kernel launch overhead
  • Improve load balancing across groups
  • Better utilize hardware resources
# Tile scheduler handles work distribution
tile_sched = utils.StaticPersistentGroupTileScheduler.create(
    tile_sched_params,
    bid, grid_dim,
    cluster_tile_shape_mnk,
    utils.create_initial_search_state(),
    group_count,
    problem_sizes_mnkl,
)

Constraints and Considerations

The following constraints apply to grouped GEMM:
  • Only FP16 and BF16 data types are supported for A and B
  • Output (C) can be FP16, BF16, or FP32
  • The contiguous dimension must be 16-byte aligned
  • Batch size (L) must be 1 for each group
  • All groups must have the same majorness for A, B, and C

Performance Optimization

Choosing MMA Tile Size

Select tile sizes based on your problem sizes:
  • Small problems: Use smaller tiles (64×64, 128×64)
  • Large problems: Use larger tiles (128×128, 256×128)
  • Mixed sizes: Choose a balanced tile size

Cluster Configuration

Cluster shape affects performance:
  • (1,1): No clustering, good for small problems
  • (2,1) or (1,2): Light clustering, balanced approach
  • (2,2): Maximum clustering, best for large tiles

Memory Alignment

Ensure proper alignment for optimal performance:
# Check alignment
assert m * dtype_size % 16 == 0, "Contiguous dimension must be 16-byte aligned"

Complete Working Example

Find the full implementation:
examples/python/CuTeDSL/blackwell/grouped_gemm.py
The example includes:
  • Full kernel implementation with warp specialization
  • Tensormap management for variable problem sizes
  • Reference implementation for correctness checking
  • Performance benchmarking utilities

Legacy Grouped GEMM

For pre-SM90 architectures, use the high-level Python interface:
import cutlass
import numpy as np

# Define problem sizes
problems = [
    cutlass.op.GroupedGemmArguments(
        A=np.random.randn(128, 256).astype(np.float16),
        B=np.random.randn(256, 512).astype(np.float16),
        C=np.zeros((128, 512), dtype=np.float16),
        D=np.zeros((128, 512), dtype=np.float16),
    )
    for _ in range(num_groups)
]

# Create and run grouped GEMM
plan = cutlass.op.GroupedGemm(element=np.float16)
plan.run(problems)
See examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb for details.

Next Steps

Basic GEMM

Start with single GEMM operations

Custom Epilogue

Add custom operations to grouped GEMM

Build docs developers (and LLMs) love