Skip to main content

Overview

The cutlass::gemm::device::Gemm template provides a device-level interface to efficient CUTLASS GEMM kernels. It maps data types and structural parameters to specific CUTLASS components at compile time, and handles logical-to-kernel argument mapping and kernel launches at runtime. Header: cutlass/gemm/device/gemm.h

Template Signature

template <
    typename ElementA_,
    typename LayoutA_,
    typename ElementB_,
    typename LayoutB_,
    typename ElementC_,
    typename LayoutC_,
    typename ElementAccumulator_ = ElementC_,
    typename OperatorClass_ = arch::OpClassSimt,
    typename ArchTag_ = arch::Sm70,
    typename ThreadblockShape_ = /* default */,
    typename WarpShape_ = /* default */,
    typename InstructionShape_ = /* default */,
    typename EpilogueOutputOp_ = /* default */,
    typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
    int Stages = /* default */,
    int AlignmentA = /* default */,
    int AlignmentB = /* default */,
    bool SplitKSerial = false,
    typename Operator_ = /* default */,
    bool GatherA = false,
    bool GatherB = false,
    bool ScatterD = false,
    typename PermuteDLayout = layout::NoPermute
>
class Gemm;

Template Parameters

ElementA_
typename
Element type for A matrix operand. Examples: float, half_t, int8_t
LayoutA_
typename
Layout type for A matrix operand. Examples: layout::RowMajor, layout::ColumnMajor
ElementB_
typename
Element type for B matrix operand
LayoutB_
typename
Layout type for B matrix operand
ElementC_
typename
Element type for C and D matrix operands
LayoutC_
typename
Layout type for C and D matrix operands
ElementAccumulator_
typename
default:"ElementC_"
Element type for internal accumulation
OperatorClass_
typename
default:"arch::OpClassSimt"
Operator class tag. Examples: arch::OpClassSimt, arch::OpClassTensorOp
ArchTag_
typename
default:"arch::Sm70"
Tag indicating architecture to tune for. Minimum SM that supports the intended feature. Examples: arch::Sm70, arch::Sm75, arch::Sm80
ThreadblockShape_
typename
Threadblock-level tile size (concept: GemmShape). Defaults depend on configuration.
WarpShape_
typename
Warp-level tile size (concept: GemmShape)
InstructionShape_
typename
Instruction-level tile size (concept: GemmShape)
EpilogueOutputOp_
typename
Epilogue output operator. Example: epilogue::thread::LinearCombination
ThreadblockSwizzle_
typename
default:"threadblock::GemmIdentityThreadblockSwizzle<>"
Threadblock-level swizzling operator
Stages
int
Number of stages used in the pipelined mainloop
AlignmentA
int
Access granularity of A matrix in units of elements
AlignmentB
int
Access granularity of B matrix in units of elements
SplitKSerial
bool
default:"false"
If true, kernel supports split-K with serial reduction
Operator_
typename
Operation performed by GEMM
GatherA
bool
default:"false"
Gather operand A by using an index array
GatherB
bool
default:"false"
Gather operand B by using an index array
ScatterD
bool
default:"false"
Scatter result D by using an index array
PermuteDLayout
typename
default:"layout::NoPermute"
Permute result D

Member Types

using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;

Arguments Structure

struct Arguments {
    GemmCoord problem_size;
    TensorRef<ElementA const, LayoutA> ref_A;
    TensorRef<ElementB const, LayoutB> ref_B;
    TensorRef<ElementC const, LayoutC> ref_C;
    TensorRef<ElementC, LayoutC> ref_D;
    typename EpilogueOutputOp::Params epilogue;
    int split_k_slices;
    int const *gather_A_indices;
    int const *gather_B_indices;
    int const *scatter_D_indices;
    
    // Constructor
    CUTLASS_HOST_DEVICE
    Arguments(
      GemmCoord problem_size_,
      TensorRef<ElementA const, LayoutA> ref_A_,
      TensorRef<ElementB const, LayoutB> ref_B_,
      TensorRef<ElementC const, LayoutC> ref_C_,
      TensorRef<ElementC, LayoutC> ref_D_,
      typename EpilogueOutputOp::Params epilogue_ = 
        typename EpilogueOutputOp::Params(),
      int split_k_slices = 1,
      int const *gather_A_indices_ = nullptr,
      int const *gather_B_indices_ = nullptr,
      int const *scatter_D_indices_ = nullptr
    );
};

Member Functions

Constructor

Gemm();
Constructs the GEMM operator.

can_implement

static Status can_implement(Arguments const &args);
Determines whether the GEMM can execute the given problem. Parameters:
  • args: Arguments structure containing problem parameters
Returns: Status::kSuccess if the problem can be executed, error code otherwise

get_workspace_size

static size_t get_workspace_size(Arguments const &args);
Gets the workspace size in bytes required for the operation. Parameters:
  • args: Arguments structure
Returns: Size in bytes of required workspace

initialize

Status initialize(
  Arguments const &args,
  void *workspace = nullptr,
  cudaStream_t stream = nullptr
);
Initializes GEMM state from arguments. Parameters:
  • args: Arguments structure
  • workspace: Pointer to workspace memory
  • stream: CUDA stream
Returns: Status code

update

Status update(
  Arguments const &args,
  void *workspace = nullptr
);
Lightweight update given a subset of arguments.

run

Status run(cudaStream_t stream = nullptr);
Runs the kernel using initialized state.

operator()

Status operator()(cudaStream_t stream = nullptr);

Status operator()(
  Arguments const &args,
  void *workspace = nullptr,
  cudaStream_t stream = nullptr
);
Runs the kernel. The second overload combines initialization and execution.

Usage Example

From include/cutlass/gemm/device/gemm.h:87-114:
//
// Instantiate the CUTLASS GEMM operator.
//

cutlass::gemm::device::Gemm<
  float,
  cutlass::layout::ColumnMajor,
  float,
  cutlass::layout::ColumnMajor,
  float,
  cutlass::layout::ColumnMajor
> gemm_op;

//
// Launch the GEMM operation on the device
//

cutlass::Status status = gemm_op({
  {m, n, k},                          // GemmCoord problem_size,
  {A, lda},                           // TensorRef<float, layout::ColumnMajor> ref_A,
  {B, ldb},                           // TensorRef<float, layout::ColumnMajor> ref_B,
  {C, ldc},                           // TensorRef<float, layout::ColumnMajor> ref_C,
  {D, ldd},                           // TensorRef<float, layout::ColumnMajor> ref_D,
  {alpha, beta}                       // EpilogueOutputOp::Params epilogue_op_params
});

Complete Example

From examples/00_basic_gemm/basic_gemm.cu:78-145:
using ColumnMajor = cutlass::layout::ColumnMajor;

using CutlassGemm = cutlass::gemm::device::Gemm<float,        // Data-type of A matrix
                                                ColumnMajor,  // Layout of A matrix
                                                float,        // Data-type of B matrix
                                                ColumnMajor,  // Layout of B matrix
                                                float,        // Data-type of C matrix
                                                ColumnMajor>; // Layout of C matrix

// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;

// Construct the CUTLASS GEMM arguments object.
CutlassGemm::Arguments args({M , N, K},  // Gemm Problem dimensions
                            {A, lda},    // Tensor-ref for source matrix A
                            {B, ldb},    // Tensor-ref for source matrix B
                            {C, ldc},    // Tensor-ref for source matrix C
                            {C, ldc},    // Tensor-ref for destination matrix D
                            {alpha, beta}); // Scalars used in the Epilogue

//
// Launch the CUTLASS GEMM kernel.
//

cutlass::Status status = gemm_operator(args);

if (status != cutlass::Status::kSuccess) {
  return cudaErrorUnknown;
}

See Also

Build docs developers (and LLMs) love