Skip to main content

Overview

cute::TiledMma represents a tiled matrix multiply-accumulate operation that partitions an MMA across multiple threads and values. It builds on MMA_Atom (hardware-level MMA instructions) by tiling them in the M, N, and K dimensions.

Class Template

template <class MMA_Atom,
          class AtomLayoutMNK,
          class PermutationMNK = Tile<Underscore,Underscore,Underscore>>
struct TiledMMA : MMA_Atom
{
  using Atom           = MMA_Atom;
  using AtomShape_MNK  = typename MMA_Atom::Shape_MNK;
  using AtomThrID      = typename MMA_Atom::ThrID;
  
  using ThrLayoutVMNK  = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{}));
  
  ThrLayoutVMNK thr_layout_vmnk_;
  
  CUTE_HOST_DEVICE constexpr
  TiledMMA(MMA_Atom const& mma_atom = {}, 
           AtomLayoutMNK const& thr_layout_mnk = {});
};

Template Parameters

MMA_Atom
MMA_Atom<MMAOperation>
The atomic MMA operation (e.g., SM80_16x8x16_F16F16F16F16_TN for Ampere Tensor Cores).
AtomLayoutMNK
Layout
The MNK-tiling of the Atom. Specifies how many atoms are tiled in M, N, and K dimensions.
PermutationMNK
Tile
default:"Tile<_,_,_>"
Optional permutations to apply to each MNK mode before tiling.

Source Location

include/cute/atom/mma_atom.hpp:202-457

Member Types

Value Types

using ValTypeD = typename Traits::ValTypeD;  // Accumulator type
using ValTypeA = typename Traits::ValTypeA;  // A operand type
using ValTypeB = typename Traits::ValTypeB;  // B operand type
using ValTypeC = typename Traits::ValTypeC;  // C operand type

Fragment Types

using FrgTypeD = /* ... */;  // D fragment type
using FrgTypeA = /* ... */;  // A fragment type
using FrgTypeB = /* ... */;  // B fragment type
using FrgTypeC = /* ... */;  // C fragment type

Layout Types

using Shape_MNK   = /* ... */;  // Shape of a single atom (M, N, K)
using LayoutC_TV  = /* ... */;  // (thread, value) -> C coordinate
using LayoutA_TV  = /* ... */;  // (thread, value) -> A coordinate
using LayoutB_TV  = /* ... */;  // (thread, value) -> B coordinate

Member Functions

Tensor Partitioning

thrfrg_C()

template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_C(CTensor&& ctensor) const;
Tiles a C tensor from shape (M,N,...) to shape ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))). Parameters:
  • ctensor: The C/D tensor to partition
Returns: Partitioned tensor with thread and fragment modes separated Example:
auto gmem_C = make_tensor(ptr_C, make_shape(128, 128));
auto thr_layout = tiled_mma.thrfrg_C(gmem_C);
// Shape: ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN)))

thrfrg_A()

template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_A(ATensor&& atensor) const;
Tiles an A tensor from shape (M,K,...) to shape ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))). Parameters:
  • atensor: The A operand tensor to partition
Returns: Partitioned tensor with thread and fragment modes

thrfrg_B()

template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
thrfrg_B(BTensor&& btensor) const;
Tiles a B tensor from shape (N,K,...) to shape ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))). Parameters:
  • btensor: The B operand tensor to partition
Returns: Partitioned tensor with thread and fragment modes

Thread Slicing

get_slice()

template <class ThrIdx>
CUTE_HOST_DEVICE constexpr auto
get_slice(ThrIdx const& thr_idx) const;
Returns a ThrMMA object for a specific thread index. Parameters:
  • thr_idx: Thread index within the TiledMMA
Returns: ThrMMA<TiledMMA, ThrCoord> for per-thread operations Example:
int tid = threadIdx.x;
auto thr_mma = tiled_mma.get_slice(tid);

get_thread_slice()

template <class ThrIdx>
CUTE_HOST_DEVICE constexpr auto
get_thread_slice(ThrIdx const& thr_idx) const;
Alias for get_slice(). Returns per-thread MMA object.

Layout Accessors

get_layoutC_TV()

CUTE_HOST_DEVICE constexpr auto
get_layoutC_TV() const;
Returns the (thread_idx, value_idx) -> (M, N) layout for the C matrix.

get_layoutA_TV()

CUTE_HOST_DEVICE constexpr auto
get_layoutA_TV() const;
Returns the (thread_idx, value_idx) -> (M, K) layout for the A matrix.

get_layoutB_TV()

CUTE_HOST_DEVICE constexpr auto
get_layoutB_TV() const;
Returns the (thread_idx, value_idx) -> (N, K) layout for the B matrix.

Utility

tile_size_mnk()

template <int I>
CUTE_HOST_DEVICE constexpr auto
tile_size_mnk() const;
Returns the total tile size in the M (I=0), N (I=1), or K (I=2) dimension. Example:
auto M_tile = tiled_mma.tile_size_mnk<0>();  // Total M extent
auto N_tile = tiled_mma.tile_size_mnk<1>();  // Total N extent
auto K_tile = tiled_mma.tile_size_mnk<2>();  // Total K extent

ThrMMA - Per-Thread MMA

template <class TiledMMA, class ThrCoord>
struct ThrMMA : TiledMMA
{
  ThrCoord thr_vmnk_;
};
Represents the MMA operation for a single thread.

ThrMMA Methods

partition_C()

template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
partition_C(CTensor&& ctensor) const;
Partitions the C tensor for this thread. Returns: Per-thread view of C with shape (FrgV,(RestM,RestN,...))

partition_A()

template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
partition_A(ATensor&& atensor) const;
Partitions the A tensor for this thread. Returns: Per-thread view of A with shape (FrgV,(RestM,RestK,...))

partition_B()

template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
partition_B(BTensor&& btensor) const;
Partitions the B tensor for this thread. Returns: Per-thread view of B with shape (FrgV,(RestN,RestK,...))

partition_fragment_C()

template <class CTensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_C(CTensor&& ctensor) const;
Partitions C and creates a register fragment. Returns: Register tensor fragment for C accumulation

partition_fragment_A()

template <class ATensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_A(ATensor&& atensor) const;
Partitions A and creates a register fragment.

partition_fragment_B()

template <class BTensor>
CUTE_HOST_DEVICE constexpr auto
partition_fragment_B(BTensor&& btensor) const;
Partitions B and creates a register fragment.

Factory Functions

make_tiled_mma()

template <class MMA_Op,
          class MMAThrLayout = Layout<Shape<_1,_1,_1>>,
          class Permutations = Tile<Underscore,Underscore,Underscore>>
CUTE_HOST_DEVICE constexpr auto
make_tiled_mma(MMA_Atom<MMA_Op> const& mma_atom,
               MMAThrLayout     const& thr_layout   = {},
               Permutations     const& permutations = {});
Creates a TiledMMA from an MMA atom and thread layout. Parameters:
  • mma_atom: The atomic MMA operation
  • thr_layout: Thread layout in (M, N, K) specifying the tiling
  • permutations: Optional permutations for each mode
Example:
using namespace cute;

// Single SM80 Tensor Core MMA
using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;

// Tile 2x2 in M and N, 1 in K
auto tiled_mma = make_tiled_mma(
  MMA_Atom{},
  make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
);
// This creates a 32x16x16 MMA using 4 atoms (2x2x1)

Common MMA Atoms

Ampere (SM80) Tensor Cores

// FP16 Tensor Core MMA
SM80_16x8x16_F16F16F16F16_TN    // 16x8x16 FP16 output
SM80_16x8x8_F32F16F16F32_TN     // 16x8x8  FP32 output

// TF32 Tensor Core MMA  
SM80_16x8x8_F32TF32TF32F32_TN   // 16x8x8 TF32 -> FP32

// INT8 Tensor Core MMA
SM80_16x8x32_S32S8S8S32_TN      // 16x8x32 INT8 -> INT32

Hopper (SM90) Tensor Cores

// Warp-group MMA (WGMMA)
SM90_64x8x16_F16F16F16_SS       // 64x8x16 shared->shared
SM90_64x16x16_F16F16F16_SS      // 64x16x16
SM90_64x32x16_F16F16F16_SS      // 64x32x16

// With FP8
SM90_64x16x32_F16E4M3E4M3_SS    // FP8 E4M3 inputs
SM90_64x16x32_F16E5M2E5M2_SS    // FP8 E5M2 inputs

Usage Examples

Basic TiledMMA Setup

#include <cute/tensor.hpp>
#include <cute/atom/mma_atom.hpp>

using namespace cute;

// Define MMA atom for SM80 Tensor Cores
using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;

// Create a 2x2 tiled MMA (32x16x16)
auto tiled_mma = make_tiled_mma(
  MMA_Atom{},
  make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
);

// Get thread-level MMA
int tid = threadIdx.x;
auto thr_mma = tiled_mma.get_slice(tid);

Partitioning Tensors

__device__ void gemm_partition_example(
  half_t* A_ptr, half_t* B_ptr, half_t* C_ptr,
  int M, int N, int K)
{
  // Create global memory tensors
  auto gA = make_tensor(A_ptr, make_shape(M, K));
  auto gB = make_tensor(B_ptr, make_shape(N, K));
  auto gC = make_tensor(C_ptr, make_shape(M, N));
  
  // Setup TiledMMA
  using MMA_Atom = MMA_Atom<SM80_16x8x16_F16F16F16F16_TN>;
  auto tiled_mma = make_tiled_mma(
    MMA_Atom{},
    make_layout(make_shape(Int<2>{}, Int<2>{}, Int<1>{}))
  );
  
  // Get per-thread MMA
  auto thr_mma = tiled_mma.get_slice(threadIdx.x);
  
  // Partition tensors for this thread
  auto tCgA = thr_mma.partition_A(gA);  // Thread's A data
  auto tCgB = thr_mma.partition_B(gB);  // Thread's B data
  auto tCgC = thr_mma.partition_C(gC);  // Thread's C data
  
  // Create register fragments
  auto tCrA = thr_mma.partition_fragment_A(gA);
  auto tCrB = thr_mma.partition_fragment_B(gB);
  auto tCrC = thr_mma.partition_fragment_C(gC);
  
  // Clear accumulator
  clear(tCrC);
  
  // Load and compute...
}

Complete GEMM Kernel Outline

template <class TiledMMA>
__global__ void gemm_kernel(
  half_t* A, half_t* B, float* C,
  int M, int N, int K)
{
  // Shared memory
  __shared__ half_t smem_A[...];
  __shared__ half_t smem_B[...];
  
  // Global memory tensors
  auto gA = make_tensor(A, make_shape(M, K));
  auto gB = make_tensor(B, make_shape(N, K));
  auto gC = make_tensor(C, make_shape(M, N));
  
  // Shared memory tensors
  auto sA = make_tensor(make_smem_ptr(smem_A), ...);
  auto sB = make_tensor(make_smem_ptr(smem_B), ...);
  
  // Setup TiledMMA
  TiledMMA tiled_mma;
  auto thr_mma = tiled_mma.get_slice(threadIdx.x);
  
  // Partition
  auto tCgC = thr_mma.partition_C(gC);
  auto tCsA = thr_mma.partition_A(sA);
  auto tCsB = thr_mma.partition_B(sB);
  
  // Fragments
  auto tCrA = thr_mma.partition_fragment_A(sA);
  auto tCrB = thr_mma.partition_fragment_B(sB);
  auto tCrC = thr_mma.partition_fragment_C(gC);
  
  clear(tCrC);
  
  // Main loop
  for (int k = 0; k < K; k += TILE_K) {
    // Load A and B into shared memory
    copy(gA(_, k), sA);
    copy(gB(_, k), sB);
    __syncthreads();
    
    // Load into registers and compute
    copy(tCsA, tCrA);
    copy(tCsB, tCrB);
    
    // MMA operation: D = A * B + C
    gemm(tiled_mma, tCrA, tCrB, tCrC);
    
    __syncthreads();
  }
  
  // Store result
  copy(tCrC, tCgC);
}

Multi-stage Pipeline

// Use with async copy and multi-stage pipeline
for (int pipe = 0; pipe < NUM_STAGES-1; ++pipe) {
  copy(tiled_copy, gA_stage[pipe], sA_stage[pipe]);
  copy(tiled_copy, gB_stage[pipe], sB_stage[pipe]);
  cp_async_fence();
}

for (int k = 0; k < NUM_K_TILES; ++k) {
  int read_stage = k % NUM_STAGES;
  int write_stage = (k + NUM_STAGES - 1) % NUM_STAGES;
  
  cp_async_wait<NUM_STAGES-2>();
  __syncthreads();
  
  // Compute on read_stage
  copy(tCsA[read_stage], tCrA);
  copy(tCsB[read_stage], tCrB);
  gemm(tiled_mma, tCrA, tCrB, tCrC);
  
  // Load next tile into write_stage
  if (k + NUM_STAGES - 1 < NUM_K_TILES) {
    copy(tiled_copy, gA_stage[write_stage], sA_stage[write_stage]);
    copy(tiled_copy, gB_stage[write_stage], sB_stage[write_stage]);
    cp_async_fence();
  }
}

See Also

Build docs developers (and LLMs) love