CuTe: CUDA Template Library
CuTe (CUDA Templates) is the foundational abstraction layer introduced in CUTLASS 3.0. It provides powerful tools for describing and manipulating tensors of threads and data with compile-time layouts.
What is CuTe?
CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. It provides:
Layouts : Compile-time mappings from logical coordinates to linear indices
Tensors : Data structures combining engines (storage) with layouts
Algorithms : Operations on tensors (copy, GEMM, partition, etc.)
CuTe represents a paradigm shift in GPU programming - instead of manually computing indices and strides, you describe the logical structure of your data and let CuTe handle the mechanical bookkeeping.
Core Abstractions
Shapes and Strides
CuTe uses tuple-based shapes and strides to describe multidimensional data:
// Shape: logical size in each dimension
template < class ... Shapes >
using Shape = cute :: tuple <Shapes...>;
// Stride: memory offset between elements
template < class ... Strides >
using Stride = cute :: tuple <Strides...>;
// Coordinate: position within a shape
template < class ... Coords >
using Coord = cute :: tuple <Coords...>;
// Helper functions
auto shape = make_shape ( 128 , 128 ); // 2D shape
auto stride = make_stride ( 1 , 128 ); // Column-major
auto coord = make_coord ( 10 , 5 ); // Position (10, 5)
Reference: include/cute/layout.hpp:47
Layouts
A Layout combines shape and stride to map logical coordinates to linear memory offsets:
template < class Shape , class Stride = LayoutLeft ::Apply< Shape >>
struct Layout : private cute :: tuple < Shape , Stride > {
// Map coordinate to linear index
CUTE_HOST_DEVICE
auto operator () ( Coord const& coord ) const -> Index;
// Access shape and stride
CUTE_HOST_DEVICE
auto shape () const -> Shape const & ;
CUTE_HOST_DEVICE
auto stride () const -> Stride const & ;
};
Reference: include/cute/layout.hpp:98
Example: Row-Major vs Column-Major
// Row-major layout for 4×8 matrix
auto layout_row = make_layout (
make_shape ( 4 , 8 ), // 4 rows, 8 columns
make_stride ( 8 , 1 ) // Row-major: stride-1 in column dimension
);
// Column-major layout for 4×8 matrix
auto layout_col = make_layout (
make_shape ( 4 , 8 ), // 4 rows, 8 columns
make_stride ( 1 , 4 ) // Column-major: stride-1 in row dimension
);
// Map coordinate (2, 3) to linear index
int idx_row = layout_row ( make_coord ( 2 , 3 )); // 2*8 + 3 = 19
int idx_col = layout_col ( make_coord ( 2 , 3 )); // 2 + 3*4 = 14
Layouts are computed entirely at compile-time when possible, enabling zero-overhead abstractions!
Hierarchical Layouts
CuTe supports nested, hierarchical layouts for representing complex data structures:
// Hierarchical shape: 2D threadblock tile subdivided into warp tiles
auto shape = make_shape (
make_shape ( 4 , 8 ), // Outer: 4×8 warp tiles
make_shape ( 32 , 32 ) // Inner: each warp tile is 32×32 elements
);
// Total shape: (4*32) × (8*32) = 128×256
// Access hierarchical coordinates
auto coord = make_coord (
make_coord ( 2 , 3 ), // Warp (2, 3)
make_coord ( 10 , 15 ) // Element (10, 15) within warp
);
Tensors
Tensors combine an engine (storage) with a layout (indexing):
template < class Engine , class Layout >
struct Tensor {
using iterator = typename Engine :: iterator ;
using value_type = typename Engine :: value_type ;
using element_type = typename Engine :: element_type ;
using reference = typename Engine :: reference ;
// Access data pointer
CUTE_HOST_DEVICE
auto data () const -> iterator;
// Access layout
CUTE_HOST_DEVICE
auto layout () const -> Layout const & ;
// Access shape
CUTE_HOST_DEVICE
auto shape () const -> decltype ( layout (). shape ());
// Index into tensor
template < class Coord >
CUTE_HOST_DEVICE
auto operator () ( Coord const& c ) -> reference ;
};
Reference: include/cute/tensor_impl.hpp:135
Tensor Engines
Engines define how tensors store data:
ArrayEngine (Owning)
Allocates and owns data storage:
template < class T , size_t N >
struct ArrayEngine {
using Storage = array_aligned < T , N >;
Storage storage_; // Owned storage
};
// Example: Allocate 256 floats
Tensor < ArrayEngine < float , 256 > , Layout < Shape < _16, _16 >>> tensor;
Reference: include/cute/tensor_impl.hpp:70
ViewEngine (Non-Owning)
References existing data without ownership:
template < class Iterator >
struct ViewEngine {
Iterator storage_; // Pointer to existing data
};
// Example: View existing memory
float * data = get_shared_memory ();
auto layout = make_layout ( make_shape ( 128 , 128 ));
auto tensor = make_tensor (data, layout); // Non-owning view
Reference: include/cute/tensor_impl.hpp:106
Tensor Operations
Creating Tensors
// Create tensor from pointer and layout
float * ptr = ...;
auto layout = make_layout ( make_shape ( 64 , 64 ));
auto tensor = make_tensor (ptr, layout);
// Create tensor with automatic storage
auto tensor_local = make_tensor < float >( make_shape ( 32 , 32 ));
// Create identity layout tensor (unit stride)
auto tensor_1d = make_tensor (ptr, make_shape ( 256 )); // Contiguous
Indexing Tensors
auto tensor = make_tensor (ptr, make_layout (
make_shape ( 128 , 128 ),
make_stride ( 128 , 1 ) // Row-major
));
// Linear indexing
float val = tensor ( 42 ); // 43rd element
// Multidimensional indexing
float val = tensor ( 10 , 20 ); // Row 10, column 20
// Hierarchical indexing
float val = tensor ( make_coord ( make_coord ( 2 , 3 ), make_coord ( 5 , 7 )));
Partitioning Tensors
Partitioning divides tensors across threads or threadblocks:
// Partition tensor across threads
auto tensor_global = make_tensor (ptr, layout);
auto thread_tensor = local_partition (
tensor_global,
make_layout ( make_shape ( 32 )), // 32 threads
thread_idx // This thread's index
);
// Each thread now owns a slice of the original tensor
Tiling Tensors
Tiling restructures tensors into hierarchical tile views:
// Tile a 128×128 tensor into 32×32 tiles
auto tensor = make_tensor (ptr, make_shape ( 128 , 128 ));
auto tiled = zipped_divide (tensor, make_shape ( 32 , 32 ));
// Result shape: ((4, 4), (32, 32))
// Outer mode: 4×4 tiles
// Inner mode: 32×32 elements per tile
Understanding Partition vs Tile
Partition : Distributes tensor elements across threads/threadblocks
Tile : Restructures a tensor into a hierarchical layout
Both : Can be combined to partition tiled tensors
Copy Operations
CuTe provides high-level copy abstractions:
// Copy between tensors with automatic vectorization
auto src = make_tensor (src_ptr, layout);
auto dst = make_tensor (dst_ptr, layout);
copy (src, dst); // Automatically vectorizes when possible
Async Copy (SM80+)
// Asynchronous global → shared memory copy
using CopyOp = SM80_CP_ASYNC_CACHEGLOBAL < cute :: uint128_t >;
auto copy_op = Copy_Atom < CopyOp, float > {};
copy_op . call (src_tensor, dst_tensor);
cp_async_wait < 0 >(); // Wait for completion
TMA (Tensor Memory Accelerator) SM90+
// Hardware-accelerated bulk tensor copy
auto tma = make_tma_copy (
SM90_TMA_LOAD{},
src_tensor,
layout
);
tma . copy (dst_tensor);
Layout Algebra
CuTe layouts support algebraic composition:
Composition
// Compose two layouts: first apply layout_a, then layout_b
auto composed = composition (layout_a, layout_b);
Product
// Cartesian product of layouts
auto prod = make_layout (
make_shape ( 8 , 16 ),
make_stride ( 16 , 1 )
);
Complement
// Find orthogonal layout (unused dimensions)
auto comp = complement (layout, max_size);
Layout algebra enables powerful compile-time transformations that would be error-prone to write manually.
Integration with CUTLASS
CuTe is deeply integrated into CUTLASS 3.x:
// CUTLASS 3.x GEMM using CuTe layouts
using GmemLayoutA = Layout < Shape < _128 , _64 >, Stride < _64 , _1 >>;
using SmemLayoutA = Layout < Shape < _128 , _64 >, Stride < _1 , _128 >>;
// Automatically handles data movement with CuTe copy atoms
using GmemCopyA = Copy_Atom < SM80_CP_ASYNC_CACHEGLOBAL < uint128_t >, float >;
Why CuTe?
Traditional GPU programming:
// Manual index computation (error-prone!)
int idx = blockIdx . x * blockDim . x + threadIdx . x ;
int row = idx / N;
int col = idx % N;
int offset = row * lda + col;
data [offset] = value;
With CuTe:
// Logical, composable operations
auto tensor = make_tensor (data, make_shape (M, N));
tensor (row, col) = value; // CuTe handles the math
Benefits of CuTe:
Type-safe indexing with compile-time checking
Composable abstractions (partition, tile, slice)
Zero-overhead when fully static
Dramatically reduced boilerplate code
Easier to reason about correctness
CuTe DSL (Python)
CUTLASS 4.0 introduced CuTe DSL, a Python interface to CuTe concepts:
import cutlass.cute as cute
# Define layouts in Python
shape = cute.make_shape( 128 , 128 )
stride = cute.make_stride( 1 , 128 )
layout = cute.make_layout(shape, stride)
# Create tensors
tensor = cute.make_tensor(ptr, layout)
# Copy operations
cute.copy(src_tensor, dst_tensor)
Real-World Example
Here’s how CuTe simplifies GEMM implementation:
// Load A tile from global to shared memory
auto gA = make_tensor (
make_gmem_ptr (A_ptr),
make_layout ( make_shape (M, K), make_stride (K, 1 ))
);
auto sA = make_tensor (
make_smem_ptr (shared_A),
make_layout ( make_shape ( 128 , 64 ))
);
// Partition across threadblock
auto tAgA = local_partition (gA, threadblock_layout, blockIdx . x );
auto tAsA = local_partition (sA, threadblock_layout, blockIdx . x );
// Copy with automatic vectorization and coalescing
copy (tAgA, tAsA);
Without CuTe, this would require hundreds of lines of manual indexing logic!
Next Steps
Tensor Cores Learn how CuTe interfaces with Tensor Cores
Memory Layouts Deep dive into layout strategies
GEMM Operations See CuTe in action for GEMM
Examples Explore CuTe code examples