Overview
CUTLASS provides highly optimized convolution operations for 2D and 3D inputs using the implicit GEMM algorithm. This approach reformulates convolution as a matrix multiplication, leveraging CUTLASS’s GEMM infrastructure for maximum performance.
Implicit GEMM convolution achieves similar performance to cuDNN while providing greater flexibility and customization through C++ templates.
Convolution Types
CUTLASS supports three main convolution operations:
Fprop Forward propagation Output = Conv(Activation, Filter)
Dgrad Activation gradient (backprop) dActivation = Conv(dOutput, Filter)
Wgrad Filter gradient (weight update) dFilter = Conv(Activation, dOutput)
cutlass::conv::device::ImplicitGemmConvolution
The primary convolution device operator template.
Template Structure
include/cutlass/conv/device/implicit_gemm_convolution.h:52
template < typename ImplicitGemmKernel_ >
class ImplicitGemmConvolution {
public:
using ElementA = typename UnderlyingKernel :: ElementA ;
using LayoutA = typename UnderlyingKernel :: LayoutA ;
using ElementB = typename UnderlyingKernel :: ElementB ;
using LayoutB = typename UnderlyingKernel :: LayoutB ;
using ElementC = typename UnderlyingKernel :: ElementC ;
using LayoutC = typename UnderlyingKernel :: LayoutC ;
using ElementAccumulator = typename UnderlyingKernel :: ElementAccumulator ;
static cutlass :: conv ::Operator const kConvolutionalOperator;
static int const kConvDim; // 2 for Conv2d, 3 for Conv3d
};
Key Type Aliases
Element type for activation tensor (input feature maps) For Fprop: input activations For Dgrad: output gradients For Wgrad: output gradients
Element type for filter tensor (kernel weights) For all modes: filter/kernel weights
Element type for output tensor For Fprop: output activations For Dgrad: input gradients For Wgrad: filter gradients
Convolution operation type:
conv::Operator::kFprop - Forward propagation
conv::Operator::kDgrad - Activation gradient
conv::Operator::kWgrad - Filter gradient
conv::Operator::kDeconv - Deconvolution (transposed convolution)
Problem Size (Conv2d)
include/cutlass/conv/conv2d_problem_size.h
struct Conv2dProblemSize {
int N; // Batch size
int H; // Input height
int W; // Input width
int C; // Input channels
int P; // Output height
int Q; // Output width
int K; // Output channels
int R; // Filter height
int S; // Filter width
int pad_h; // Padding height
int pad_w; // Padding width
int stride_h; // Stride height
int stride_w; // Stride width
int dilation_h; // Dilation height
int dilation_w; // Dilation width
conv ::Mode mode; // kCrossCorrelation or kConvolution
int split_k_slices; // For split-K parallelization
int groups; // For grouped convolution
};
Batch size (number of images)
Input spatial dimensions (Height × Width)
Number of output channels (filters)
Filter spatial dimensions (Height × Width)
Output spatial dimensions (Height × Width) Computed from input size, padding, stride, and dilation
Padding applied to top/left (in Fprop)
Filter dilation (atrous convolution)
Number of groups for grouped convolution When > 1, inputs and outputs are split into groups with independent convolutions
Tensor Layouts
CUTLASS convolution supports multiple tensor layouts:
NHWC (TensorNHWC)
// Activation: (N, H, W, C)
cutlass :: layout ::TensorNHWC
// Preferred for Tensor Core operations on Ampere+
Layout : Channels are the fastest-changing dimension (most contiguous in memory)
NCHW (TensorNCHW)
// Activation: (N, C, H, W)
cutlass :: layout ::TensorNCHW
// Traditional PyTorch/cuDNN layout
Layout : Spatial dimensions are contiguous
Filter Layouts
// NHWC format: (K, R, S, C)
cutlass :: layout ::TensorNHWC
// NCHW format: (K, C, R, S)
cutlass :: layout ::TensorNCHW
Forward Propagation (Fprop) Example
From examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu:
Kernel Definition
Problem Setup
Kernel Launch
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
using ElementAccumulator = float ;
using ElementComputeEpilogue = float ;
using ElementInputA = cutlass :: half_t ;
using ElementInputB = cutlass :: half_t ;
using ElementOutput = cutlass :: half_t ;
using LayoutInputA = cutlass :: layout :: TensorNHWC ;
using LayoutInputB = cutlass :: layout :: TensorNHWC ;
using LayoutOutput = cutlass :: layout :: TensorNHWC ;
using Conv2dFpropKernel = typename cutlass :: conv :: kernel ::DefaultConv2dFprop <
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
cutlass :: gemm ::GemmShape < 128 , 128 , 64 > , // Threadblock tile
cutlass :: gemm ::GemmShape < 64 , 64 , 64 > , // Warp tile
cutlass :: gemm ::GemmShape < 16 , 8 , 16 > , // Instruction tile
cutlass :: epilogue :: thread ::LinearCombination <
ElementOutput,
128 / cutlass :: sizeof_bits < ElementOutput >::value,
ElementAccumulator,
ElementComputeEpilogue
> ,
cutlass :: gemm :: threadblock ::GemmIdentityThreadblockSwizzle <> ,
3 , // Stages
cutlass :: arch ::OpMultiplyAdd
> ::Kernel;
using ImplicitGemm = cutlass :: conv :: device :: ImplicitGemmConvolution < Conv2dFpropKernel >;
Activation Gradient (Dgrad) Example
using Conv2dDgradKernel = typename cutlass :: conv :: kernel ::DefaultConv2dDgrad <
ElementInputA, // dY (output gradient)
LayoutInputA,
ElementInputB, // Filter (same as fprop)
LayoutInputB,
ElementOutput, // dX (input gradient - output of dgrad)
LayoutOutput,
ElementAccumulator,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
cutlass :: gemm ::GemmShape < 128 , 128 , 64 > ,
cutlass :: gemm ::GemmShape < 64 , 64 , 64 > ,
cutlass :: gemm ::GemmShape < 16 , 8 , 16 > ,
cutlass :: epilogue :: thread ::LinearCombination <
ElementOutput,
128 / cutlass :: sizeof_bits < ElementOutput >::value,
ElementAccumulator,
ElementComputeEpilogue
> ,
cutlass :: gemm :: threadblock ::GemmIdentityThreadblockSwizzle <> ,
3 ,
cutlass :: arch ::OpMultiplyAdd
> ::Kernel;
using Conv2dDgrad = cutlass :: conv :: device :: ImplicitGemmConvolution < Conv2dDgradKernel >;
Filter Gradient (Wgrad) Example
using Conv2dWgradKernel = typename cutlass :: conv :: kernel ::DefaultConv2dWgrad <
ElementInputA, // dY (output gradient)
LayoutInputA,
ElementInputB, // X (input activation)
LayoutInputB,
ElementOutput, // dW (filter gradient - output of wgrad)
LayoutOutput,
ElementAccumulator,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
cutlass :: gemm ::GemmShape < 128 , 128 , 64 > ,
cutlass :: gemm ::GemmShape < 64 , 64 , 64 > ,
cutlass :: gemm ::GemmShape < 16 , 8 , 16 > ,
cutlass :: epilogue :: thread ::LinearCombination <
ElementOutput,
128 / cutlass :: sizeof_bits < ElementOutput >::value,
ElementAccumulator,
ElementComputeEpilogue
> ,
cutlass :: gemm :: threadblock ::GemmIdentityThreadblockSwizzle <> ,
3 ,
cutlass :: arch ::OpMultiplyAdd
> ::Kernel;
using Conv2dWgrad = cutlass :: conv :: device :: ImplicitGemmConvolution < Conv2dWgradKernel >;
Advanced Features
Grouped Convolution
cutlass :: conv :: Conv2dProblemSize problem_size (
N , H , W , C ,
K , R , S ,
P , Q ,
pad_h , pad_w ,
stride_h , stride_w ,
dilation_h , dilation_w ,
cutlass :: conv :: Mode :: kCrossCorrelation ,
1 , // split_k_slices
4 // groups (4 independent convolutions)
);
// C and K must be divisible by groups
// Each group processes C/groups input channels -> K/groups output channels
Grouped convolution is commonly used in architectures like ResNeXt and MobileNet. CUTLASS supports both single-group and multi-group modes.
Depthwise Convolution
Depthwise convolution is a special case where groups = C = K:
// Depthwise: each input channel convolved independently
int channels = 256 ;
cutlass :: conv :: Conv2dProblemSize depthwise_problem (
N , H , W , channels , // Input
channels , // K = C for depthwise
3 , 3 , // 3×3 filter
H , W , // Same spatial size (padding=1, stride=1)
1 , 1 , // Padding
1 , 1 , // Stride
1 , 1 , // Dilation
cutlass :: conv :: Mode :: kCrossCorrelation ,
1 , // split_k
channels // groups = C = K
);
Strided Convolution
// Downsampling with stride=2
cutlass :: conv :: Conv2dProblemSize problem_size (
N , 56 , 56 , C , // Input 56×56
K , 3 , 3 , // 3×3 filter
28 , 28 , // Output 28×28 (56/2)
1 , 1 , // Padding
2 , 2 , // stride_h=2, stride_w=2
1 , 1 // Dilation
);
Dilated Convolution
// Dilated convolution for larger receptive field
cutlass :: conv :: Conv2dProblemSize problem_size (
N , H , W , C ,
K , 3 , 3 ,
P , Q ,
2 , 2 , // Padding (larger for dilation)
1 , 1 , // Stride
2 , 2 // dilation_h=2, dilation_w=2 (effective 5×5 filter)
);
Deconvolution (Transposed Convolution)
using DeconvKernel = typename cutlass :: conv :: kernel ::DefaultConv2dDeconv <
ElementInputA, LayoutInputA,
ElementInputB, LayoutInputB,
ElementOutput, LayoutOutput,
ElementAccumulator,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
ThreadblockSwizzle,
Stages,
MathOperator
> ::Kernel;
using Deconv = cutlass :: conv :: device :: ImplicitGemmConvolution < DeconvKernel >;
// Upsampling: input 28×28 -> output 56×56 with stride=2
3D Convolution
For volumetric data (e.g., video, medical imaging):
struct Conv3dProblemSize {
int N; // Batch
int D; // Depth
int H; // Height
int W; // Width
int C; // Input channels
int Z; // Output depth
int P; // Output height
int Q; // Output width
int K; // Output channels
int T; // Filter depth
int R; // Filter height
int S; // Filter width
int pad_d, pad_h, pad_w; // Padding
int stride_d, stride_h, stride_w; // Stride
int dilation_d, dilation_h, dilation_w; // Dilation
};
Memory Access Patterns
Iterator Algorithms
CUTLASS uses different iterator algorithms for different layouts:
conv::IteratorAlgorithm::kAnalytic
Analytically computed predicates (slower, more flexible)
conv::IteratorAlgorithm::kOptimized
Optimized iterators with precomputed predicates (faster)
conv::IteratorAlgorithm::kFixedStrideDilation
Specialized for unit stride and dilation
Operator Verification
ImplicitGemm conv_op;
cutlass ::Status status = ImplicitGemm :: can_implement (arguments);
if (status == cutlass :: Status ::kErrorInvalidProblem) {
// Problem size not supported (e.g., misaligned channels)
}
else if (status == cutlass :: Status ::kErrorNotSupported) {
// Feature combination not supported
}
else if (status == cutlass :: Status ::kSuccess) {
// Can proceed with kernel launch
}
Common reasons for can_implement failure:
Output channels K not aligned to required boundary
Invalid group count (C or K not divisible by groups)
Unsupported stride/dilation combination
Tensor size exceeds 2^31 elements
Tile Size Selection
For NHWC layout on Ampere/Hopper, use larger threadblock tiles:
GemmShape<128, 128, 64> or GemmShape<128, 256, 64>
For NCHW layout or older architectures:
GemmShape<128, 128, 32> or GemmShape<64, 64, 32>
Pipeline Stages
// More stages overlap data movement with compute
int Stages = 3 ; // Ampere/Hopper with async copy
int Stages = 2 ; // Turing and earlier
Split-K for Small Batches
// When N is small, parallelize across K dimension
cutlass :: conv :: Conv2dProblemSize problem_size (
1 , // Small batch
224 , 224 , 3 ,
64 , 7 , 7 ,
112 , 112 ,
3 , 3 , 2 , 2 , 1 , 1 ,
cutlass :: conv :: Mode :: kCrossCorrelation ,
4 // split_k_slices=4 for better parallelism
);
Example Applications
Image Classification ResNet, VGG, EfficientNet convolution layers
Object Detection YOLO, Faster R-CNN feature extractors
Semantic Segmentation U-Net, DeepLab upsampling and downsampling
Video Processing 3D convolutions for action recognition
Convolution Bias Add bias and activation in epilogue for fused operations
Batch Normalization Can be fused into convolution epilogue
Residual Connections Use epilogue to add skip connections
Gather/Scatter Conv Sparse or irregular convolution patterns
Comparison with cuDNN
Feature CUTLASS cuDNN Performance Comparable (often within 5%) Highly optimized Flexibility Full template customization Limited to API Epilogue Fusion Arbitrary C++ functors Limited built-ins Batched Ops Yes, with custom layouts Yes Grouped Conv Yes, single/multi-group Yes Mixed Precision Any combination Predefined types Code Generation Template instantiation Precompiled library
See Also
GEMM API Convolution builds on GEMM infrastructure
Epilogue Operations Fuse activations and bias into convolution
Conv Examples See examples/09_turing_tensorop_conv2dfprop/ and related examples
Performance Guide Optimize convolution kernels for your workload