Overview
cute::TiledMma represents a tiled matrix multiply-accumulate operation that partitions an MMA across multiple threads and values. It builds on MMA_Atom (hardware-level MMA instructions) by tiling them in the M, N, and K dimensions.
Class Template
template <class MMA_Atom,
class AtomLayoutMNK,
class PermutationMNK = Tile<Underscore,Underscore,Underscore>>
struct TiledMMA : MMA_Atom
{
using Atom = MMA_Atom;
using AtomShape_MNK = typename MMA_Atom::Shape_MNK;
using AtomThrID = typename MMA_Atom::ThrID;
using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
ThrLayoutVMNK thr_layout_vmnk_;
CUTE_HOST_DEVICE constexpr
TiledMMA(MMA_Atom const& mma_atom = {},
AtomLayoutMNK const& thr_layout_mnk = {});
};
Template Parameters
The atomic MMA operation (e.g., SM80_16x8x16_F16F16F16F16_TN for Ampere Tensor Cores).
The MNK-tiling of the Atom. Specifies how many atoms are tiled in M, N, and K dimensions.
PermutationMNK
Tile
default:"Tile<_,_,_>"
Optional permutations to apply to each MNK mode before tiling.
Source Location
include/cute/atom/mma_atom.hpp:202-457
Member Types
Value Types
using ValTypeD = typename Traits::ValTypeD; // Accumulator type
using ValTypeA = typename Traits::ValTypeA; // A operand type
using ValTypeB = typename Traits::ValTypeB; // B operand type
using ValTypeC = typename Traits::ValTypeC; // C operand type
Fragment Types
using FrgTypeD = /* ... */; // D fragment type
using FrgTypeA = /* ... */; // A fragment type
using FrgTypeB = /* ... */; // B fragment type
using FrgTypeC = /* ... */; // C fragment type
Layout Types
using Shape_MNK = /* ... */; // Shape of a single atom (M, N, K)
using LayoutC_TV = /* ... */; // (thread, value) -> C coordinate
using LayoutA_TV = /* ... */; // (thread, value) -> A coordinate
using LayoutB_TV = /* ... */; // (thread, value) -> B coordinate
Member Functions
Tensor Partitioning
thrfrg_C()
template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_C(CTensor&& ctensor) const;
Tiles a C tensor from shape (M,N,...) to shape ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))).
Parameters:
ctensor: The C/D tensor to partition
Returns: Partitioned tensor with thread and fragment modes separated
Example:
auto gmem_C = make_tensor(ptr_C, make_shape(128, 128));
auto thr_layout = tiled_mma.thrfrg_C(gmem_C);
// Shape: ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))
thrfrg_A()
template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_A(ATensor&& atensor) const;
Tiles an A tensor from shape (M,K,...) to shape ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))).
Parameters:
atensor: The A operand tensor to partition
Returns: Partitioned tensor with thread and fragment modes
thrfrg_B()
template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_B(BTensor&& btensor) const;
Tiles a B tensor from shape (N,K,...) to shape ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))).
Parameters:
btensor: The B operand tensor to partition
Returns: Partitioned tensor with thread and fragment modes
Thread Slicing
get_slice()
template <class ThrIdx>
CUTE_HOST_DEVICE constexpr auto
get_slice(ThrIdx const& thr_idx) const;
Returns a ThrMMA object for a specific thread index.
Parameters:
thr_idx: Thread index within the TiledMMA
Returns: ThrMMA<TiledMMA, ThrCoord> for per-thread operations
Example:
int tid = threadIdx.x;
auto thr_mma = tiled_mma.get_slice(tid);
get_thread_slice()
template <class ThrIdx>
CUTE_HOST_DEVICE constexpr auto
get_thread_slice(ThrIdx const& thr_idx) const;
Alias for get_slice(). Returns per-thread MMA object.
Layout Accessors
get_layoutC_TV()
CUTE_HOST_DEVICE constexpr auto
get_layoutC_TV() const;
Returns the (thread_idx, value_idx) -> (M, N) layout for the C matrix.
get_layoutA_TV()
CUTE_HOST_DEVICE constexpr auto
get_layoutA_TV() const;
Returns the (thread_idx, value_idx) -> (M, K) layout for the A matrix.
get_layoutB_TV()
CUTE_HOST_DEVICE constexpr auto
get_layoutB_TV() const;
Returns the (thread_idx, value_idx) -> (N, K) layout for the B matrix.
Utility
tile_size_mnk()
template <int I>
CUTE_HOST_DEVICE constexpr auto
tile_size_mnk() const;
Returns the total tile size in the M (I=0), N (I=1), or K (I=2) dimension.
Example:
auto M_tile = tiled_mma.tile_size_mnk<0>(); // Total M extent
auto N_tile = tiled_mma.tile_size_mnk<1>(); // Total N extent
auto K_tile = tiled_mma.tile_size_mnk<2>(); // Total K extent
ThrMMA - Per-Thread MMA
template <class TiledMMA, class ThrCoord>
struct ThrMMA : TiledMMA
{
ThrCoord thr_vmnk_;
};
Represents the MMA operation for a single thread.
ThrMMA Methods
partition_C()
template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
partition_C(CTensor&& ctensor) const;
Partitions the C tensor for this thread.
Returns: Per-thread view of C with shape (FrgV,(RestM,RestN,...))
partition_A()
template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
partition_A(ATensor&& atensor) const;
Partitions the A tensor for this thread.
Returns: Per-thread view of A with shape (FrgV,(RestM,RestK,...))
partition_B()
template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
partition_B(BTensor&& btensor) const;
Partitions the B tensor for this thread.
Returns: Per-thread view of B with shape (FrgV,(RestN,RestK,...))
partition_fragment_C()
template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_C(CTensor&& ctensor) const;
Partitions C and creates a register fragment.
Returns: Register tensor fragment for C accumulation
partition_fragment_A()
template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_A(ATensor&& atensor) const;
Partitions A and creates a register fragment.
partition_fragment_B()
template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_B(BTensor&& btensor) const;
Partitions B and creates a register fragment.
Factory Functions
make_tiled_mma()
template <class MMA_Op,
class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
class Permutations = Tile<Underscore,Underscore,Underscore>>
CUTE_HOST_DEVICE constexpr auto
make_tiled_mma(MMA_Atom<MMA_Op> const& mma_atom,
MMAThrLayout const& thr_layout = {},
Permutations const& permutations = {});
Creates a TiledMMA from an MMA atom and thread layout.
Parameters:
mma_atom: The atomic MMA operation
thr_layout: Thread layout in (M, N, K) specifying the tiling
permutations: Optional permutations for each mode
Example:
using namespace cute;
// Single SM80 Tensor Core MMA
using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
// Tile 2x2 in M and N, 1 in K
auto tiled_mma = make_tiled_mma(
MMA_Atom{},
make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
);
// This creates a 32x16x16 MMA using 4 atoms (2x2x1)
Common MMA Atoms
Ampere (SM80) Tensor Cores
// FP16 Tensor Core MMA
SM80_16x8x16_F16F16F16F16_TN // 16x8x16 FP16 output
SM80_16x8x8_F32F16F16F32_TN // 16x8x8 FP32 output
// TF32 Tensor Core MMA
SM80_16x8x8_F32TF32TF32F32_TN // 16x8x8 TF32 -> FP32
// INT8 Tensor Core MMA
SM80_16x8x32_S32S8S8S32_TN // 16x8x32 INT8 -> INT32
Hopper (SM90) Tensor Cores
// Warp-group MMA (WGMMA)
SM90_64x8x16_F16F16F16_SS // 64x8x16 shared->shared
SM90_64x16x16_F16F16F16_SS // 64x16x16
SM90_64x32x16_F16F16F16_SS // 64x32x16
// With FP8
SM90_64x16x32_F16E4M3E4M3_SS // FP8 E4M3 inputs
SM90_64x16x32_F16E5M2E5M2_SS // FP8 E5M2 inputs
Usage Examples
Basic TiledMMA Setup
#include <cute/tensor.hpp>
#include <cute/atom/mma_atom.hpp>
using namespace cute;
// Define MMA atom for SM80 Tensor Cores
using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
// Create a 2x2 tiled MMA (32x16x16)
auto tiled_mma = make_tiled_mma(
MMA_Atom{},
make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
);
// Get thread-level MMA
int tid = threadIdx.x;
auto thr_mma = tiled_mma.get_slice(tid);
Partitioning Tensors
__device__ void gemm_partition_example(
half_t* A_ptr, half_t* B_ptr, half_t* C_ptr,
int M, int N, int K)
{
// Create global memory tensors
auto gA = make_tensor(A_ptr, make_shape(M, K));
auto gB = make_tensor(B_ptr, make_shape(N, K));
auto gC = make_tensor(C_ptr, make_shape(M, N));
// Setup TiledMMA
using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
auto tiled_mma = make_tiled_mma(
MMA_Atom{},
make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
);
// Get per-thread MMA
auto thr_mma = tiled_mma.get_slice(threadIdx.x);
// Partition tensors for this thread
auto tCgA = thr_mma.partition_A(gA); // Thread's A data
auto tCgB = thr_mma.partition_B(gB); // Thread's B data
auto tCgC = thr_mma.partition_C(gC); // Thread's C data
// Create register fragments
auto tCrA = thr_mma.partition_fragment_A(gA);
auto tCrB = thr_mma.partition_fragment_B(gB);
auto tCrC = thr_mma.partition_fragment_C(gC);
// Clear accumulator
clear(tCrC);
// Load and compute...
}
Complete GEMM Kernel Outline
template <class TiledMMA>
__global__ void gemm_kernel(
half_t* A, half_t* B, float* C,
int M, int N, int K)
{
// Shared memory
__shared__ half_t smem_A[...];
__shared__ half_t smem_B[...];
// Global memory tensors
auto gA = make_tensor(A, make_shape(M, K));
auto gB = make_tensor(B, make_shape(N, K));
auto gC = make_tensor(C, make_shape(M, N));
// Shared memory tensors
auto sA = make_tensor(make_smem_ptr(smem_A), ...);
auto sB = make_tensor(make_smem_ptr(smem_B), ...);
// Setup TiledMMA
TiledMMA tiled_mma;
auto thr_mma = tiled_mma.get_slice(threadIdx.x);
// Partition
auto tCgC = thr_mma.partition_C(gC);
auto tCsA = thr_mma.partition_A(sA);
auto tCsB = thr_mma.partition_B(sB);
// Fragments
auto tCrA = thr_mma.partition_fragment_A(sA);
auto tCrB = thr_mma.partition_fragment_B(sB);
auto tCrC = thr_mma.partition_fragment_C(gC);
clear(tCrC);
// Main loop
for (int k = 0; k < K; k += TILE_K) {
// Load A and B into shared memory
copy(gA(_, k), sA);
copy(gB(_, k), sB);
__syncthreads();
// Load into registers and compute
copy(tCsA, tCrA);
copy(tCsB, tCrB);
// MMA operation: D = A * B + C
gemm(tiled_mma, tCrA, tCrB, tCrC);
__syncthreads();
}
// Store result
copy(tCrC, tCgC);
}
Multi-stage Pipeline
// Use with async copy and multi-stage pipeline
for (int pipe = 0; pipe < NUM_STAGES-1; ++pipe) {
copy(tiled_copy, gA_stage[pipe], sA_stage[pipe]);
copy(tiled_copy, gB_stage[pipe], sB_stage[pipe]);
cp_async_fence();
}
for (int k = 0; k < NUM_K_TILES; ++k) {
int read_stage = k % NUM_STAGES;
int write_stage = (k + NUM_STAGES - 1) % NUM_STAGES;
cp_async_wait<NUM_STAGES-2>();
__syncthreads();
// Compute on read_stage
copy(tCsA[read_stage], tCrA);
copy(tCsB[read_stage], tCrB);
gemm(tiled_mma, tCrA, tCrB, tCrC);
// Load next tile into write_stage
if (k + NUM_STAGES - 1 < NUM_K_TILES) {
copy(tiled_copy, gA_stage[write_stage], sA_stage[write_stage]);
copy(tiled_copy, gB_stage[write_stage], sB_stage[write_stage]);
cp_async_fence();
}
}
See Also