Overview
The cutlass::gemm::device::Gemm template provides a device-level interface to efficient CUTLASS GEMM kernels. It maps data types and structural parameters to specific CUTLASS components at compile time, and handles logical-to-kernel argument mapping and kernel launches at runtime.
Header: cutlass/gemm/device/gemm.h
Template Signature
template <
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementAccumulator_ = ElementC_,
typename OperatorClass_ = arch::OpClassSimt,
typename ArchTag_ = arch::Sm70,
typename ThreadblockShape_ = /* default */,
typename WarpShape_ = /* default */,
typename InstructionShape_ = /* default */,
typename EpilogueOutputOp_ = /* default */,
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
int Stages = /* default */,
int AlignmentA = /* default */,
int AlignmentB = /* default */,
bool SplitKSerial = false,
typename Operator_ = /* default */,
bool GatherA = false,
bool GatherB = false,
bool ScatterD = false,
typename PermuteDLayout = layout::NoPermute
>
class Gemm;
Template Parameters
Element type for A matrix operand. Examples: float, half_t, int8_t
Layout type for A matrix operand. Examples: layout::RowMajor, layout::ColumnMajor
Element type for B matrix operand
Layout type for B matrix operand
Element type for C and D matrix operands
Layout type for C and D matrix operands
ElementAccumulator_
typename
default:"ElementC_"
Element type for internal accumulation
OperatorClass_
typename
default:"arch::OpClassSimt"
Operator class tag. Examples: arch::OpClassSimt, arch::OpClassTensorOp
ArchTag_
typename
default:"arch::Sm70"
Tag indicating architecture to tune for. Minimum SM that supports the intended feature. Examples: arch::Sm70, arch::Sm75, arch::Sm80
Threadblock-level tile size (concept: GemmShape). Defaults depend on configuration.
Warp-level tile size (concept: GemmShape)
Instruction-level tile size (concept: GemmShape)
Epilogue output operator. Example: epilogue::thread::LinearCombination
ThreadblockSwizzle_
typename
default:"threadblock::GemmIdentityThreadblockSwizzle<>"
Threadblock-level swizzling operator
Number of stages used in the pipelined mainloop
Access granularity of A matrix in units of elements
Access granularity of B matrix in units of elements
If true, kernel supports split-K with serial reduction
Operation performed by GEMM
Gather operand A by using an index array
Gather operand B by using an index array
Scatter result D by using an index array
PermuteDLayout
typename
default:"layout::NoPermute"
Permute result D
Member Types
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using TensorRefA = TensorRef<ElementA const, LayoutA>;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using TensorRefB = TensorRef<ElementB const, LayoutB>;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using TensorRefC = TensorRef<ElementC const, LayoutC>;
using TensorRefD = TensorRef<ElementC, LayoutC>;
Arguments Structure
struct Arguments {
GemmCoord problem_size;
TensorRef<ElementA const, LayoutA> ref_A;
TensorRef<ElementB const, LayoutB> ref_B;
TensorRef<ElementC const, LayoutC> ref_C;
TensorRef<ElementC, LayoutC> ref_D;
typename EpilogueOutputOp::Params epilogue;
int split_k_slices;
int const *gather_A_indices;
int const *gather_B_indices;
int const *scatter_D_indices;
// Constructor
CUTLASS_HOST_DEVICE
Arguments(
GemmCoord problem_size_,
TensorRef<ElementA const, LayoutA> ref_A_,
TensorRef<ElementB const, LayoutB> ref_B_,
TensorRef<ElementC const, LayoutC> ref_C_,
TensorRef<ElementC, LayoutC> ref_D_,
typename EpilogueOutputOp::Params epilogue_ =
typename EpilogueOutputOp::Params(),
int split_k_slices = 1,
int const *gather_A_indices_ = nullptr,
int const *gather_B_indices_ = nullptr,
int const *scatter_D_indices_ = nullptr
);
};
Member Functions
Constructor
Constructs the GEMM operator.
can_implement
static Status can_implement(Arguments const &args);
Determines whether the GEMM can execute the given problem.
Parameters:
args: Arguments structure containing problem parameters
Returns: Status::kSuccess if the problem can be executed, error code otherwise
get_workspace_size
static size_t get_workspace_size(Arguments const &args);
Gets the workspace size in bytes required for the operation.
Parameters:
args: Arguments structure
Returns: Size in bytes of required workspace
initialize
Status initialize(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr
);
Initializes GEMM state from arguments.
Parameters:
args: Arguments structure
workspace: Pointer to workspace memory
stream: CUDA stream
Returns: Status code
update
Status update(
Arguments const &args,
void *workspace = nullptr
);
Lightweight update given a subset of arguments.
run
Status run(cudaStream_t stream = nullptr);
Runs the kernel using initialized state.
operator()
Status operator()(cudaStream_t stream = nullptr);
Status operator()(
Arguments const &args,
void *workspace = nullptr,
cudaStream_t stream = nullptr
);
Runs the kernel. The second overload combines initialization and execution.
Usage Example
From include/cutlass/gemm/device/gemm.h:87-114:
//
// Instantiate the CUTLASS GEMM operator.
//
cutlass::gemm::device::Gemm<
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor,
float,
cutlass::layout::ColumnMajor
> gemm_op;
//
// Launch the GEMM operation on the device
//
cutlass::Status status = gemm_op({
{m, n, k}, // GemmCoord problem_size,
{A, lda}, // TensorRef<float, layout::ColumnMajor> ref_A,
{B, ldb}, // TensorRef<float, layout::ColumnMajor> ref_B,
{C, ldc}, // TensorRef<float, layout::ColumnMajor> ref_C,
{D, ldd}, // TensorRef<float, layout::ColumnMajor> ref_D,
{alpha, beta} // EpilogueOutputOp::Params epilogue_op_params
});
Complete Example
From examples/00_basic_gemm/basic_gemm.cu:78-145:
using ColumnMajor = cutlass::layout::ColumnMajor;
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
ColumnMajor, // Layout of A matrix
float, // Data-type of B matrix
ColumnMajor, // Layout of B matrix
float, // Data-type of C matrix
ColumnMajor>; // Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;
// Construct the CUTLASS GEMM arguments object.
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{C, ldc}, // Tensor-ref for destination matrix D
{alpha, beta}); // Scalars used in the Epilogue
//
// Launch the CUTLASS GEMM kernel.
//
cutlass::Status status = gemm_operator(args);
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
See Also