CUTLASS implements a hierarchically blocked structure that maps efficiently to NVIDIA GPU architectures. Understanding these optimization techniques will help you achieve peak performance for your specific workloads.
Hierarchical GEMM Structure
CUTLASS organizes GEMM computations into a hierarchical tiled structure that targets different levels of the GPU memory hierarchy and execution model:
for ( int cta_n = 0 ; cta_n < GemmN; cta_n += CtaTileN) { // Threadblock-level
for ( int cta_m = 0 ; cta_m < GemmM; cta_m += CtaTileM) {
for ( int cta_k = 0 ; cta_k < GemmK; cta_k += CtaTileK) { // GEMM mainloop
for ( int warp_n = 0 ; warp_n < CtaTileN; warp_n += WarpTileN) { // Warp-level
for ( int warp_m = 0 ; warp_m < CtaTileM; warp_m += WarpTileM) {
for ( int warp_k = 0 ; warp_k < CtaTileK; warp_k += WarpTileK) {
for ( int mma_k = 0 ; mma_k < WarpTileK; mma_k += MmaK) { // Instruction-level
for ( int mma_n = 0 ; mma_n < WarpTileN; mma_n += MmaN) {
for ( int mma_m = 0 ; mma_m < WarpTileM; mma_m += MmaM) {
mma_instruction (d, a, b, c); // TensorCore operation
}
}
}
}
}
}
}
}
}
This structure targets:
Concurrency among threadblocks, warps, and CUDA/Tensor Cores
Memory locality within shared memory and registers
Tile Size Selection
Threadblock Tile Selection
The threadblock tile size (ThreadblockShape::{kM, kN, kK}) is critical for performance:
Large Problems
Small Problems
Tall/Skinny
Recommended: 128×128×32 or 256×128×32 using ThreadblockShape = cutlass :: gemm :: GemmShape < 256 , 128 , 32 >;
Larger threadblock tiles:
✅ Better data reuse
✅ Fewer global memory fetches
✅ Higher arithmetic intensity
⚠️ May reduce occupancy
⚠️ Higher register pressure
Recommended: 64×64×32 or 128×64×32 using ThreadblockShape = cutlass :: gemm :: GemmShape < 128 , 64 , 32 >;
Smaller threadblock tiles:
✅ Better occupancy
✅ More threadblocks launched
✅ Better for small M/N dimensions
⚠️ More global memory traffic
⚠️ Less data reuse
Recommended: Non-square tiles For M >> N: Use 256×64 or 128×64 using ThreadblockShape = cutlass :: gemm :: GemmShape < 256 , 64 , 32 >;
For N >> M: Use 64×256 or 64×128 using ThreadblockShape = cutlass :: gemm :: GemmShape < 64 , 256 , 32 >;
Warp Tile Configuration
Warp-level tile sizes affect shared memory access patterns:
// Configuration for Ampere/Hopper Tensor Cores
using WarpShape = cutlass :: gemm :: GemmShape < 64 , 64 , 32 >;
// Alternative configurations
using WarpShape = cutlass :: gemm :: GemmShape < 32 , 64 , 32 >; // More parallelism
using WarpShape = cutlass :: gemm :: GemmShape < 64 , 32 , 32 >; // Different aspect ratio
Key considerations:
Larger warp tiles increase data reuse but may cause bank conflicts
Warp tile should divide evenly into threadblock tile
Match warp tile aspect ratio to problem characteristics
Pipeline Optimization
Software Pipelining
CUTLASS uses software pipelining to hide memory latency by overlapping computation with data movement:
Double Buffering in Shared Memory
Allocate two tiles in shared memory:
One tile for current computation
One tile for loading next iteration’s data
constexpr int kStages = 3 ; // Number of pipeline stages
Recommended values:
Ampere: 3-4 stages
Hopper: 4-7 stages (with async TMA)
Blackwell: Dynamic based on problem size
Register-Level Pipelining
Double buffer warp-level fragments:
One fragment for MMA computation
One fragment for receiving shared memory loads
This enables overlapping shared memory loads with Tensor Core operations.
Async Copy
Use asynchronous copy instructions on Ampere and newer: // Enable async copy for global -> shared
using GlobalCopyPolicy = cutlass :: arch :: CacheOpAsync ;
Benefits:
Hides memory latency
Reduces register pressure
Enables deeper pipelines
Multi-Stage Pipeline Configuration
Ampere (SM80)
Hopper (SM90)
using GemmKernel = cutlass :: gemm :: kernel ::DefaultGemm <
float , // ElementA
cutlass :: layout ::RowMajor, // LayoutA
float , // ElementB
cutlass :: layout ::ColumnMajor, // LayoutB
float , // ElementC
cutlass :: layout ::RowMajor, // LayoutC
float , // ElementAccumulator
cutlass :: arch ::OpClassTensorOp, // Operator class
cutlass :: arch ::Sm80, // Architecture
cutlass :: gemm ::GemmShape < 128 , 128 , 32 > , // Threadblock shape
cutlass :: gemm ::GemmShape < 64 , 64 , 32 > , // Warp shape
cutlass :: gemm ::GemmShape < 16 , 8 , 16 > , // Instruction shape
3 // Stages
> ::GemmKernel;
Memory Access Optimization
Threadblock Rasterization
Control how threadblocks are mapped to the GEMM problem to improve L2 cache locality:
// Horizontal rasterization (default)
using SwizzleThreadBlock = cutlass :: gemm :: threadblock :: GemmIdentityThreadblockSwizzle <>;
// Horizontal with swizzling for better cache reuse
using SwizzleThreadBlock = cutlass :: gemm :: threadblock :: GemmHorizontalThreadblockSwizzle < 4 >;
// Batched identity swizzle
using SwizzleThreadBlock = cutlass :: gemm :: threadblock :: GemmBatchedIdentityThreadblockSwizzle ;
Command-line control in profiler:
$ ./tools/profiler/cutlass_profiler --operation=Gemm \
--raster_order=along_m --swizzle_size=4 \
--m=2048 --n=2048 --k=2048
Alignment Requirements
Ensure proper memory alignment for vectorized loads:
Misaligned memory accesses can significantly degrade performance!
// Specify alignment (in elements)
static int const kAlignmentA = 8 ; // 128-bit alignment for FP16
static int const kAlignmentB = 8 ;
using Gemm = cutlass :: gemm :: device ::Gemm <
cutlass :: half_t ,
cutlass :: layout ::RowMajor,
cutlass :: half_t ,
cutlass :: layout ::ColumnMajor,
cutlass :: half_t ,
cutlass :: layout ::RowMajor,
cutlass :: half_t ,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
cutlass :: gemm ::GemmShape < 128 , 128 , 32 > ,
cutlass :: gemm ::GemmShape < 64 , 64 , 32 > ,
cutlass :: gemm ::GemmShape < 16 , 8 , 16 > ,
cutlass :: epilogue :: thread ::LinearCombination <
cutlass :: half_t , kAlignmentA,
cutlass :: half_t , cutlass :: half_t
> ,
cutlass :: gemm :: threadblock ::GemmIdentityThreadblockSwizzle <> ,
3 ,
kAlignmentA,
kAlignmentB
> ;
Parallelized Reductions (Split-K)
When to Use Split-K
Split-K parallelizes the K dimension reduction, useful when:
M and N are small but K is large
Not enough parallelism from M×N dimension
GPU occupancy is low
Serial Split-K
Parallel Split-K
Each partition writes partial results; one final reduction pass: using Gemm = cutlass :: gemm :: device ::Gemm <
// ... types and shapes ...
cutlass :: gemm :: kernel ::DefaultGemmWithSplitK <
// ... configuration ...
> ::GemmKernel
> ;
// At runtime
int split_k_slices = 8 ;
typename Gemm :: Arguments args{
problem_size,
{d_A, lda},
{d_B, ldb},
{d_C, ldc},
{d_D, ldd},
{alpha, beta},
split_k_slices
};
Atomic operations perform reduction in parallel: typename Gemm :: Arguments args{
problem_size,
{d_A, lda},
{d_B, ldb},
{d_C, ldc},
{d_D, ldd},
{alpha, beta},
split_k_slices,
cutlass :: gemm :: GemmUniversalMode ::kGemm,
cutlass :: gemm :: kernel :: SplitKMode ::kParallel
};
Trade-offs:
✅ No separate reduction kernel needed
✅ Lower latency for small reductions
⚠️ Atomic operations may contend
⚠️ Only works with certain data types
Optimal Split-K Factor
Choose split-K slices based on problem size:
def calculate_split_k ( m , n , k , num_sms ):
"""
Heuristic for split-K factor selection.
"""
# Calculate parallelism from M×N
tile_m, tile_n = 128 , 128
num_tiles = (m // tile_m) * (n // tile_n)
# If enough parallelism, don't split K
if num_tiles >= num_sms * 2 :
return 1
# Otherwise, split to achieve target occupancy
target_blocks = num_sms * 4
split_k = (target_blocks + num_tiles - 1 ) // num_tiles
# Clamp to reasonable range
return min (split_k, 32 )
Tensor Core Optimization
Instruction Shape Selection
Different architectures support different MMA instruction shapes:
// FP16 Tensor Cores
using InstructionShape = cutlass :: gemm :: GemmShape < 16 , 8 , 16 >;
// TF32 Tensor Cores
using InstructionShape = cutlass :: gemm :: GemmShape < 16 , 8 , 8 >;
// INT8 Tensor Cores
using InstructionShape = cutlass :: gemm :: GemmShape < 16 , 8 , 32 >;
// FP64 Tensor Cores
using InstructionShape = cutlass :: gemm :: GemmShape < 8 , 8 , 4 >;
// FP16/BF16 WGMMA
using TileShape = Shape < _128 , _128 , _64 >;
using InstructionShape = Shape < _64 , _64 , _16 >;
// FP8 WGMMA
using TileShape = Shape < _128 , _128 , _128 >;
using InstructionShape = Shape < _64 , _128 , _32 >;
// INT8 WGMMA
using TileShape = Shape < _128 , _128 , _64 >;
using InstructionShape = Shape < _64 , _64 , _32 >;
// FP16/BF16 with larger tiles
using TileShape = Shape < _256 , _256 , _64 >;
// FP4 block-scaled GEMM
using TileShape = Shape < _128 , _256 , _256 >;
// Mixed precision with runtime types
using TileShape = Shape < _128 , _128 , _128 >;
Maximizing Tensor Core Utilization
Use appropriate data types
Match your data types to available Tensor Core instructions:
FP16/BF16: Best performance on modern GPUs
TF32: Good balance for FP32 workloads
INT8/FP8: Highest throughput for quantized models
Ensure tile sizes are multiples of instruction shape
// ✅ Good: 128 = 8 × 16
using ThreadblockShape = cutlass :: gemm :: GemmShape < 128 , 128 , 32 >;
using InstructionShape = cutlass :: gemm :: GemmShape < 16 , 8 , 16 >;
// ❌ Bad: 100 not divisible by 16
using ThreadblockShape = cutlass :: gemm :: GemmShape < 100 , 100 , 32 >;
Minimize padding
Pad matrices to multiples of tile sizes: int padded_m = (m + 127 ) / 128 * 128 ;
int padded_n = (n + 127 ) / 128 * 128 ;
int padded_k = (k + 31 ) / 32 * 32 ;
Epilogue Optimization
The epilogue performs the final transformation: D = α·AB + β·C
Fused Epilogue Operations
Fuse additional operations into the epilogue to save memory bandwidth:
ReLU Activation
GELU Activation
Bias Addition
using EpilogueOp = cutlass :: epilogue :: thread ::LinearCombinationRelu <
float , // ElementOutput
128 / cutlass :: sizeof_bits < float >::value, // Count
float , // ElementAccumulator
float // ElementCompute
> ;
Residual Matrix Support
For transformer models, efficiently fuse residual additions:
using EpilogueOp = cutlass :: epilogue :: thread ::LinearCombination <
ElementOutput,
kAlignment,
ElementAccumulator,
ElementCompute,
cutlass :: epilogue :: thread :: ScaleType ::OnlyAlphaScaling // Skip C matrix
> ;
Advanced Hopper Features
Tensor Memory Accelerator (TMA)
Hopper’s TMA provides hardware-accelerated asynchronous memory copies:
// CUTLASS 3.x automatically uses TMA for Hopper
using CollectiveMainloop = cutlass :: gemm :: collective ::CollectiveMma <
cutlass :: gemm :: collective ::MainloopSm90TmaGmmaWarpSpecialized <
kStages,
ClusterShape
>
> ;
Benefits:
Hardware-managed async copies
Reduced register pressure
Better pipelining
Multi-dimensional addressing
Warp Specialization
Dedicate warps to specific tasks:
using CollectiveMainloop = cutlass :: gemm :: collective ::MainloopSm90TmaGmmaWarpSpecialized <
3 , // Stages
cutlass :: gemm :: collective ::StageCountAutoCarveout < sizeof (SharedStorage) >
> ;
Producer warps: Load data via TMA
Consumer warps: Execute MMA instructions
Benefits: Better latency hiding, higher throughput
Cluster Launch
Thread block clusters enable communication between threadblocks:
using ClusterShape = cutlass :: gemm :: GemmShape < 2 , 1 , 1 >;
Use clusters for:
Distributed shared memory access
Reduction operations across threadblocks
Flash Attention and other advanced algorithms
Profile your baseline
$ ./tools/profiler/cutlass_profiler --operation=Gemm \
--m=4096 --n=4096 --k=4096 --output=baseline.csv
Experiment with tile sizes
Test different threadblock configurations: $ ./tools/profiler/cutlass_profiler --operation=Gemm \
--cta_m=64,128,256 --cta_n=64,128,256 --cta_k=32,64 \
--m=4096 --n=4096 --k=4096
Optimize pipeline depth
$ ./tools/profiler/cutlass_profiler --operation=Gemm \
--stages=3,4,5,6,7 --m=4096 --n=4096 --k=4096
Test rasterization strategies
$ ./tools/profiler/cutlass_profiler --operation=Gemm \
--raster_order=heuristic,along_m,along_n \
--swizzle_size=1,2,4,8 --m=4096 --n=4096 --k=4096
Consider Split-K for small problems
$ ./tools/profiler/cutlass_profiler --operation=Gemm \
--split-k-slices=1,2,4,8,16 --m=256 --n=256 --k=8192
Next Steps
Profiling Guide Learn how to use the CUTLASS profiler effectively
Benchmarks Compare performance across different configurations