Skip to main content
Custom epilogues allow you to fuse element-wise operations directly into the GEMM kernel, eliminating separate kernel launches and improving memory bandwidth efficiency.

Overview

Instead of:
D = matmul(A, B)  # Kernel 1
D = D * alpha + C * beta  # Kernel 2
E = relu(D)  # Kernel 3
You can fuse everything into a single kernel:
D = relu(matmul(A, B) * alpha + C * beta)  # Single fused kernel
This approach:
  • Reduces memory traffic
  • Eliminates kernel launch overhead
  • Improves overall throughput

Epilogue Fusion Configuration (EFC)

For Blackwell (SM100) and later architectures, CUTLASS provides the Epilogue Fusion Configuration (EFC) framework for defining custom epilogues.

Basic Example: Alpha-Beta Scaling

This example demonstrates a GEMM with custom epilogue that computes:
Y = A × B
D = (A × B) × alpha + C × beta + X × x_factor
1

Define the kernel class

Start by extending the base EFC kernel:
import cutlass
import cutlass.cute as cute
from common_dense_gemm_efc import DenseGemmEFC

class DenseGemmAlphaBeta(DenseGemmEFC):
    """Implements batched GEMM with custom epilogue fusion.
    
    Computes:
    - Y = A * B (accumulator stored to Y)
    - D = (A * B) * alpha + C * beta + X * x_factor
    """
    
    def __init__(
        self,
        acc_dtype,
        use_2cta_instrs,
        mma_tiler_mn,
        cluster_shape_mn,
    ):
        super().__init__(
            acc_dtype,
            use_2cta_instrs,
            mma_tiler_mn,
            cluster_shape_mn,
        )
2

Define tensor arguments

Specify the input and output tensors:
def create_arguments(
    self,
    l, m, n, k,
    a_major, b_major, cd_major,
    ab_dtype,
    c_dtype, d_dtype, x_dtype, y_dtype,
):
    # Standard A, B tensors from parent class
    std_args = super().create_arguments(
        l, m, n, k, a_major, b_major, cd_major, ab_dtype
    )
    
    # Create auxiliary tensors C, X, D, Y
    c_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype)
    x_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", x_dtype)
    d_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", d_dtype)
    y_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", y_dtype)
    
    return (*std_args, c_tensor, x_tensor, d_tensor, y_tensor)
3

Define the epilogue operation

Implement the custom fusion logic:
@cute.jit
def epilogue_operation(
    self,
    accumulator: cute.Tensor,
    c: cute.Tensor,
    x: cute.Tensor,
    alpha: float,
    beta: float,
    x_factor: float,
) -> tuple:
    # Store a copy of the accumulator to Y
    y = accumulator
    
    # Compute D = accumulator * alpha + C * beta + X * x_factor
    d = accumulator * alpha + c * beta + x * x_factor
    
    return (d, y)
4

Configure the epilogue

Set up the epilogue fusion configuration:
def get_epilogue_config(self):
    return {
        'inputs': [
            {'name': 'C', 'dtype': self.c_dtype, 'load_method': 'tma'},
            {'name': 'X', 'dtype': self.x_dtype, 'load_method': 'tma'},
        ],
        'outputs': [
            {'name': 'D', 'dtype': self.d_dtype, 'store_method': 'tma'},
            {'name': 'Y', 'dtype': self.y_dtype, 'store_method': 'tma'},
        ],
        'scalars': [
            {'name': 'alpha', 'dtype': 'float'},
            {'name': 'beta', 'dtype': 'float'},
            {'name': 'x_factor', 'dtype': 'float'},
        ],
        'operation': self.epilogue_operation,
    }

Running the Example

Command Line Usage

python examples/blackwell/epilogue/custom_epilogue_dense_gemm.py \
  --ab_dtype Float16 --c_dtype Float16 --d_dtype Float16 \
  --acc_dtype Float32 --epi_dtype Float32 \
  --x_dtype Float16 --y_dtype Float16 \
  --mma_tiler_mn 128,128 --cluster_shape_mn 2,1 \
  --mnkl 8192,8192,8192,1 \
  --use_2cta_instrs --alpha 2.0 --beta 1.0 --x_factor 3.0

With NCU Profiling

ncu python examples/blackwell/epilogue/custom_epilogue_dense_gemm.py \
  --ab_dtype Float16 --c_dtype Float16 --d_dtype Float16 \
  --acc_dtype Float32 --epi_dtype Float32 \
  --x_dtype Float16 --y_dtype Float16 \
  --mma_tiler_mn 128,128 --cluster_shape_mn 2,1 \
  --mnkl 8192,8192,8192,1 \
  --use_2cta_instrs --alpha 2.0 --beta 1.0 --x_factor 3.0 \
  --warmup_iterations 1 --iterations 10 --skip_ref_check

Advanced Examples

Activation Functions

Fuse common activation functions:
@cute.jit
def epilogue_with_relu(self, accumulator, c, alpha, beta):
    # Compute D = alpha * A * B + beta * C
    d = accumulator * alpha + c * beta
    
    # Apply ReLU: max(0, x)
    d = cute.where(d > 0, d, cute.full_like(d, 0))
    
    return d

Multiple Outputs

Generate multiple output tensors:
@cute.jit
def epilogue_multi_output(self, accumulator, c):
    # Output 1: Standard result
    d = accumulator + c
    
    # Output 2: Squared result
    d_squared = d * d
    
    # Output 3: Normalized result
    d_norm = d / cute.norm(d)
    
    return (d, d_squared, d_norm)

Complex Expressions

Implement sophisticated fusion patterns:
@cute.jit
def epilogue_complex(self, accumulator, c, x, bias, scale):
    # Compute: (A * B + bias) * scale + C + tanh(X)
    result = (accumulator + bias) * scale
    result = result + c
    result = result + cute.tanh(x)
    
    # Apply GELU activation
    result = cute.gelu(result)
    
    return result

Legacy Epilogue Interface (Pre-SM90)

For older architectures, use the high-level Python interface:
import cutlass
import numpy as np

# Define a custom activation
def relu(x):
    return np.maximum(0, x)

# Create GEMM with fused epilogue
plan = cutlass.op.Gemm(
    element=np.float16,
    layout=cutlass.LayoutType.RowMajor,
    epilogue=relu,  # Fuse ReLU into GEMM
)

# Run the fused operation
A = np.random.randn(1024, 512).astype(np.float16)
B = np.random.randn(512, 2048).astype(np.float16)
C = np.zeros((1024, 2048), dtype=np.float16)
D = np.zeros((1024, 2048), dtype=np.float16)

plan.run(A, B, C, D)
See examples/python/deprecated/01_epilogue.ipynb for more details.

Supported Operations

The epilogue can include:

Mathematical Operations

  • Addition, subtraction, multiplication, division
  • Power, exponential, logarithm
  • Trigonometric functions (sin, cos, tan)

Activation Functions

  • ReLU, Leaky ReLU, PReLU
  • Sigmoid, tanh
  • GELU, SiLU
  • Softmax (with limitations)

Memory Operations

  • Load from multiple input tensors
  • Store to multiple output tensors
  • Conditional stores

Type Conversions

  • Mixed precision computations
  • Type casting between FP32, FP16, BF16, INT8

Performance Considerations

Follow these guidelines for optimal performance:
  • Minimize memory traffic: Read each input tensor only once
  • Limit register pressure: Avoid storing too many intermediate values
  • Use vectorized operations: Leverage SIMD instructions where possible
  • Consider data types: Match precision requirements to avoid unnecessary conversions

Memory Bandwidth

Epilogue fusion is most beneficial when:
Compute time >> Memory bandwidth time
Profile your kernel to ensure the epilogue doesn’t become memory-bound.

Register Usage

Monitor register usage with NCU:
ncu --metrics sm__sass_thread_inst_executed_op_* your_script.py
High register pressure can reduce occupancy.

Data Type Support

Supported Input Types (A, B)

  • FP16, BF16
  • TF32
  • INT8, UINT8
  • FP8 (E4M3FN, E5M2)

Supported Accumulator Types

  • FP32 (for all floating-point inputs)
  • FP16 (for FP16 and FP8 inputs)
  • INT32 (for INT8/UINT8 inputs)

Supported Output Types (C, D)

  • FP32, FP16, BF16
  • INT32, INT8, UINT8
  • FP8 (E4M3FN, E5M2) with FP32 accumulator

Constraints

Be aware of these limitations:
  • MMA tiler M must be 64/128 (single CTA) or 128/256 (2-CTA mode)
  • MMA tiler N must be 32-256 in steps of 32
  • Cluster shape must be power of 2, total size ≤ 16
  • Contiguous dimensions must be 16-byte aligned
  • All epilogue tensors must have the same major order (row or column)

Debugging Tips

Enable Verification

# Compare against reference implementation
--skip_ref_check  # Remove this flag during development
@cute.jit
def epilogue_debug(self, accumulator, c):
    # Print for debugging (only for small tensors!)
    if cute.thread_idx() == 0:
        cute.printf("Accumulator[0] = %f\n", accumulator[0])
    
    return accumulator + c

Check Memory Alignment

assert tensor.data_ptr() % 16 == 0, "Tensor must be 16-byte aligned"

Examples in the Repository

Find complete working examples:
  • Custom epilogue: examples/python/CuTeDSL/blackwell/epilogue/custom_epilogue_dense_gemm.py
  • Activation fusion: examples/python/CuTeDSL/blackwell/epilogue/activation_custom_epilogue_dense_gemm.py
  • Synthetic examples: examples/python/CuTeDSL/blackwell/epilogue/synthetic_custom_epilogue_dense_gemm.py
  • Legacy interface: examples/python/deprecated/01_epilogue.ipynb

Next Steps

Basic GEMM

Master basic GEMM operations first

Grouped GEMM

Combine custom epilogues with grouped operations

Build docs developers (and LLMs) love