Fused Operations Example
This example demonstrates how to fuse element-wise operations (bias and ReLU) with GEMM computation in a single kernel, improving performance by eliminating intermediate memory operations.Overview
Instead of computing GEMM and then applying bias and ReLU in separate kernel launches:Key Concepts
- Epilogue: The final stage of a GEMM kernel that processes the accumulator
- Fused operations: Combining multiple operations to reduce memory bandwidth
- Custom epilogue: Using CUTLASS epilogue templates for different output transformations
- Bias broadcasting: Applying per-row or per-column bias efficiently
Implementation
using ElementAccumulator = float; // Accumulator data type
using ElementComputeEpilogue = ElementAccumulator; // Epilogue computation type
using ElementInputA = cutlass::half_t; // Input matrix A (FP16)
using ElementInputB = cutlass::half_t; // Input matrix B (FP16)
using ElementOutput = float; // Output matrix D (FP32)
using LayoutInputA = cutlass::layout::ColumnMajor;
using LayoutInputB = cutlass::layout::ColumnMajor;
using LayoutOutput = cutlass::layout::ColumnMajor;
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput, // Output data type
128 / cutlass::sizeof_bits<ElementOutput>::value, // Elements per vector access
ElementAccumulator, // Accumulator data type
ElementComputeEpilogue, // Epilogue computation type
cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // alpha * C + bias (no beta)
// Use Tensor Cores
using MMAOp = cutlass::arch::OpClassTensorOp;
// Target Turing architecture (SM75)
using SmArch = cutlass::arch::Sm75;
// Threadblock tile: 128x128x32
using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>;
// Warp tile: 64x64x32
using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>;
// MMA instruction shape: 16x8x8
using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>;
// Threadblock swizzling
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>;
// Pipeline stages
constexpr int NumStages = 2;
using Gemm = cutlass::gemm::device::Gemm<ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ShapeMMAThreadBlock,
ShapeMMAWarp,
ShapeMMAOp,
EpilogueOp,
SwizzleThreadBlock,
NumStages>;
const int length_m = 5120;
const int length_n = 4096;
const int length_k = 4096;
// Create problem size
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
// Initialize tensors
cutlass::HostTensor<ElementInputA, LayoutInputA> tensor_a(problem_size.mk());
cutlass::HostTensor<ElementInputB, LayoutInputB> tensor_b(problem_size.kn());
// Bias vector: M x 1 (one bias per row for column-major output)
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_c_bias({problem_size.m(), 1});
cutlass::HostTensor<ElementOutput, LayoutOutput> tensor_d(problem_size.mn());
// Fill tensors with random data
cutlass::reference::host::TensorFillRandomUniform(
tensor_a.host_view(), 1, ElementInputA(4), ElementInputA(-4), 0);
cutlass::reference::host::TensorFillRandomUniform(
tensor_b.host_view(), 1, ElementInputB(4), ElementInputB(-4), 0);
cutlass::reference::host::TensorFillRandomUniform(
tensor_c_bias.host_view(), 1, ElementOutput(4), ElementOutput(-4), 0);
// Copy to device
tensor_a.sync_device();
tensor_b.sync_device();
tensor_c_bias.sync_device();
tensor_d.sync_device();
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
int split_k_slices = 1;
// Create arguments
typename Gemm::Arguments arguments{
problem_size, // Problem size
tensor_a.device_ref(), // Tensor A
tensor_b.device_ref(), // Tensor B
// Bias vector with stride 0 in the N dimension to broadcast
{tensor_c_bias.device_data(), 0},
tensor_d.device_ref(), // Output tensor
{alpha}, // Alpha (no beta)
split_k_slices}; // Split-K slices
// Allocate workspace
size_t workspace_size = Gemm::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
// Initialize and run
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(arguments);
if (status != cutlass::Status::kSuccess) {
std::cerr << "Kernel cannot be executed" << std::endl;
return -1;
}
status = gemm_op.initialize(arguments, workspace.get());
status = gemm_op();
if (status != cutlass::Status::kSuccess) {
std::cerr << "Kernel execution failed" << std::endl;
return -1;
}
// Compute reference: GEMM followed by bias + ReLU
cutlass::reference::device::Gemm<...> gemm_device_reference;
gemm_device_reference(
problem_size,
alpha,
tensor_a.device_ref(),
tensor_b.device_ref(),
0, // beta = 0
tensor_ref_d.device_ref());
cudaDeviceSynchronize();
// Copy to host and apply bias + ReLU
tensor_ref_d.sync_host();
for (int i = 0; i < problem_size.m(); ++i) {
for (int j = 0; j < problem_size.n(); ++j) {
tensor_ref_d.at({i, j}) = std::max(
ElementOutput(0),
ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0}))
);
}
}
// Compare results
tensor_d.sync_host();
bool passed = cutlass::reference::host::TensorEquals(
tensor_d.host_view(),
tensor_ref_d.host_view());
std::cout << (passed ? "Passed" : "Failed") << std::endl;
Building and Running
Build the example
Run the example
Source Code Location
The complete source code for this example is available at:examples/12_gemm_bias_relu/gemm_bias_relu.cu
What This Example Demonstrates
- Epilogue customization: Using custom epilogue operations for fused kernels
- Bias broadcasting: Efficiently applying per-row bias using stride-0 tensors
- Activation functions: Fusing ReLU activation with GEMM
- Performance optimization: Eliminating intermediate memory operations
- Mixed precision: Using FP16 inputs with FP32 accumulation and output
Performance Benefits
Fusing operations provides significant benefits:- Reduced memory bandwidth: No need to write/read intermediate results
- Lower latency: Single kernel launch instead of multiple
- Better cache utilization: Data stays in cache between operations
- Higher throughput: Overlapped computation and memory operations
Bias Layout Requirements
Important notes about bias layout:- For column-major output: bias must be M×1 (per-row bias)
- For row-major output: bias would be 1×N (per-column bias)
- Use stride 0 in the non-broadcast dimension to efficiently broadcast the bias vector
Key Takeaways
- CUTLASS epilogues enable efficient fusion of element-wise operations
LinearCombinationRelucombines scaling, bias addition, and ReLU activation- Bias broadcasting is achieved using stride-0 tensor references
- Fused operations significantly reduce memory traffic and improve performance
- The same pattern works for other epilogue operations (GELU, sigmoid, etc.)
Next Steps
- Explore Basic GEMM for simple matrix multiplication
- Learn about Batched GEMM for multiple independent operations
- Check out other epilogue operations in
cutlass/epilogue/thread/