The CuTe DSL provides high-level operations for matrix multiplication, memory transfers, and tensor manipulation.
GEMM Operations
gemm
Performs General Matrix Multiply-Accumulate (GEMM) operation: D = A * B + C.
cute.gemm(
atom=mma_atom,
d=dest_tensor,
a=tensor_a,
b=tensor_b,
c=accum_tensor
)
MMA atom defining the computation pattern (e.g., warp-level or warpgroup-level MMA)
Destination tensor (output)
a
Union[Tensor, List[Tensor]]
required
First source tensor or [tensor, scale_factor] for block-scaled GEMM
b
Union[Tensor, List[Tensor]]
required
Second source tensor or [tensor, scale_factor] for block-scaled GEMM
Accumulator tensor (can alias with d)
Dispatch Rules:
# Dispatch [1]: (V) x (V) => (V)
# Dispatch [2]: (M) x (N) => (M,N)
# Dispatch [3]: (M,K) x (N,K) => (M,N)
# Dispatch [4]: (V,M) x (V,N) => (V,M,N)
# Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N)
Example:
# Create MMA atom
mma_atom = cute.make_mma_atom(
cute.nvgpu.warp.MMA_F32F32F32_M16N8K8_TN()
)
# Partition tensors
tiled_mma = cute.make_tiled_mma(mma_atom)
thr_mma = tiled_mma.get_slice(thread_idx)
# Partition operands
thrA = thr_mma.partition_A(blockA)
thrB = thr_mma.partition_B(blockB)
thrC = thr_mma.partition_C(blockC)
# Perform GEMM
cute.gemm(mma_atom, thrC, thrA, thrB, thrC)
Copy Operations
copy
Performs atomic or tiled copy from source to destination tensor.
cute.copy(copy_atom, src, dst)
Copy atom defining the copy pattern
Example:
# Create copy atom for global memory
copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
element_type=cute.Float32
)
# Create tiled copy
tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
# Get thread slice
thr_copy = tiled_copy.get_slice(thread_idx)
# Partition tensors
thrA_src = thr_copy.partition_S(blockA)
thrA_dst = thr_copy.partition_D(fragmentA)
# Execute copy
cute.copy(copy_atom, thrA_src, thrA_dst)
basic_copy
Performs basic element-wise copy without atoms (for simple cases).
cute.basic_copy(src, dst)
Destination tensor (must have same size as src)
Example:
# Simple element-wise copy
fragment = cute.make_fragment_like(tensor)
cute.basic_copy(tensor, fragment)
basic_copy_if
Performs predicated element-wise copy.
cute.basic_copy_if(pred, src, dst)
Predicate tensor (boolean values)
Example:
# Copy only valid elements
coord_tensor = cute.make_identity_tensor(shape)
pred = coord_tensor < shape # Out-of-bounds predicate
cute.basic_copy_if(pred, src, dst)
autovec_copy
Performs vectorized copy with automatic vectorization.
cute.autovec_copy(src, dst)
Automatically vectorizes the copy based on tensor alignment and element type.
prefetch
Prefetches data into cache.
cute.prefetch(tensor, cache_level=cute.CacheEvictionPriority.EVICT_FIRST)
cache_level
CacheEvictionPriority
default:"EVICT_FIRST"
Cache eviction priority
Tensor Operations
print_tensor
Prints tensor contents (for debugging).
cute.print_tensor(tensor)
Example:
@cute.kernel
def debug_kernel(tensor: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
cute.print_tensor(tensor)
printf
Prints formatted output from device code.
cute.printf("Thread {}: value = {}\n", thread_id, value)
Example:
@cute.kernel
def debug_kernel():
tidx, _, _ = cute.arch.thread_idx()
cute.printf("Thread {}: starting computation\n", tidx)
Reduction Operations
any_
Returns True if any element in the tensor is non-zero.
result = cute.any_(tensor)
all_
Returns True if all elements in the tensor are non-zero.
result = cute.all_(tensor)
Conditional Operations
where
Element-wise conditional selection.
result = cute.where(condition, true_value, false_value)
Values to select when condition is True
Values to select when condition is False
Example:
# Clamp values
result = cute.where(tensor > max_val, max_val, tensor)
result = cute.where(result < min_val, min_val, result)
Control Flow
for_generate
Generates loop iterations in device code.
from cutlass.cutlass_dsl import for_generate, yield_out
for i in for_generate(start, end):
# Loop body
tensor[i] = compute(i)
yield_out() # Required at end of loop body
if_generate
Generates conditional branches in device code.
from cutlass.cutlass_dsl import if_generate
if_generate(
condition,
lambda: true_branch(),
lambda: false_branch() # Optional
)
Complete GEMM Example
import cutlass.cute as cute
@cute.kernel
def gemm_kernel(
tiled_mma: cute.TiledMma,
mA: cute.Tensor,
mB: cute.Tensor,
mC: cute.Tensor,
):
tidx, _, _ = cute.arch.thread_idx()
bidx, _, _ = cute.arch.block_idx()
# Get thread's portion of MMA
thr_mma = tiled_mma.get_slice(tidx)
# Partition block tensors
blkA = mA[((None, None, None), bidx)]
blkB = mB[((None, None, None), bidx)]
blkC = mC[((None, None), bidx)]
# Partition to threads
thrA = thr_mma.partition_A(blkA)
thrB = thr_mma.partition_B(blkB)
thrC = thr_mma.partition_C(blkC)
# Allocate register fragments
fragA = cute.make_fragment_like(thrA[None, None, 0])
fragB = cute.make_fragment_like(thrB[None, None, 0])
fragC = cute.make_fragment_like(thrC)
# Copy accumulator
cute.basic_copy(thrC, fragC)
# Main loop
for k in range(K_tiles):
# Load A and B
cute.basic_copy(thrA[None, None, k], fragA)
cute.basic_copy(thrB[None, None, k], fragB)
# Perform GEMM
cute.gemm(tiled_mma.atom, fragC, fragA, fragB, fragC)
# Store result
cute.basic_copy(fragC, thrC)
See Also