Overview
Grouped GEMM is ideal for scenarios where you need to perform multiple matrix multiplications with varying dimensions, such as:- Multi-head attention with different head sizes
- Variable-length sequence processing
- Sparse neural network layers
- Mixed expert models
What is Grouped GEMM?
Grouped GEMM differs from “Batched Array” GEMM:- Batched GEMM: All matrices have the same dimensions (M, N, K)
- Grouped GEMM: Each group can have different dimensions
C[i] = A[i] × B[i] where A[i], B[i], and C[i] can have different sizes.
Implementation Example
Here’s a complete example using the Blackwell architecture with CuTe DSL:Step 1: Define the Kernel
Step 2: Call the Kernel
Step 3: Implement the Device Kernel
Running the Example
Key Features
Warp Specialization
Grouped GEMM uses specialized warps for different tasks:- TMA Warp: Handles tensormap updates and data loading
- MMA Warp: Performs matrix multiply-accumulate operations
- Epilogue Warps: Handle result storage and post-processing
Tensormap Update Modes
Grouped GEMM supports two modes for updating tensormaps:Persistent Tile Scheduling
The kernel uses persistent tile scheduling to:- Minimize kernel launch overhead
- Improve load balancing across groups
- Better utilize hardware resources
Constraints and Considerations
Performance Optimization
Choosing MMA Tile Size
Select tile sizes based on your problem sizes:- Small problems: Use smaller tiles (64×64, 128×64)
- Large problems: Use larger tiles (128×128, 256×128)
- Mixed sizes: Choose a balanced tile size
Cluster Configuration
Cluster shape affects performance:- (1,1): No clustering, good for small problems
- (2,1) or (1,2): Light clustering, balanced approach
- (2,2): Maximum clustering, best for large tiles
Memory Alignment
Ensure proper alignment for optimal performance:Complete Working Example
Find the full implementation:- Full kernel implementation with warp specialization
- Tensormap management for variable problem sizes
- Reference implementation for correctness checking
- Performance benchmarking utilities
Legacy Grouped GEMM
For pre-SM90 architectures, use the high-level Python interface:examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb for details.
Next Steps
Basic GEMM
Start with single GEMM operations
Custom Epilogue
Add custom operations to grouped GEMM