Basic GEMM Example
This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference matrix multiply kernel to verify its correctness.Overview
The CUTLASS GEMM template computes the general matrix product (GEMM) using single-precision floating-point arithmetic:Key Concepts
- GEMM kernel instantiation: Defining and launching a CUTLASS GEMM kernel
- Template parameters: Configuring data types and matrix layouts
- Argument objects: Passing parameters to CUTLASS kernels
- Reference validation: Verifying correctness against a naive implementation
Implementation
#include "cutlass/gemm/device/gemm.h"
cudaError_t CutlassSgemmNN(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
using ColumnMajor = cutlass::layout::ColumnMajor;
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
ColumnMajor, // Layout of A matrix
float, // Data-type of B matrix
ColumnMajor, // Layout of B matrix
float, // Data-type of C matrix
ColumnMajor>; // Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{C, ldc}, // Tensor-ref for destination matrix D
{alpha, beta}); // Scalars used in the Epilogue
Note that the destination matrix D can be different from source matrix C, allowing for out-of-place operations.
cutlass::Status status = gemm_operator(args);
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
// Compute leading dimensions for each matrix
int lda = M;
int ldb = K;
int ldc = M;
// Define pointers to matrices in GPU device memory
float *A;
float *B;
float *C_cutlass;
float *C_reference;
// Allocate matrices in GPU device memory
cudaMalloc(&A, sizeof(float) * M * K);
cudaMalloc(&B, sizeof(float) * K * N);
cudaMalloc(&C_cutlass, sizeof(float) * M * N);
cudaMalloc(&C_reference, sizeof(float) * M * N);
// Initialize matrices (implementation details omitted)
InitializeMatrix(A, M, K, 0);
InitializeMatrix(B, K, N, 17);
InitializeMatrix(C_cutlass, M, N, 101);
cudaMemcpy(C_reference, C_cutlass, sizeof(float) * M * N, cudaMemcpyDeviceToDevice);
// Launch CUTLASS GEMM
CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc);
// Launch reference GEMM for verification
ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc);
// Verify results (implementation details omitted)
// ...
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return cudaSuccess;
}
int main(int argc, const char *arg[]) {
// GEMM problem dimensions (default: 128x128x128)
int problem[3] = { 128, 128, 128 };
for (int i = 1; i < argc && i < 4; ++i) {
std::stringstream ss(arg[i]);
ss >> problem[i - 1];
}
// Scalars used for linear scaling (default: alpha=1, beta=0)
float scalars[2] = { 1, 0 };
for (int i = 4; i < argc && i < 6; ++i) {
std::stringstream ss(arg[i]);
ss >> scalars[i - 4];
}
cudaError_t result = TestCutlassGemm(
problem[0], // GEMM M dimension
problem[1], // GEMM N dimension
problem[2], // GEMM K dimension
scalars[0], // alpha
scalars[1] // beta
);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
}
return result == cudaSuccess ? 0 : -1;
}
Building and Running
Build the example
Run the example
Source Code Location
The complete source code for this example is available at:examples/00_basic_gemm/basic_gemm.cu
What This Example Demonstrates
- Minimal CUTLASS usage: This example deliberately uses minimal CUTLASS components to show the simplest path to getting started
- Template-based kernel instantiation: How to define a GEMM kernel using CUTLASS templates
- Argument construction: The pattern for passing arguments to CUTLASS kernels
- Correctness verification: How to validate CUTLASS output against a reference implementation
Key Takeaways
- CUTLASS provides simplified abstractions for high-performance GEMM operations
- The
cutlass::gemm::device::Gemmtemplate handles kernel instantiation - Argument objects provide a structured way to pass parameters to kernels
- CUTLASS defaults (e.g., 128x128x8 tile size) are chosen for good general performance
Next Steps
- Explore Batched GEMM for processing multiple independent matrix multiplications
- Learn about Fused Operations to combine GEMM with element-wise operations
- Check out Convolution for convolutional neural network operations