Skip to main content

Overview

Epilogue operators perform element-wise transformations on GEMM output before writing to memory. They apply operations like linear combinations, activation functions, and element-wise operations on accumulator and source tensor fragments. Header: cutlass/epilogue/thread/linear_combination.h (and related headers)

LinearCombination

Applies a linear combination operator to arrays of elements: D = alpha * accumulator + beta * source.

Template Signature

template <
  typename ElementOutput_,
  int Count,
  typename ElementAccumulator_ = ElementOutput_,
  typename ElementCompute_ = ElementOutput_,
  ScaleType::Kind Scale = ScaleType::Default,
  FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
  typename ElementSource_ = ElementOutput_
>
class LinearCombination;

Template Parameters

ElementOutput_
typename
Data type used to load and store tensors. Examples: float, half_t, int8_t
Count
int
Number of elements computed per operation. Usually 128/sizeof_bits<ElementOutput_>, but can be 64 or 32 when there’s less data to store.
ElementAccumulator_
typename
default:"ElementOutput_"
Accumulator data type. Often higher precision than output, e.g., float accumulators for half_t output.
ElementCompute_
typename
default:"ElementOutput_"
Data type used to compute linear combination
Scale
ScaleType::Kind
default:"ScaleType::Default"
Controls alpha and beta scaling behavior. Options: Default, NoBetaScaling, OnlyAlphaScaling, Nothing
Round
FloatRoundStyle
default:"FloatRoundStyle::round_to_nearest"
Rounding mode for numeric conversions
ElementSource_
typename
default:"ElementOutput_"
Source tensor data type

Member Types

using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementScalar = ElementCompute;
using ElementC = ElementSource_;
using ElementD = ElementOutput_;

static int const kCount = Count;
static const ScaleType::Kind kScale = Scale;

using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using FragmentCompute = Array<ElementCompute, kCount>;

static FloatRoundStyle const kRound = Round;

Parameters Structure

struct Params {
  ElementCompute alpha;
  ElementCompute beta;
  ElementCompute const *alpha_ptr;
  ElementCompute const *beta_ptr;
  ElementCompute const* const* alpha_ptr_array;
  ElementCompute const* const* beta_ptr_array;
  
  CUTLASS_HOST_DEVICE
  Params();
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute alpha, ElementCompute beta);
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute alpha);
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute const *alpha_ptr, ElementCompute const *beta_ptr);
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute const *alpha_ptr);
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute const* const* alpha_ptr_array,
         ElementCompute const* const* beta_ptr_array);
  
  CUTLASS_HOST_DEVICE
  Params(ElementCompute const* const* alpha_ptr_array);
};

Constructors

CUTLASS_HOST_DEVICE
explicit LinearCombination(Params const &params, int group_idx);

CUTLASS_HOST_DEVICE
explicit LinearCombination(const Params & params);
Constructs the function object, possibly loading alpha/beta from pointers in device memory.

Member Functions

is_source_needed

CUTLASS_HOST_DEVICE
bool is_source_needed() const;
Returns true if the source tensor is needed (i.e., beta != 0).

set_k_partition

CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count);
Functionally required for serial reduction in the epilogue.

operator()

CUTLASS_HOST_DEVICE
FragmentOutput operator()(
    FragmentAccumulator const &accumulator,
    FragmentSource const &source) const;
Computes linear scaling with source: D = alpha * accumulator + beta * source. Implementation (simplified):
FragmentCompute converted_source = source_converter(source);
FragmentCompute converted_accumulator = accumulator_converter(accumulator);

if (Scale == ScaleType::NoBetaScaling)
  intermediate = converted_source;
else
  intermediate = mul_add_source(beta_, converted_source);

intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate);

return destination_converter(intermediate);

Usage Example

From examples/00_basic_gemm/basic_gemm.cu:103-127:
using CutlassGemm = cutlass::gemm::device::Gemm<
  float,                      // ElementA
  cutlass::layout::ColumnMajor,  // LayoutA
  float,                      // ElementB
  cutlass::layout::ColumnMajor,  // LayoutB
  float,                      // ElementC
  cutlass::layout::ColumnMajor>; // LayoutC

CutlassGemm::Arguments args(
  {M, N, K},
  {A, lda},
  {B, ldb},
  {C, ldc},
  {C, ldc},
  {alpha, beta}  // LinearCombination::Params
);

cutlass::Status status = gemm_operator(args);

LinearCombinationRelu

Linear combination followed by ReLU activation: D = ReLU(alpha * accumulator + beta * source). Header: cutlass/epilogue/thread/linear_combination_relu.h

Template Signature

template <
  typename ElementOutput_,
  int Count,
  typename ElementAccumulator_ = ElementOutput_,
  typename ElementCompute_ = ElementOutput_,
  ScaleType::Kind Scale = ScaleType::Default,
  FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationRelu;

Usage

using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
  float,  // output
  128 / cutlass::sizeof_bits<float>::value,
  float,  // accumulator
  float   // computation
>;

using GemmKernel = cutlass::gemm::device::Gemm<
  float, cutlass::layout::RowMajor,
  float, cutlass::layout::RowMajor,
  float, cutlass::layout::RowMajor,
  float,  // accumulator
  cutlass::arch::OpClassSimt,
  cutlass::arch::Sm70,
  cutlass::gemm::GemmShape<128, 128, 8>,
  cutlass::gemm::GemmShape<32, 64, 8>,
  cutlass::gemm::GemmShape<1, 1, 1>,
  EpilogueOp  // Use ReLU epilogue
>;

LinearCombinationGelu

Linear combination followed by GELU activation. Header: cutlass/epilogue/thread/linear_combination_gelu.h

Template Signature

template <
  typename ElementOutput_,
  int Count,
  typename ElementAccumulator_ = ElementOutput_,
  typename ElementCompute_ = ElementOutput_,
  ScaleType::Kind Scale = ScaleType::Default,
  FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
  bool IsFast = true
>
class LinearCombinationGELU;

Template Parameters

IsFast
bool
default:"true"
If true, uses fast approximation of GELU. If false, uses precise GELU calculation.

LinearCombinationClamp

Linear combination with clamping to specified range. Header: cutlass/epilogue/thread/linear_combination_clamp.h

Template Signature

template <
  typename ElementOutput_,
  int Count,
  typename ElementAccumulator_ = ElementOutput_,
  typename ElementCompute_ = ElementOutput_,
  ScaleType::Kind Scale = ScaleType::Default,
  FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationClamp;

Parameters Structure

struct Params {
  ElementCompute alpha;
  ElementCompute beta;
  ElementCompute clamp_min;
  ElementCompute clamp_max;
  
  CUTLASS_HOST_DEVICE
  Params(
    ElementCompute alpha,
    ElementCompute beta,
    ElementCompute clamp_min,
    ElementCompute clamp_max
  );
};

Other Epilogue Operators

CUTLASS provides many specialized epilogue operators:

Activation Functions

  • LinearCombinationSigmoid - Sigmoid activation
  • LinearCombinationSilu - SiLU/Swish activation
  • LinearCombinationHardSwish - Hard Swish activation
  • LinearCombinationLeakyRelu - Leaky ReLU activation

Bias Operations

  • LinearCombinationBiasRelu - Add bias and apply ReLU
  • LinearCombinationBiasElementwise - Add bias with element-wise op

Specialized Operations

  • LinearCombinationResidualBlock - Residual connection fusion
  • LinearCombinationGeneric - Customizable generic epilogue
  • LinearCombinationPlanarComplex - Complex number operations

ScaleType Enumeration

namespace ScaleType {
  enum Kind {
    Default,           // Apply both alpha and beta scaling
    NoBetaScaling,     // Only alpha scaling, beta is ignored
    OnlyAlphaScaling,  // Source is not loaded
    Nothing            // No scaling, just type conversion
  };
}

Usage

// No beta scaling - saves memory bandwidth
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
  float,
  128 / cutlass::sizeof_bits<float>::value,
  float,
  float,
  cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
>;

Custom Epilogue Example

// Custom epilogue: D = alpha * Tanh(accumulator) + beta * source
template <typename T, int N>
struct LinearCombinationTanh {
  using FragmentOutput = cutlass::Array<T, N>;
  using FragmentAccumulator = cutlass::Array<T, N>;
  using FragmentSource = cutlass::Array<T, N>;
  
  struct Params {
    T alpha;
    T beta;
  };
  
  T alpha_;
  T beta_;
  
  CUTLASS_HOST_DEVICE
  LinearCombinationTanh(Params const &params) 
    : alpha_(params.alpha), beta_(params.beta) {}
  
  CUTLASS_HOST_DEVICE
  bool is_source_needed() const { return beta_ != T(0); }
  
  CUTLASS_HOST_DEVICE
  FragmentOutput operator()(
      FragmentAccumulator const &accum,
      FragmentSource const &source) const {
    
    FragmentOutput result;
    
    CUTLASS_PRAGMA_UNROLL
    for (int i = 0; i < N; ++i) {
      T tanh_val = tanh(float(accum[i]));
      result[i] = alpha_ * tanh_val + beta_ * source[i];
    }
    
    return result;
  }
};

See Also

  • Gemm - Uses epilogue operators for output transformation
  • Array - Fragment type used by epilogue operators
  • TensorRef - Source and destination tensor references

Build docs developers (and LLMs) love