Skip to main content
This guide demonstrates how to implement a mixed-precision GEMM kernel where input matrices have different data types. The kernel automatically handles type conversion and optional scaling operations.

Overview

Mixed-precision GEMM is essential for efficient AI inference workloads. This implementation supports:
  • Narrow precision tensors: int4, int8, uint8, fp8
  • Wide precision tensors: fp16, bf16
  • Two transformation modes:
    • Convert-only: Direct type conversion
    • Convert-scale: Type conversion followed by element-wise scaling

Architecture

The mixed-input GEMM uses a warp-specialized persistent kernel design:
Input A (narrow)     Input B (wide)
     │                    │
     ├─ TMA Load          ├─ TMA Load
     ↓                    ↓
 Shared Memory        Shared Memory
     ↓                    │
 Transform Warp           │
(convert + scale)         │
     ↓                    ↓
Shared/Tensor Memory ──→ MMA Warps
     ↓                    ↓
 Accumulator (wide precision)

 Epilogue Warps

 Output C

Implementation

1
Define the Kernel Configuration
2
Set up the mixed-precision GEMM kernel:
3
import cutlass
import cutlass.cute as cute
import cutlass.utils.mixed_input_helpers as mixed_input_utils
from cutlass.utils.mixed_input_helpers import TransformMode

class MixedInputGemmKernel:
    def __init__(
        self,
        scale_granularity_m: int,
        scale_granularity_k: int,
        acc_dtype: type[cutlass.Numeric],
        use_2cta_instrs: bool,
        mma_tiler_mnk: tuple[int, int, int],
        cluster_shape_mn: tuple[int, int],
        use_tma_store: bool,
        shuffle_a: bool,
    ):
        self.scale_granularity_m = scale_granularity_m
        self.scale_granularity_k = scale_granularity_k
        
        # Determine transformation mode
        if cutlass.const_expr(
            self.scale_granularity_m == 0 and self.scale_granularity_k == 0
        ):
            self.scale_mode = TransformMode.ConvertOnly
        else:
            self.scale_mode = TransformMode.ConvertScale
        
        self.acc_dtype = acc_dtype
        self.use_2cta_instrs = use_2cta_instrs
        self.mma_tiler = mma_tiler_mnk
        self.cluster_shape_mn = cluster_shape_mn
        self.use_tma_store = use_tma_store
4
Configure Warp Specialization
5
Assign specific roles to different warps:
6
# Define specialized warp IDs
self.epilog_warp_id = (0, 1, 2, 3)  # 4 warps for epilogue
self.mma_warp_id = 4                # 1 warp for MMA
self.tma_warp_id = 5                # 1 warp for TMA A/B load
self.scale_tma_warp_id = 6          # 1 warp for scale TMA load
self.transform_warp_id = (8, 9, 10, 11)  # 4 warps for transformation

# Set register allocation per warp type
self.num_regs_epilogue_warps = 192
self.num_regs_mma_warp = 96
self.num_regs_tma_warps = 96
self.num_regs_transform_warps = 208
7
Implement the Transform Warp
8
Transform warps convert narrow-precision data to wide precision:
9
# In transform warps (warp_idx >= transform_warp_id[0])
if warp_idx >= self.transform_warp_id[0]:
    cute.arch.setmaxregister_increase(self.num_regs_transform_warps)
    transform_local_tidx = tidx - 32 * self.transform_warp_id[0]
    
    # Partition tensors for transformation
    (
        src_copy_a,
        dst_copy_a,
        tAsA_input,
        tAsA_transform,
    ) = mixed_input_utils.transform_partition(
        self.transform_a_source,
        self.scale_mode,
        copy_atom_a_input,
        copy_atom_a_transform,
        sA_input,
        transformed_output,
        transform_local_tidx,
    )
    
    # Allocate register memory
    tArA = cute.make_rmem_tensor(
        tAsA_input[(None, None, None, None, 0)].shape,
        tAsA_input.element_type
    )
    tArA_transform = cute.make_rmem_tensor(
        tAsA_input[(None, None, None, None, 0)].shape,
        self.mma_dtype
    )
10
Handle Convert-Scale Mode
11
For int4 inputs, apply scaling after conversion:
12
if cutlass.const_expr(self.scale_mode == TransformMode.ConvertScale):
    # Partition scale tensor
    smem_thr_copy_S, tSsS_trans, tSrS_copy, tSrS = (
        mixed_input_utils.scale_partition(
            self.scale_granularity_k,
            self.cta_tile_shape_mnk,
            sS_input,
            transform_local_tidx,
        )
    )
    
    # Transform loop with scaling
    for k_tile in range(k_tile_cnt):
        # Wait for input A to be loaded
        a_load2trans_consumer_state = a_load2trans_pipeline.consumer_wait(
            a_load2trans_consumer_state
        )
        
        # Load narrow-precision A from shared memory
        cute.copy(src_copy_a, tAsA_input[k_tile], tArA)
        
        # Load scale factors
        scale_k_tile = k_tile // num_k_tiles_per_scale
        cute.copy(smem_thr_copy_S, tSsS_trans[scale_k_tile], tSrS_copy)
        
        # Convert and scale
        a_converted = tArA.load().to(self.mma_dtype)
        a_scaled = a_converted * tSrS.load()
        tArA_transform.store(a_scaled)
        
        # Store transformed result
        cute.copy(dst_copy_a, tArA_transform, tAsA_transform[k_tile])
13
Pipeline Coordination
14
Coordinate between specialized warps using pipeline barriers:
15
# Initialize pipelines
a_load2trans_pipeline = pipeline.PipelineTmaAsync.create(
    barrier_storage=storage.a_load2trans_full_mbar_ptr.data_ptr(),
    num_stages=self.num_load2trans_stage,
    producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
    consumer_group=pipeline.CooperativeGroup(
        pipeline.Agent.Thread,
        self.num_mcast_ctas_a * len(self.transform_warp_id),
    ),
    tx_count=self.num_tma_load_bytes_a,
)

trans2mma_pipeline = pipeline.PipelineAsyncUmma.create(
    barrier_storage=storage.a_trans2mma_full_mbar_ptr.data_ptr(),
    num_stages=self.num_trans2mma_stage,
    producer_group=pipeline.CooperativeGroup(
        pipeline.Agent.Thread,
        32 * len(self.transform_warp_id) * cta_v_size,
    ),
    consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
)

Running Examples

python examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py \
  --a_dtype Int8 \
  --b_dtype BFloat16 \
  --scale_granularity_m 0 \
  --scale_granularity_k 0 \
  --c_dtype BFloat16 \
  --acc_dtype Float32 \
  --mma_tiler_mnk 128,128,64 \
  --cluster_shape_mn 1,1 \
  --mnkl 256,512,8192,1

Performance Profiling

ncu python examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py \
  --a_dtype Int8 --b_dtype BFloat16 \
  --scale_granularity_m 0 --scale_granularity_k 0 \
  --c_dtype BFloat16 --acc_dtype Float32 \
  --mma_tiler_mnk 128,128,64 --cluster_shape_mn 1,1 \
  --mnkl 256,512,8192,1 \
  --warmup_iterations 1 --iterations 10 --skip_ref_check

Key Concepts

Scale granularity determines how many elements share the same scale factor:
  • scale_granularity_m: Number of M-mode elements per scale
  • scale_granularity_k: Number of K-mode elements per scale
For a GEMM with shape (M, N, K, L), the scale tensor shape is:
(M // scale_granularity_m, K // scale_granularity_k, L)
Example: For M=1024, K=6144, L=16 with granularity (1, 256):
  • Scale tensor shape: (1024, 24, 16)
  • Each element in M has its own scale
  • Every 256 elements in K share a scale
Warp specialization improves performance by:
  1. Reduced synchronization: Warps operate independently
  2. Better register allocation: Each warp type has optimized register usage
  3. Increased parallelism: Different operations happen concurrently
  4. Cache efficiency: Specialized warps have more predictable access patterns
Transformed data can be stored in:Shared Memory (SMEM):
  • Used for K-major layouts
  • Allows all MMA warps to access
  • Higher SMEM usage
Tensor Memory (TMEM):
  • Used for M-major layouts (Blackwell only)
  • Direct accumulator access
  • Saves SMEM bandwidth

Data Type Support

Narrow Precision (A)Wide Precision (B)Transform ModeAccumulator
Int8, Uint8Float16, BFloat16Convert-OnlyFloat32
Int4Float16, BFloat16Convert-ScaleFloat32
Float8E4M3FNFloat16, BFloat16Convert-OnlyFloat32
Float8E5M2Float16, BFloat16Convert-OnlyFloat32

Constraints

  • Scale granularity M must be 1 (current limitation)
  • Scale granularity K must be a multiple of MMA tile K
  • Scale tensor must be in M-major layout
  • Narrow precision tensor A must use supported types
  • Both input tensors require 16-byte alignment on contiguous dimension

Source Code

Mixed Input GEMM

Complete source: examples/python/CuTeDSL/blackwell/mixed_input_gemm/mixed_input_gemm.py

Build docs developers (and LLMs) love