Skip to main content
The CuTe DSL provides Python decorators to define GPU kernels and JIT-compiled host functions. These decorators transform Python functions into CUDA kernels or optimized host code.

Decorators

@cute.kernel

Defines a GPU kernel function that will be JIT-compiled and executed on the device.
import cutlass.cute as cute

@cute.kernel
def elementwise_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor,
    gC: cute.Tensor,
    shape: cute.Shape,
    thr_layout: cute.Layout,
    val_layout: cute.Layout,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    
    # Kernel implementation
    blkA = gA[((None, None), bidx)]
    blkB = gB[((None, None), bidx)]
    blkC = gC[((None, None), bidx)]
    
    # Perform computation
    copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type)
    tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout)
gA
cute.Tensor
Input tensor A
gB
cute.Tensor
Input tensor B
gC
cute.Tensor
Output tensor C
shape
cute.Shape
Shape of the tensors
thr_layout
cute.Layout
Thread layout for partitioning
val_layout
cute.Layout
Value layout for vectorization
Launching Kernels:
kernel_obj = elementwise_add_kernel(gA, gB, gC, shape, thr_layout, val_layout)
kernel_obj.launch(grid=(num_blocks, 1, 1), block=(num_threads, 1, 1))

@cute.jit

Defines a JIT-compiled host function that can call GPU kernels.
@cute.jit
def elementwise_add(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
    stream: cuda.CUstream,
):
    # Setup layouts
    thr_layout = cute.make_layout((4, 32), stride=(32, 1))
    val_layout = cute.make_layout((4, 4), stride=(4, 1))
    
    # Call kernel
    elementwise_add_kernel(
        mA, mB, mC, 
        mA.shape, 
        thr_layout, 
        val_layout
    ).launch(
        grid=(grid_size, 1, 1),
        block=(128, 1, 1),
        stream=stream
    )
mA
cute.Tensor
Input tensor A
mB
cute.Tensor
Input tensor B
mC
cute.Tensor
Output tensor C
stream
cuda.CUstream
CUDA stream for execution

Launch Configuration

Kernel launches use the .launch() method with grid and block dimensions:
kernel_obj.launch(
    grid=(grid_x, grid_y, grid_z),   # CTA grid dimensions
    block=(block_x, block_y, block_z), # Thread block dimensions
    smem=shared_memory_bytes,          # Optional: shared memory size
    stream=cuda_stream                 # Optional: CUDA stream
)
grid
tuple[int, int, int]
Grid dimensions (number of thread blocks in x, y, z)
block
tuple[int, int, int]
Block dimensions (number of threads per block in x, y, z)
smem
int
default:"None"
Shared memory size in bytes (auto-calculated if None)
stream
cuda.CUstream
default:"default stream"
CUDA stream for asynchronous execution

Compile-Time Constants

Use cutlass.Constexpr for values known at compile time:
@cute.jit
def gemm_wrapper(
    mA: cute.Tensor,
    mB: cute.Tensor,
    mC: cute.Tensor,
    epilogue_op: cutlass.Constexpr = lambda x: x,
    stream: cuda.CUstream = None,
):
    # epilogue_op is evaluated at compile time
    if cutlass.const_expr(epilogue_op is None):
        # Branch eliminated at compile time
        pass

Thread and Block Indexing

Access CUDA thread and block indices within kernels:
@cute.kernel
def my_kernel(data: cute.Tensor):
    tidx, tidy, tidz = cute.arch.thread_idx()
    bidx, bidy, bidz = cute.arch.block_idx()
    
    # Global thread index
    global_tid = bidx * blockDim_x + tidx

Complete Example

import cutlass.cute as cute
import cuda.bindings.driver as cuda

@cute.kernel
def vector_add_kernel(
    gA: cute.Tensor,
    gB: cute.Tensor, 
    gC: cute.Tensor,
    n: int,
):
    tidx, _, _ = cute.arch.thread_idx()
    bidx, _, _ = cute.arch.block_idx()
    
    idx = bidx * 256 + tidx
    if idx < n:
        gC[idx] = gA[idx] + gB[idx]

@cute.jit
def vector_add(
    A: cute.Tensor,
    B: cute.Tensor,
    C: cute.Tensor,
    stream: cuda.CUstream,
):
    n = cute.size(A)
    num_threads = 256
    num_blocks = (n + num_threads - 1) // num_threads
    
    vector_add_kernel(A, B, C, n).launch(
        grid=(num_blocks, 1, 1),
        block=(num_threads, 1, 1),
        stream=stream
    )

See Also

Build docs developers (and LLMs) love