Overview
Epilogue operators perform element-wise transformations on GEMM output before writing to memory. They apply operations like linear combinations, activation functions, and element-wise operations on accumulator and source tensor fragments.
Header: cutlass/epilogue/thread/linear_combination.h (and related headers)
LinearCombination
Applies a linear combination operator to arrays of elements: D = alpha * accumulator + beta * source.
Template Signature
template <
typename ElementOutput_,
int Count,
typename ElementAccumulator_ = ElementOutput_,
typename ElementCompute_ = ElementOutput_,
ScaleType::Kind Scale = ScaleType::Default,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
typename ElementSource_ = ElementOutput_
>
class LinearCombination;
Template Parameters
Data type used to load and store tensors. Examples: float, half_t, int8_t
Number of elements computed per operation. Usually 128/sizeof_bits<ElementOutput_>, but can be 64 or 32 when there’s less data to store.
ElementAccumulator_
typename
default:"ElementOutput_"
Accumulator data type. Often higher precision than output, e.g., float accumulators for half_t output.
ElementCompute_
typename
default:"ElementOutput_"
Data type used to compute linear combination
Scale
ScaleType::Kind
default:"ScaleType::Default"
Controls alpha and beta scaling behavior. Options: Default, NoBetaScaling, OnlyAlphaScaling, Nothing
Round
FloatRoundStyle
default:"FloatRoundStyle::round_to_nearest"
Rounding mode for numeric conversions
ElementSource_
typename
default:"ElementOutput_"
Source tensor data type
Member Types
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementScalar = ElementCompute;
using ElementC = ElementSource_;
using ElementD = ElementOutput_;
static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;
static FloatRoundStyle const kRound = Round;
Parameters Structure
struct Params {
ElementCompute alpha;
ElementCompute beta;
ElementCompute const *alpha_ptr;
ElementCompute const *beta_ptr;
ElementCompute const* const* alpha_ptr_array;
ElementCompute const* const* beta_ptr_array;
CUTLASS_HOST_DEVICE
Params();
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha, ElementCompute beta);
CUTLASS_HOST_DEVICE
Params(ElementCompute alpha);
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr);
CUTLASS_HOST_DEVICE
Params(ElementCompute const *alpha_ptr);
CUTLASS_HOST_DEVICE
Params(ElementCompute const* const* alpha_ptr_array,
ElementCompute const* const* beta_ptr_array);
CUTLASS_HOST_DEVICE
Params(ElementCompute const* const* alpha_ptr_array);
};
Constructors
CUTLASS_HOST_DEVICE
explicit LinearCombination(Params const ¶ms, int group_idx);
CUTLASS_HOST_DEVICE
explicit LinearCombination(const Params & params);
Constructs the function object, possibly loading alpha/beta from pointers in device memory.
Member Functions
is_source_needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const;
Returns true if the source tensor is needed (i.e., beta != 0).
set_k_partition
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count);
Functionally required for serial reduction in the epilogue.
operator()
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentSource const &source) const;
Computes linear scaling with source: D = alpha * accumulator + beta * source.
Implementation (simplified):
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);
if (Scale == ScaleType::NoBetaScaling)
intermediate = converted_source;
else
intermediate = mul_add_source(beta_, converted_source);
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);
return destination_converter(intermediate);
Usage Example
From examples/00_basic_gemm/basic_gemm.cu:103-127:
using CutlassGemm = cutlass::gemm::device::Gemm<
float, // ElementA
cutlass::layout::ColumnMajor, // LayoutA
float, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
float, // ElementC
cutlass::layout::ColumnMajor>; // LayoutC
CutlassGemm::Arguments args(
{M, N, K},
{A, lda},
{B, ldb},
{C, ldc},
{C, ldc},
{alpha, beta} // LinearCombination::Params
);
cutlass::Status status = gemm_operator(args);
LinearCombinationRelu
Linear combination followed by ReLU activation: D = ReLU(alpha * accumulator + beta * source).
Header: cutlass/epilogue/thread/linear_combination_relu.h
Template Signature
template <
typename ElementOutput_,
int Count,
typename ElementAccumulator_ = ElementOutput_,
typename ElementCompute_ = ElementOutput_,
ScaleType::Kind Scale = ScaleType::Default,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationRelu;
Usage
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
float, // output
128 / cutlass::sizeof_bits<float>::value,
float, // accumulator
float // computation
>;
using GemmKernel = cutlass::gemm::device::Gemm<
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, // accumulator
cutlass::arch::OpClassSimt,
cutlass::arch::Sm70,
cutlass::gemm::GemmShape<128, 128, 8>,
cutlass::gemm::GemmShape<32, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
EpilogueOp // Use ReLU epilogue
>;
LinearCombinationGelu
Linear combination followed by GELU activation.
Header: cutlass/epilogue/thread/linear_combination_gelu.h
Template Signature
template <
typename ElementOutput_,
int Count,
typename ElementAccumulator_ = ElementOutput_,
typename ElementCompute_ = ElementOutput_,
ScaleType::Kind Scale = ScaleType::Default,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
bool IsFast = true
>
class LinearCombinationGELU;
Template Parameters
If true, uses fast approximation of GELU. If false, uses precise GELU calculation.
LinearCombinationClamp
Linear combination with clamping to specified range.
Header: cutlass/epilogue/thread/linear_combination_clamp.h
Template Signature
template <
typename ElementOutput_,
int Count,
typename ElementAccumulator_ = ElementOutput_,
typename ElementCompute_ = ElementOutput_,
ScaleType::Kind Scale = ScaleType::Default,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationClamp;
Parameters Structure
struct Params {
ElementCompute alpha;
ElementCompute beta;
ElementCompute clamp_min;
ElementCompute clamp_max;
CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta,
ElementCompute clamp_min,
ElementCompute clamp_max
);
};
Other Epilogue Operators
CUTLASS provides many specialized epilogue operators:
Activation Functions
- LinearCombinationSigmoid - Sigmoid activation
- LinearCombinationSilu - SiLU/Swish activation
- LinearCombinationHardSwish - Hard Swish activation
- LinearCombinationLeakyRelu - Leaky ReLU activation
Bias Operations
- LinearCombinationBiasRelu - Add bias and apply ReLU
- LinearCombinationBiasElementwise - Add bias with element-wise op
Specialized Operations
- LinearCombinationResidualBlock - Residual connection fusion
- LinearCombinationGeneric - Customizable generic epilogue
- LinearCombinationPlanarComplex - Complex number operations
ScaleType Enumeration
namespace ScaleType {
enum Kind {
Default, // Apply both alpha and beta scaling
NoBetaScaling, // Only alpha scaling, beta is ignored
OnlyAlphaScaling, // Source is not loaded
Nothing // No scaling, just type conversion
};
}
Usage
// No beta scaling - saves memory bandwidth
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
float,
128 / cutlass::sizeof_bits<float>::value,
float,
float,
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;
Custom Epilogue Example
// Custom epilogue: D = alpha * Tanh(accumulator) + beta * source
template <typename T, int N>
struct LinearCombinationTanh {
using FragmentOutput = cutlass::Array<T, N>;
using FragmentAccumulator = cutlass::Array<T, N>;
using FragmentSource = cutlass::Array<T, N>;
struct Params {
T alpha;
T beta;
};
T alpha_;
T beta_;
CUTLASS_HOST_DEVICE
LinearCombinationTanh(Params const ¶ms)
: alpha_(params.alpha), beta_(params.beta) {}
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return beta_ != T(0); }
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accum,
FragmentSource const &source) const {
FragmentOutput result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
T tanh_val = tanh(float(accum[i]));
result[i] = alpha_ * tanh_val + beta_ * source[i];
}
return result;
}
};
See Also
- Gemm - Uses epilogue operators for output transformation
- Array - Fragment type used by epilogue operators
- TensorRef - Source and destination tensor references