Batched GEMM Example
This example demonstrates how to use CUTLASS to compute batched GEMM operations in two different ways:- Strided batched GEMM: Matrices separated by a fixed stride in memory
- Array GEMM: Arbitrary pointers to each matrix in the batch
Overview
Batched GEMM operations compute multiple independent matrix multiplications:Key Concepts
- Strided batched GEMM: Efficient when matrices are laid out with uniform spacing
- Array GEMM: Flexible approach for arbitrary memory layouts
- Batch stride: Distance in memory between consecutive matrices
- Performance optimization: Amortize kernel launch overhead across multiple operations
Memory Layout
Consider a batch of 2 matrices with dimensions M=6, N=3, K=2:Matrix C Layout (M=6, N=3, batch=2)
(batch_idx, row_idx, column_idx) denotes each element.
The batch stride is: batch_stride_C = ldc * N
Implementation
#include "cutlass/gemm/device/gemm_batched.h"
cudaError_t cutlass_strided_batched_sgemm(
int m,
int n,
int k,
float alpha,
float const *A,
int lda,
long long int batch_stride_A,
float const *B,
int ldb,
long long int batch_stride_B,
float *C,
int ldc,
long long int batch_stride_C,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmBatched<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
{A, lda},
batch_stride_A,
{B, ldb},
batch_stride_B,
{C, ldc},
batch_stride_C,
{C, ldc},
batch_stride_C,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
#include "cutlass/gemm/device/gemm_array.h"
cudaError_t cutlass_array_sgemm(
int m,
int n,
int k,
float alpha,
float const * const *A, // Array of pointers
int lda,
float const * const *B, // Array of pointers
int ldb,
float * const *C, // Array of pointers
int ldc,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmArray<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
A, lda,
B, ldb,
C, ldc,
C, ldc,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
int m = 520, n = 219, k = 129;
int batch_count = 17;
int lda = m;
int ldb = k * batch_count;
int ldc = m;
// Stride between consecutive matrices in the batch
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
// Allocate host memory for pointers
std::vector<float*> host_ptr_A(batch_count);
std::vector<float*> host_ptr_B(batch_count);
std::vector<float*> host_ptr_C(batch_count);
// Matrices can be in any order - not required to be uniformly spaced
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}
// Allocate device memory for pointer arrays
float const **ptr_A;
float const **ptr_B;
float **ptr_C;
cudaMalloc(&ptr_A, batch_count * sizeof(float*));
cudaMalloc(&ptr_B, batch_count * sizeof(float*));
cudaMalloc(&ptr_C, batch_count * sizeof(float*));
// Copy pointers to device
cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
// Launch array GEMM
cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
int main() {
cudaError_t result = cudaSuccess;
// Run both strided batched GEMM and array GEMM
for (bool use_array : {false, true}) {
result = run_batched_gemm(use_array);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
} else {
break;
}
}
return result == cudaSuccess ? 0 : -1;
}
Building and Running
Build the example
Run the example
Source Code Location
The complete source code for this example is available at:examples/05_batched_gemm/batched_gemm.cu
What This Example Demonstrates
- Two batching modes: Both strided and array-based batched GEMM
- Flexible memory layouts: How to handle both regular and irregular memory patterns
- Pointer management: Setting up device pointer arrays for array GEMM
- Correctness verification: Reference implementation for validating results
Performance Considerations
-
Strided batched GEMM is typically faster when matrices are uniformly spaced because:
- Simpler addressing logic
- Better memory access patterns
- Less pointer indirection
-
Array GEMM provides flexibility when:
- Matrices are scattered in memory
- Each batch item comes from different allocations
- You need arbitrary ordering of operations
Key Takeaways
- Use
GemmBatchedfor strided batched operations with uniform spacing - Use
GemmArrayfor arbitrary pointer arrays with irregular layouts - Batch operations amortize kernel launch overhead across multiple GEMMs
- Both approaches share the same underlying optimizations for individual matrix multiplications
Next Steps
- Learn about Basic GEMM for single matrix multiplication
- Explore Fused Operations to combine GEMM with activation functions
- Check out Convolution for batched convolution operations