#include <iostream>#include "cutlass/gemm/device/gemm.h"int main() { // Matrix dimensions: M x K * K x N = M x N int M = 1024; int N = 1024; int K = 1024; // Define the GEMM operation using ColumnMajor = cutlass::layout::ColumnMajor; using CutlassGemm = cutlass::gemm::device::Gemm< float, // Data type of A matrix ColumnMajor, // Layout of A matrix float, // Data type of B matrix ColumnMajor, // Layout of B matrix float, // Data type of C matrix ColumnMajor // Layout of C matrix >; // Allocate device memory float *A, *B, *C; size_t size_A = M * K * sizeof(float); size_t size_B = K * N * sizeof(float); size_t size_C = M * N * sizeof(float); cudaMalloc(&A, size_A); cudaMalloc(&B, size_B); cudaMalloc(&C, size_C); // Initialize matrices (simplified - you'd fill with real data) cudaMemset(A, 0, size_A); cudaMemset(B, 0, size_B); cudaMemset(C, 0, size_C); // Set up GEMM arguments: D = alpha * A * B + beta * C float alpha = 1.0f; float beta = 0.0f; CutlassGemm gemm_op; CutlassGemm::Arguments args( {M, N, K}, // Problem dimensions {A, K}, // Tensor-ref for A {B, N}, // Tensor-ref for B {C, N}, // Tensor-ref for C {C, N}, // Tensor-ref for D (output) {alpha, beta} // Epilogue scalars ); // Launch the GEMM kernel cutlass::Status status = gemm_op(args); if (status != cutlass::Status::kSuccess) { std::cerr << "CUTLASS GEMM kernel failed" << std::endl; return -1; } // Wait for completion cudaDeviceSynchronize(); std::cout << "GEMM completed successfully!" << std::endl; // Cleanup cudaFree(A); cudaFree(B); cudaFree(C); return 0;}
The CUTLASS Python interface provides a high-level API for running CUTLASS kernels from Python.
1
Install the CUTLASS Python package
pip install nvidia-cutlass
Or install from source:
cd ${CUTLASS_PATH}pip install .
2
Create a Python GEMM script
Create a file named my_gemm.py:
my_gemm.py
import cutlassimport numpy as np# Create a GEMM operation plan# This will use FP16 (half precision) with row-major layoutplan = cutlass.op.Gemm( element=np.float16, layout=cutlass.LayoutType.RowMajor)# Create input matricesM, N, K = 1024, 1024, 1024# Initialize with random valuesA = np.random.randn(M, K).astype(np.float16)B = np.random.randn(K, N).astype(np.float16)C = np.zeros((M, N), dtype=np.float16)D = np.zeros((M, N), dtype=np.float16) # Output# Run the GEMM: D = A @ B + Cplan.run(A, B, C, D)print("GEMM completed successfully!")print(f"Output shape: {D.shape}")print(f"Output sample: {D[0, :5]}")# Verify with NumPyreference = A @ B + Cerror = np.max(np.abs(D - reference))print(f"Max error vs NumPy: {error}")