The Gemm class provides a high-level interface for constructing, compiling, and running General Matrix Multiply operations in CUTLASS.
Class Reference
cutlass.op.Gemm
from cutlass.op import Gemm
plan = Gemm(
A=None,
B=None,
C=None,
D=None,
alpha=1.0,
beta=0.0,
element_A=None,
element_B=None,
element_C=None,
element_D=None,
element_accumulator=None,
element=None,
layout_A=None,
layout_B=None,
layout_C=None,
layout=None,
cc=None,
kernel_cc=None
)
Constructs a GEMM operation that computes: D = alpha * (A @ B) + beta * C
The data types and layouts of operands are bound to the Gemm object throughout its lifetime and cannot be changed after construction.
Input tensor A with shape (M, K). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
Input tensor B with shape (K, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
Input tensor C with shape (M, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
Output tensor D with shape (M, N). Can be torch.Tensor, numpy.ndarray, or cupy.ndarray. If provided, the element type and layout are inferred from this tensor.
alpha
scalar, optional
default:"1.0"
Scalar multiplier for the A @ B product.
beta
scalar, optional
default:"0.0"
Scalar multiplier for the C tensor. When beta=0, C is not read.
element_A
cutlass.DataType, optional
Data type for operand A. Overrides generic element parameter.
element_B
cutlass.DataType, optional
Data type for operand B. Overrides generic element parameter.
element_C
cutlass.DataType, optional
Data type for operand C. Overrides generic element parameter.
element_D
cutlass.DataType, optional
Data type for operand D. Overrides generic element parameter.
element_accumulator
cutlass.DataType, optional
Data type used for accumulation. If not specified, defaults to the element type.
element
cutlass.DataType, optional
Generic data type for all operands (A, B, C, D) and accumulator. Can be overridden per operand.
layout_A
cutlass.LayoutType, optional
Memory layout for operand A. Overrides generic layout parameter.
layout_B
cutlass.LayoutType, optional
Memory layout for operand B. Overrides generic layout parameter.
layout_C
cutlass.LayoutType, optional
Memory layout for operand C and D. Overrides generic layout parameter.
layout
cutlass.LayoutType, optional
Generic layout for all operands. Can be overridden per operand.
Compute capability of the target device (e.g., 90 for H100). Defaults to auto-detection.
Compute capability of the kernel to generate. Useful for using older kernel implementations on newer hardware.
Methods
run()
args = plan.run(
A=None,
B=None,
C=None,
D=None,
alpha=None,
beta=None,
sync=True,
print_module=False
)
Executes the GEMM operation.
Input tensor A. If not provided, uses the tensor from initialization.
Input tensor B. If not provided, uses the tensor from initialization.
Input tensor C. If not provided, uses the tensor from initialization.
Output tensor D. If not provided, uses the tensor from initialization.
Scalar alpha. If not provided, uses the value from initialization.
Scalar beta. If not provided, uses the value from initialization.
If True, waits for the kernel to complete. If False, returns immediately.
If True, prints the generated CUDA C++ code.
Returns: GemmArguments object that can be used to synchronize or retrieve results.
compile()
plan.compile(
tile_description=None,
alignment=None
)
Explicitly compiles the GEMM kernel. This is optional as run() will compile if needed.
tile_description
cutlass.TileDescription, optional
Custom tile description for advanced kernel configuration.
Memory alignment requirement in elements. Defaults to automatic selection.
plan()
operation = plan.plan(
tile_description=None,
alignment=None
)
Generates an execution plan without compiling. Returns the selected operation configuration.
Returns: Operation descriptor with selected kernel parameters.
Properties
activation
plan.activation = cutlass.epilogue.relu
Sets the activation function to fuse into the epilogue. Available activations:
cutlass.epilogue.identity (default)
cutlass.epilogue.relu
cutlass.epilogue.gelu
cutlass.epilogue.sigmoid
cutlass.epilogue.tanh
cutlass.epilogue.silu
cutlass.epilogue.hardswish
cutlass.epilogue.leaky_relu
swizzle
plan.swizzle = cutlass.swizzle.ThreadblockSwizzleStreamK
Sets the threadblock swizzling function for improved performance.
Examples
Basic GEMM
import torch
from cutlass.op import Gemm
M, N, K = 512, 1024, 256
# Create tensors
A = torch.randn((M, K), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
C = torch.zeros((M, N), device='cuda', dtype=torch.float16)
D = torch.zeros((M, N), device='cuda', dtype=torch.float16)
# Run GEMM: D = A @ B
plan = Gemm(A=A, B=B, C=C, D=D, alpha=1.0, beta=0.0)
plan.run()
GEMM with Data Type Configuration
import numpy as np
import cutlass
from cutlass.op import Gemm
# Create plan with explicit types
plan = Gemm(
element_A=cutlass.DataType.f16,
element_B=cutlass.DataType.f16,
element_C=cutlass.DataType.f32,
element_D=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32,
layout=cutlass.LayoutType.RowMajor
)
plan.compile()
# Run with different tensors
A1 = np.random.randn(128, 256).astype(np.float16)
B1 = np.random.randn(256, 64).astype(np.float16)
C1 = np.zeros((128, 64), dtype=np.float32)
D1 = np.zeros((128, 64), dtype=np.float32)
plan.run(A1, B1, C1, D1)
GEMM with Activation Fusion
import torch
from cutlass.op import Gemm
import cutlass
plan = Gemm(
element=torch.float32,
layout=cutlass.LayoutType.RowMajor
)
plan.activation = cutlass.epilogue.relu
A = torch.randn((M, K), device='cuda')
B = torch.randn((K, N), device='cuda')
C = torch.zeros((M, N), device='cuda')
D = torch.zeros((M, N), device='cuda')
# Computes D = relu(A @ B + C)
plan.run(A, B, C, D, alpha=1.0, beta=1.0)
Mixed Precision GEMM
import torch
import cutlass
from cutlass.op import Gemm
# FP16 inputs, FP32 accumulation and output
plan = Gemm(
element_A=cutlass.DataType.f16,
element_B=cutlass.DataType.f16,
element_C=cutlass.DataType.f32,
element_D=cutlass.DataType.f32,
element_accumulator=cutlass.DataType.f32,
layout=cutlass.LayoutType.RowMajor
)
A = torch.randn((M, K), device='cuda', dtype=torch.float16)
B = torch.randn((K, N), device='cuda', dtype=torch.float16)
C = torch.zeros((M, N), device='cuda', dtype=torch.float32)
D = torch.zeros((M, N), device='cuda', dtype=torch.float32)
plan.run(A, B, C, D)
Column-Major GEMM (Fortran Layout)
import torch
import cutlass
from cutlass.op import Gemm
plan = Gemm(
element=torch.float32,
layout_A=cutlass.LayoutType.ColumnMajor,
layout_B=cutlass.LayoutType.ColumnMajor,
layout_C=cutlass.LayoutType.ColumnMajor
)
# Tensors are stored in column-major order
A = torch.randn((M, K), device='cuda').t().contiguous().t()
B = torch.randn((K, N), device='cuda').t().contiguous().t()
C = torch.zeros((M, N), device='cuda').t().contiguous().t()
D = torch.zeros((M, N), device='cuda').t().contiguous().t()
plan.run(A, B, C, D)
Asynchronous Execution
import torch
from cutlass.op import Gemm
plan = Gemm(element=torch.float32)
# Launch kernel asynchronously
args = plan.run(A, B, C, D, sync=False)
# Do other CPU work
process_other_data()
# Wait for GEMM to complete
args.sync()
# Access results
result = D.cpu().numpy()
Alignment: Ensure tensor dimensions are multiples of 8 or 16 for optimal performance with FP16/BF16 data types.
Persistent Tensors: Reuse tensor allocations across multiple run() calls to minimize memory allocation overhead.
Auto-tuning: The default parameters prioritize correctness over performance. For production workloads, consider profiling and tuning tile sizes, kernel schedules, and other parameters.
Type Aliases
For convenience, you can use native tensor types instead of CUTLASS types:
| Python Type | CUTLASS Type |
|---|
torch.float16 | cutlass.DataType.f16 |
torch.float32 | cutlass.DataType.f32 |
torch.float64 | cutlass.DataType.f64 |
torch.bfloat16 | cutlass.DataType.bf16 |
numpy.float16 | cutlass.DataType.f16 |
numpy.float32 | cutlass.DataType.f32 |
numpy.float64 | cutlass.DataType.f64 |
Source Code
Implementation: cutlass/python/cutlass_cppgen/op/gemm.py
See Also