Skip to main content
The CuTe DSL provides high-level operations for matrix multiplication, memory transfers, and tensor manipulation.

GEMM Operations

gemm

Performs General Matrix Multiply-Accumulate (GEMM) operation: D = A * B + C.
cute.gemm(
    atom=mma_atom,
    d=dest_tensor,
    a=tensor_a,
    b=tensor_b,
    c=accum_tensor
)
atom
MmaAtom
required
MMA atom defining the computation pattern (e.g., warp-level or warpgroup-level MMA)
d
Tensor
required
Destination tensor (output)
a
Union[Tensor, List[Tensor]]
required
First source tensor or [tensor, scale_factor] for block-scaled GEMM
b
Union[Tensor, List[Tensor]]
required
Second source tensor or [tensor, scale_factor] for block-scaled GEMM
c
Tensor
required
Accumulator tensor (can alias with d)
Dispatch Rules:
# Dispatch [1]: (V) x (V) => (V)
# Dispatch [2]: (M) x (N) => (M,N)
# Dispatch [3]: (M,K) x (N,K) => (M,N)
# Dispatch [4]: (V,M) x (V,N) => (V,M,N)
# Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
Example:
# Create MMA atom
mma_atom = cute.make_mma_atom(
    cute.nvgpu.warp.MMA_F32F32F32_M16N8K8_TN()
)

# Partition tensors
tiled_mma = cute.make_tiled_mma(mma_atom)
thr_mma = tiled_mma.get_slice(thread_idx)

# Partition operands
thrA = thr_mma.partition_A(blockA)
thrB = thr_mma.partition_B(blockB)
thrC = thr_mma.partition_C(blockC)

# Perform GEMM
cute.gemm(mma_atom, thrC, thrA, thrB, thrC)

Copy Operations

copy

Performs atomic or tiled copy from source to destination tensor.
cute.copy(copy_atom, src, dst)
copy_atom
CopyAtom
required
Copy atom defining the copy pattern
src
Tensor
required
Source tensor
dst
Tensor
required
Destination tensor
Example:
# Create copy atom for global memory
copy_atom = cute.make_copy_atom(
    cute.nvgpu.CopyUniversalOp(),
    element_type=cute.Float32
)

# Create tiled copy
tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)

# Get thread slice
thr_copy = tiled_copy.get_slice(thread_idx)

# Partition tensors
thrA_src = thr_copy.partition_S(blockA)
thrA_dst = thr_copy.partition_D(fragmentA)

# Execute copy
cute.copy(copy_atom, thrA_src, thrA_dst)

basic_copy

Performs basic element-wise copy without atoms (for simple cases).
cute.basic_copy(src, dst)
src
Tensor
required
Source tensor
dst
Tensor
required
Destination tensor (must have same size as src)
Example:
# Simple element-wise copy
fragment = cute.make_fragment_like(tensor)
cute.basic_copy(tensor, fragment)

basic_copy_if

Performs predicated element-wise copy.
cute.basic_copy_if(pred, src, dst)
pred
Tensor
required
Predicate tensor (boolean values)
src
Tensor
required
Source tensor
dst
Tensor
required
Destination tensor
Example:
# Copy only valid elements
coord_tensor = cute.make_identity_tensor(shape)
pred = coord_tensor < shape  # Out-of-bounds predicate

cute.basic_copy_if(pred, src, dst)

autovec_copy

Performs vectorized copy with automatic vectorization.
cute.autovec_copy(src, dst)
Automatically vectorizes the copy based on tensor alignment and element type.

prefetch

Prefetches data into cache.
cute.prefetch(tensor, cache_level=cute.CacheEvictionPriority.EVICT_FIRST)
tensor
Tensor
required
Tensor to prefetch
cache_level
CacheEvictionPriority
default:"EVICT_FIRST"
Cache eviction priority

Tensor Operations

Prints tensor contents (for debugging).
cute.print_tensor(tensor)
Example:
@cute.kernel
def debug_kernel(tensor: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    if tidx == 0:
        cute.print_tensor(tensor)

printf

Prints formatted output from device code.
cute.printf("Thread {}: value = {}\n", thread_id, value)
Example:
@cute.kernel
def debug_kernel():
    tidx, _, _ = cute.arch.thread_idx()
    cute.printf("Thread {}: starting computation\n", tidx)

Reduction Operations

any_

Returns True if any element in the tensor is non-zero.
result = cute.any_(tensor)

all_

Returns True if all elements in the tensor are non-zero.
result = cute.all_(tensor)

Conditional Operations

where

Element-wise conditional selection.
result = cute.where(condition, true_value, false_value)
condition
Tensor
required
Boolean condition tensor
true_value
Tensor
required
Values to select when condition is True
false_value
Tensor
required
Values to select when condition is False
Example:
# Clamp values
result = cute.where(tensor > max_val, max_val, tensor)
result = cute.where(result < min_val, min_val, result)

Control Flow

for_generate

Generates loop iterations in device code.
from cutlass.cutlass_dsl import for_generate, yield_out

for i in for_generate(start, end):
    # Loop body
    tensor[i] = compute(i)
    yield_out()  # Required at end of loop body

if_generate

Generates conditional branches in device code.
from cutlass.cutlass_dsl import if_generate

if_generate(
    condition,
    lambda: true_branch(),
    lambda: false_branch()  # Optional
)

Complete GEMM Example

import cutlass.cute as cute

@cute.kernel
def gemm_kernel(
    tiled_mma: cute.TiledMma,
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    
    # Get thread's portion of MMA
    thr_mma = tiled_mma.get_slice(tidx)
    
    # Partition block tensors
    blkA = mA[((None, None, None), bidx)]
    blkB = mB[((None, None, None), bidx)]
    blkC = mC[((None, None), bidx)]
    
    # Partition to threads
    thrA = thr_mma.partition_A(blkA)
    thrB = thr_mma.partition_B(blkB)
    thrC = thr_mma.partition_C(blkC)
    
    # Allocate register fragments
    fragA = cute.make_fragment_like(thrA[None, None, 0])
    fragB = cute.make_fragment_like(thrB[None, None, 0])
    fragC = cute.make_fragment_like(thrC)
    
    # Copy accumulator
    cute.basic_copy(thrC, fragC)
    
    # Main loop
    for k in range(K_tiles):
        # Load A and B
        cute.basic_copy(thrA[None, None, k], fragA)
        cute.basic_copy(thrB[None, None, k], fragB)
        
        # Perform GEMM
        cute.gemm(tiled_mma.atom, fragC, fragA, fragB, fragC)
    
    # Store result
    cute.basic_copy(fragC, thrC)

See Also

Build docs developers (and LLMs) love