Overview
Mixed-precision GEMM is essential for efficient AI inference workloads. This implementation supports:- Narrow precision tensors: int4, int8, uint8, fp8
- Wide precision tensors: fp16, bf16
- Two transformation modes:
- Convert-only: Direct type conversion
- Convert-scale: Type conversion followed by element-wise scaling
Architecture
The mixed-input GEMM uses a warp-specialized persistent kernel design:Implementation
import cutlass
import cutlass.cute as cute
import cutlass.utils.mixed_input_helpers as mixed_input_utils
from cutlass.utils.mixed_input_helpers import TransformMode
class MixedInputGemmKernel:
def __init__(
self,
scale_granularity_m: int,
scale_granularity_k: int,
acc_dtype: type[cutlass.Numeric],
use_2cta_instrs: bool,
mma_tiler_mnk: tuple[int, int, int],
cluster_shape_mn: tuple[int, int],
use_tma_store: bool,
shuffle_a: bool,
):
self.scale_granularity_m = scale_granularity_m
self.scale_granularity_k = scale_granularity_k
# Determine transformation mode
if cutlass.const_expr(
self.scale_granularity_m == 0 and self.scale_granularity_k == 0
):
self.scale_mode = TransformMode.ConvertOnly
else:
self.scale_mode = TransformMode.ConvertScale
self.acc_dtype = acc_dtype
self.use_2cta_instrs = use_2cta_instrs
self.mma_tiler = mma_tiler_mnk
self.cluster_shape_mn = cluster_shape_mn
self.use_tma_store = use_tma_store
# Define specialized warp IDs
self.epilog_warp_id = (0, 1, 2, 3) # 4 warps for epilogue
self.mma_warp_id = 4 # 1 warp for MMA
self.tma_warp_id = 5 # 1 warp for TMA A/B load
self.scale_tma_warp_id = 6 # 1 warp for scale TMA load
self.transform_warp_id = (8, 9, 10, 11) # 4 warps for transformation
# Set register allocation per warp type
self.num_regs_epilogue_warps = 192
self.num_regs_mma_warp = 96
self.num_regs_tma_warps = 96
self.num_regs_transform_warps = 208
# In transform warps (warp_idx >= transform_warp_id[0])
if warp_idx >= self.transform_warp_id[0]:
cute.arch.setmaxregister_increase(self.num_regs_transform_warps)
transform_local_tidx = tidx - 32 * self.transform_warp_id[0]
# Partition tensors for transformation
(
src_copy_a,
dst_copy_a,
tAsA_input,
tAsA_transform,
) = mixed_input_utils.transform_partition(
self.transform_a_source,
self.scale_mode,
copy_atom_a_input,
copy_atom_a_transform,
sA_input,
transformed_output,
transform_local_tidx,
)
# Allocate register memory
tArA = cute.make_rmem_tensor(
tAsA_input[(None, None, None, None, 0)].shape,
tAsA_input.element_type
)
tArA_transform = cute.make_rmem_tensor(
tAsA_input[(None, None, None, None, 0)].shape,
self.mma_dtype
)
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
# Partition scale tensor
smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS = (
mixed_input_utils.scale_partition(
self.scale_granularity_k,
self.cta_tile_shape_mnk,
sS_input,
transform_local_tidx,
)
)
# Transform loop with scaling
for k_tile in range(k_tile_cnt):
# Wait for input A to be loaded
a_load2trans_consumer_state = a_load2trans_pipeline.consumer_wait(
a_load2trans_consumer_state
)
# Load narrow-precision A from shared memory
cute.copy(src_copy_a, tAsA_input[k_tile], tArA)
# Load scale factors
scale_k_tile = k_tile // num_k_tiles_per_scale
cute.copy(smem_thr_copy_S, tSsS_trans[scale_k_tile], tSrS_copy)
# Convert and scale
a_converted = tArA.load().to(self.mma_dtype)
a_scaled = a_converted * tSrS.load()
tArA_transform.store(a_scaled)
# Store transformed result
cute.copy(dst_copy_a, tArA_transform, tAsA_transform[k_tile])
# Initialize pipelines
a_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
barrier_storage=storage.a_load2trans_full_mbar_ptr.data_ptr(),
num_stages=self.num_load2trans_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
self.num_mcast_ctas_a * len(self.transform_warp_id),
),
tx_count=self.num_tma_load_bytes_a,
)
trans2mma_pipeline = pipeline.PipelineAsyncUmma.create(
barrier_storage=storage.a_trans2mma_full_mbar_ptr.data_ptr(),
num_stages=self.num_trans2mma_stage,
producer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread,
32 * len(self.transform_warp_id) * cta_v_size,
),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
)
Running Examples
Performance Profiling
Key Concepts
Scale Granularity
Scale Granularity
Scale granularity determines how many elements share the same scale factor:Example: For M=1024, K=6144, L=16 with granularity (1, 256):
- scale_granularity_m: Number of M-mode elements per scale
- scale_granularity_k: Number of K-mode elements per scale
- Scale tensor shape: (1024, 24, 16)
- Each element in M has its own scale
- Every 256 elements in K share a scale
Warp Specialization Benefits
Warp Specialization Benefits
Warp specialization improves performance by:
- Reduced synchronization: Warps operate independently
- Better register allocation: Each warp type has optimized register usage
- Increased parallelism: Different operations happen concurrently
- Cache efficiency: Specialized warps have more predictable access patterns
Transformation Storage Location
Transformation Storage Location
Transformed data can be stored in:Shared Memory (SMEM):
- Used for K-major layouts
- Allows all MMA warps to access
- Higher SMEM usage
- Used for M-major layouts (Blackwell only)
- Direct accumulator access
- Saves SMEM bandwidth
Data Type Support
| Narrow Precision (A) | Wide Precision (B) | Transform Mode | Accumulator |
|---|---|---|---|
| Int8, Uint8 | Float16, BFloat16 | Convert-Only | Float32 |
| Int4 | Float16, BFloat16 | Convert-Scale | Float32 |
| Float8E4M3FN | Float16, BFloat16 | Convert-Only | Float32 |
| Float8E5M2 | Float16, BFloat16 | Convert-Only | Float32 |
Constraints
Source Code
Mixed Input GEMM
Complete source:
examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py