Skip to main content

Batched GEMM Example

This example demonstrates how to use CUTLASS to compute batched GEMM operations in two different ways:
  1. Strided batched GEMM: Matrices separated by a fixed stride in memory
  2. Array GEMM: Arbitrary pointers to each matrix in the batch

Overview

Batched GEMM operations compute multiple independent matrix multiplications:
C[i] = alpha * (A[i] x B[i]) + beta * C[i]  for i = 0 to batch_count-1
This is common in many applications including neural network training, computer graphics, and scientific computing.

Key Concepts

  • Strided batched GEMM: Efficient when matrices are laid out with uniform spacing
  • Array GEMM: Flexible approach for arbitrary memory layouts
  • Batch stride: Distance in memory between consecutive matrices
  • Performance optimization: Amortize kernel launch overhead across multiple operations

Memory Layout

Consider a batch of 2 matrices with dimensions M=6, N=3, K=2:

Matrix C Layout (M=6, N=3, batch=2)

-----------------------------------------------------------
| (0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
-----------------------------------------------------------
| (0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------------------------------------
|    ...  |   ...   |   ...   |   ...   |   ...   |   ...   |
-----------------------------------------------------------
            batch 0          |           batch 1
Where (batch_idx, row_idx, column_idx) denotes each element. The batch stride is: batch_stride_C = ldc * N

Implementation

1
Strided Batched GEMM
2
Use when your matrices are laid out with uniform spacing in memory:
3
#include "cutlass/gemm/device/gemm_batched.h"

cudaError_t cutlass_strided_batched_sgemm(
  int m, 
  int n,
  int k,
  float alpha,
  float const *A,
  int lda,
  long long int batch_stride_A,
  float const *B,
  int ldb,
  long long int batch_stride_B,
  float *C,
  int ldc,
  long long int batch_stride_C,
  float beta,
  int batch_count) {

  using Gemm = cutlass::gemm::device::GemmBatched<
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor
  >;

  Gemm gemm_op;

  cutlass::Status status = gemm_op({
    {m, n, k},
    {A, lda}, 
    batch_stride_A,
    {B, ldb}, 
    batch_stride_B,
    {C, ldc}, 
    batch_stride_C,
    {C, ldc}, 
    batch_stride_C,
    {alpha, beta},
    batch_count
  });

  if (status != cutlass::Status::kSuccess) {
    return cudaErrorUnknown;
  }

  return cudaSuccess;
}
4
Array GEMM
5
Use when matrices are scattered in memory with irregular spacing:
6
#include "cutlass/gemm/device/gemm_array.h"

cudaError_t cutlass_array_sgemm(
  int m,
  int n,
  int k,
  float alpha,
  float const * const *A,  // Array of pointers
  int lda,
  float const * const *B,  // Array of pointers
  int ldb,
  float * const *C,        // Array of pointers
  int ldc,
  float beta,
  int batch_count) {

  using Gemm = cutlass::gemm::device::GemmArray<
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor
  >;

  Gemm gemm_op;

  cutlass::Status status = gemm_op({
    {m, n, k},
    A, lda,
    B, ldb,
    C, ldc,
    C, ldc,
    {alpha, beta},
    batch_count
  });

  if (status != cutlass::Status::kSuccess) {
    return cudaErrorUnknown;
  }

  return cudaSuccess;
}
7
Calculate batch strides
8
For strided batched GEMM, calculate the stride between consecutive matrices:
9
int m = 520, n = 219, k = 129;
int batch_count = 17;

int lda = m;
int ldb = k * batch_count;
int ldc = m;

// Stride between consecutive matrices in the batch
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
10
Setup array of pointers for Array GEMM
11
For array GEMM, create arrays of pointers to each matrix:
12
// Allocate host memory for pointers
std::vector<float*> host_ptr_A(batch_count);
std::vector<float*> host_ptr_B(batch_count);
std::vector<float*> host_ptr_C(batch_count);

// Matrices can be in any order - not required to be uniformly spaced
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
  host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
  host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
  host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}

// Allocate device memory for pointer arrays
float const **ptr_A;
float const **ptr_B;
float **ptr_C;

cudaMalloc(&ptr_A, batch_count * sizeof(float*));
cudaMalloc(&ptr_B, batch_count * sizeof(float*));
cudaMalloc(&ptr_C, batch_count * sizeof(float*));

// Copy pointers to device
cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);

// Launch array GEMM
cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
13
Complete example
14
Here’s a complete example running both approaches:
15
int main() {
  cudaError_t result = cudaSuccess;
  
  // Run both strided batched GEMM and array GEMM
  for (bool use_array : {false, true}) {
    result = run_batched_gemm(use_array);
    if (result == cudaSuccess) {
      std::cout << "Passed." << std::endl;
    } else {
      break;
    }
  }

  return result == cudaSuccess ? 0 : -1;
}

Building and Running

Build the example

cd /path/to/cutlass
mkdir build && cd build
cmake .. -DCUTLASS_NVCC_ARCHS='75;80;86'
make 05_batched_gemm

Run the example

./examples/05_batched_gemm/05_batched_gemm
Expected output:
Running strided batched gemm
Passed.
Running array gemm
Passed.

Source Code Location

The complete source code for this example is available at:
  • examples/05_batched_gemm/batched_gemm.cu

What This Example Demonstrates

  1. Two batching modes: Both strided and array-based batched GEMM
  2. Flexible memory layouts: How to handle both regular and irregular memory patterns
  3. Pointer management: Setting up device pointer arrays for array GEMM
  4. Correctness verification: Reference implementation for validating results

Performance Considerations

  • Strided batched GEMM is typically faster when matrices are uniformly spaced because:
    • Simpler addressing logic
    • Better memory access patterns
    • Less pointer indirection
  • Array GEMM provides flexibility when:
    • Matrices are scattered in memory
    • Each batch item comes from different allocations
    • You need arbitrary ordering of operations

Key Takeaways

  • Use GemmBatched for strided batched operations with uniform spacing
  • Use GemmArray for arbitrary pointer arrays with irregular layouts
  • Batch operations amortize kernel launch overhead across multiple GEMMs
  • Both approaches share the same underlying optimizations for individual matrix multiplications

Next Steps

Build docs developers (and LLMs) love