Skip to main content
The Gemm class provides a high-level interface for constructing, compiling, and running General Matrix Multiply operations in CUTLASS.

Class Reference

cutlass.op.Gemm

from cutlass.op import Gemm

plan = Gemm(
    A=None,
    B=None, 
    C=None,
    D=None,
    alpha=1.0,
    beta=0.0,
    element_A=None,
    element_B=None,
    element_C=None,
    element_D=None,
    element_accumulator=None,
    element=None,
    layout_A=None,
    layout_B=None,
    layout_C=None,
    layout=None,
    cc=None,
    kernel_cc=None
)
Constructs a GEMM operation that computes: D = alpha * (A @ B) + beta * C The data types and layouts of operands are bound to the Gemm object throughout its lifetime and cannot be changed after construction.
A
tensor, optional
Input tensor A with shape (M, K). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
B
tensor, optional
Input tensor B with shape (K, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
C
tensor, optional
Input tensor C with shape (M, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
D
tensor, optional
Output tensor D with shape (M, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
alpha
scalar, optional
default:"1.0"
Scalar multiplier for the A @ B product.
beta
scalar, optional
default:"0.0"
Scalar multiplier for the C tensor. When beta=0, C is not read.
element_A
cutlass.DataType, optional
Data type for operand A. Overrides generic element parameter.
element_B
cutlass.DataType, optional
Data type for operand B. Overrides generic element parameter.
element_C
cutlass.DataType, optional
Data type for operand C. Overrides generic element parameter.
element_D
cutlass.DataType, optional
Data type for operand D. Overrides generic element parameter.
element_accumulator
cutlass.DataType, optional
Data type used for accumulation. If not specified, defaults to the element type.
element
cutlass.DataType, optional
Generic data type for all operands (A, B, C, D) and accumulator. Can be overridden per operand.
layout_A
cutlass.LayoutType, optional
Memory layout for operand A. Overrides generic layout parameter.
layout_B
cutlass.LayoutType, optional
Memory layout for operand B. Overrides generic layout parameter.
layout_C
cutlass.LayoutType, optional
Memory layout for operand C and D. Overrides generic layout parameter.
layout
cutlass.LayoutType, optional
Generic layout for all operands. Can be overridden per operand.
cc
int, optional
Compute capability of the target device (e.g., 90 for H100). Defaults to auto-detection.
kernel_cc
int, optional
Compute capability of the kernel to generate. Useful for using older kernel implementations on newer hardware.

Methods

run()

args = plan.run(
    A=None,
    B=None,
    C=None, 
    D=None,
    alpha=None,
    beta=None,
    sync=True,
    print_module=False
)
Executes the GEMM operation.
A
tensor, optional
Input tensor A. If not provided, uses the tensor from initialization.
B
tensor, optional
Input tensor B. If not provided, uses the tensor from initialization.
C
tensor, optional
Input tensor C. If not provided, uses the tensor from initialization.
D
tensor, optional
Output tensor D. If not provided, uses the tensor from initialization.
alpha
scalar, optional
Scalar alpha. If not provided, uses the value from initialization.
beta
scalar, optional
Scalar beta. If not provided, uses the value from initialization.
sync
bool
default:"True"
If True, waits for the kernel to complete. If False, returns immediately.
print_module
bool
default:"False"
If True, prints the generated CUDA C++ code.
Returns: GemmArguments object that can be used to synchronize or retrieve results.

compile()

plan.compile(
    tile_description=None,
    alignment=None
)
Explicitly compiles the GEMM kernel. This is optional as run() will compile if needed.
tile_description
cutlass.TileDescription, optional
Custom tile description for advanced kernel configuration.
alignment
int, optional
Memory alignment requirement in elements. Defaults to automatic selection.

plan()

operation = plan.plan(
    tile_description=None,
    alignment=None
)
Generates an execution plan without compiling. Returns the selected operation configuration. Returns: Operation descriptor with selected kernel parameters.

Properties

activation

plan.activation = cutlass.epilogue.relu
Sets the activation function to fuse into the epilogue. Available activations:
  • cutlass.epilogue.identity (default)
  • cutlass.epilogue.relu
  • cutlass.epilogue.gelu
  • cutlass.epilogue.sigmoid
  • cutlass.epilogue.tanh
  • cutlass.epilogue.silu
  • cutlass.epilogue.hardswish
  • cutlass.epilogue.leaky_relu

swizzle

plan.swizzle = cutlass.swizzle.ThreadblockSwizzleStreamK
Sets the threadblock swizzling function for improved performance.

Examples

Basic GEMM

import torch
from cutlass.op import Gemm

M, N, K = 512, 1024, 256

# Create tensors
A = torch.randn((M, K), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
C = torch.zeros((M, N), device='cuda', dtype=torch.float16)
D = torch.zeros((M, N), device='cuda', dtype=torch.float16)

# Run GEMM: D = A @ B
plan = Gemm(A=A, B=B, C=C, D=D, alpha=1.0, beta=0.0)
plan.run()

GEMM with Data Type Configuration

import numpy as np
import cutlass
from cutlass.op import Gemm

# Create plan with explicit types
plan = Gemm(
    element_A=cutlass.DataType.f16,
    element_B=cutlass.DataType.f16,
    element_C=cutlass.DataType.f32,
    element_D=cutlass.DataType.f32,
    element_accumulator=cutlass.DataType.f32,
    layout=cutlass.LayoutType.RowMajor
)

plan.compile()

# Run with different tensors
A1 = np.random.randn(128, 256).astype(np.float16)
B1 = np.random.randn(256, 64).astype(np.float16)
C1 = np.zeros((128, 64), dtype=np.float32)
D1 = np.zeros((128, 64), dtype=np.float32)

plan.run(A1, B1, C1, D1)

GEMM with Activation Fusion

import torch
from cutlass.op import Gemm
import cutlass

plan = Gemm(
    element=torch.float32,
    layout=cutlass.LayoutType.RowMajor
)
plan.activation = cutlass.epilogue.relu

A = torch.randn((M, K), device='cuda')
B = torch.randn((K, N), device='cuda')
C = torch.zeros((M, N), device='cuda')
D = torch.zeros((M, N), device='cuda')

# Computes D = relu(A @ B + C)
plan.run(A, B, C, D, alpha=1.0, beta=1.0)

Mixed Precision GEMM

import torch
import cutlass
from cutlass.op import Gemm

# FP16 inputs, FP32 accumulation and output
plan = Gemm(
    element_A=cutlass.DataType.f16,
    element_B=cutlass.DataType.f16,
    element_C=cutlass.DataType.f32,
    element_D=cutlass.DataType.f32,
    element_accumulator=cutlass.DataType.f32,
    layout=cutlass.LayoutType.RowMajor
)

A = torch.randn((M, K), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
C = torch.zeros((M, N), device='cuda', dtype=torch.float32)
D = torch.zeros((M, N), device='cuda', dtype=torch.float32)

plan.run(A, B, C, D)

Column-Major GEMM (Fortran Layout)

import torch
import cutlass
from cutlass.op import Gemm

plan = Gemm(
    element=torch.float32,
    layout_A=cutlass.LayoutType.ColumnMajor,
    layout_B=cutlass.LayoutType.ColumnMajor,
    layout_C=cutlass.LayoutType.ColumnMajor
)

# Tensors are stored in column-major order
A = torch.randn((M, K), device='cuda').t().contiguous().t()
B = torch.randn((K, N), device='cuda').t().contiguous().t()
C = torch.zeros((M, N), device='cuda').t().contiguous().t()
D = torch.zeros((M, N), device='cuda').t().contiguous().t()

plan.run(A, B, C, D)

Asynchronous Execution

import torch
from cutlass.op import Gemm

plan = Gemm(element=torch.float32)

# Launch kernel asynchronously
args = plan.run(A, B, C, D, sync=False)

# Do other CPU work
process_other_data()

# Wait for GEMM to complete
args.sync()

# Access results
result = D.cpu().numpy()

Performance Tips

Alignment: Ensure tensor dimensions are multiples of 8 or 16 for optimal performance with FP16/BF16 data types.
Persistent Tensors: Reuse tensor allocations across multiple run() calls to minimize memory allocation overhead.
Auto-tuning: The default parameters prioritize correctness over performance. For production workloads, consider profiling and tuning tile sizes, kernel schedules, and other parameters.

Type Aliases

For convenience, you can use native tensor types instead of CUTLASS types:
Python TypeCUTLASS Type
torch.float16cutlass.DataType.f16
torch.float32cutlass.DataType.f32
torch.float64cutlass.DataType.f64
torch.bfloat16cutlass.DataType.bf16
numpy.float16cutlass.DataType.f16
numpy.float32cutlass.DataType.f32
numpy.float64cutlass.DataType.f64

Source Code

Implementation: cutlass/python/cutlass_cppgen/op/gemm.py

See Also

Build docs developers (and LLMs) love