Tensor Core Programming
Tensor Cores are specialized hardware units in modern NVIDIA GPUs that dramatically accelerate matrix multiply-accumulate (MMA) operations. This page explains how CUTLASS leverages Tensor Cores for peak performance.
What are Tensor Cores?
Tensor Cores are dedicated matrix processing units that can perform multiple multiply-accumulate operations in a single instruction. They are the key to achieving peak throughput for deep learning and HPC workloads.
Key Characteristics:
Perform matrix operations on small tiles (e.g., 16×8×16)
Operate at warp granularity (32 threads collaborate)
Deliver 8-16× higher throughput than CUDA cores for matrix math
Support various data types: FP64, FP32, TF32, FP16, BF16, FP8, INT8, INT4
Architecture Evolution
Volta (SM70) - First Generation
Shape : 16×16×4 (M×N×K)
Data Types : FP16 input, FP16/FP32 accumulation
Instructions : wmma API
Turing (SM75) - Second Generation
Shapes : 16×8×8, 8×8×4
Data Types : FP16, INT8, INT4, INT1
Instructions : Enhanced wmma and mma.sync
Ampere (SM80) - Third Generation
New Shapes : 16×8×8, 16×8×16
New Types : BF16, TF32 (19-bit format)
Instructions : mma.sync.aligned
Features : Async copy, structured sparsity (2:4)
Hopper (SM90) - Fourth Generation
New Shapes : 64×64×16, 64×128×16, 64×192×16
New Types : FP8 (E4M3, E5M2)
Instructions : wgmma (warpgroup MMA)
Features : TMA (Tensor Memory Accelerator), Thread Block Clusters
Throughput : Up to 2000 TFLOPS (FP8)
Blackwell (SM100) - Fifth Generation
Enhanced Shapes : Larger warpgroup operations
New Types : FP4, MXFP formats
Features : Enhanced TMA, distributed shared memory
Throughput : Up to 4000 TFLOPS (FP4)
Tensor Core Throughput Comparison
Architecture FP16 TFLOPS FP32 TFLOPS FP8 TFLOPS Volta V100 125 - - Turing T4 65 - - Ampere A100 312 156 (TF32) - Hopper H100 989 495 (TF32) 1979 Blackwell B200 2500 1250 (TF32) 5000
Tensor Core operations are exposed through MMA (Matrix Multiply-Accumulate) instructions. Here’s an example from Ampere (SM80):
template <>
struct Mma <
gemm :: GemmShape < 16 , 8 , 8 >, // Output: 16×8, K=8
32 , // Threads per instruction
bfloat16_t , // A element type
layout :: RowMajor , // A layout
bfloat16_t , // B element type
layout :: ColumnMajor , // B layout
float , // C/D element type
layout :: RowMajor , // C/D layout
OpMultiplyAdd > // Operation
{
using FragmentA = Array < bfloat16_t , 4 >; // A fragment per thread
using FragmentB = Array < bfloat16_t , 2 >; // B fragment per thread
using FragmentC = Array < float , 4 >; // C/D fragment per thread
CUTLASS_HOST_DEVICE
void operator () (
FragmentC & d , // Output
FragmentA const & a , // Input A
FragmentB const & b , // Input B
FragmentC const & c // Input C (accumulator)
) const {
asm (
" mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
" {%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10}; \n "
: " =f " ( D [ 0 ]), " =f " ( D [ 1 ]), " =f " ( D [ 2 ]), " =f " ( D [ 3 ])
: " r " ( A [ 0 ]), " r " ( A [ 1 ]), " r " ( B [ 0 ]),
" f " ( C [ 0 ]), " f " ( C [ 1 ]), " f " ( C [ 2 ]), " f " ( C [ 3 ])
);
}
};
Reference: include/cutlass/arch/mma_sm80.h:76
The instruction name encodes:
m16n8k8: Matrix dimensions (16×8 output, K=8)
row.col: A is row-major, B is column-major
f32.bf16.bf16.f32: Output type, A type, B type, accumulator type
Fragment Layout
Tensor Core operands are distributed across threads in a warp. Understanding fragment layouts is crucial:
Thread-to-Fragment Mapping (16×8×8 example)
Matrix A (16×8):
- Shape per thread: 4 elements (FragmentA)
- Distribution: Each thread holds elements from multiple rows
- Thread 0: [a0, a8, a16, a24]
- Thread 1: [a1, a9, a17, a25]
- ...
Matrix B (8×8):
- Shape per thread: 2 elements (FragmentB)
- Distribution: Each thread holds elements from multiple columns
- Thread 0: [b0, b8]
- Thread 1: [b1, b9]
- ...
Matrix C/D (16×8):
- Shape per thread: 4 elements (FragmentC)
- Distribution: Output elements mapped to threads
CUTLASS handles fragment distribution automatically when you use the provided templates. You rarely need to compute layouts manually!
Warpgroup MMA (Hopper SM90+)
Hopper introduced warpgroup-scoped MMA instructions that enable larger, more efficient operations:
// Warpgroup MMA operates on 4 warps (128 threads) simultaneously
// Shape: 64×64×16 per instruction
template < class ... Args >
CUTE_HOST_DEVICE
void wgmma_m64n64k16 (
Args const& ... args
) {
asm volatile (
" wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 "
" {%0, %1, ..., %63}, %64, %65, p, 1, 1, 0; \n "
: ... // 64 output registers per thread
: ... // A descriptor, B descriptor
);
}
Key Differences
Feature Warp MMA (SM80) Warpgroup MMA (SM90) Threads 32 (1 warp) 128 (4 warps) Max Shape 16×8×16 64×192×16 Memory Shared memory TMA descriptors Scheduling Software Hardware pipelining
Using Tensor Cores in CUTLASS
CUTLASS provides high-level abstractions for Tensor Core programming:
Method 1: CUTLASS Templates (Recommended)
#include "cutlass/gemm/device/gemm.h"
using Gemm = cutlass :: gemm :: device ::Gemm <
cutlass :: half_t , // Element A
cutlass :: layout ::RowMajor, // Layout A
cutlass :: half_t , // Element B
cutlass :: layout ::ColumnMajor, // Layout B
cutlass :: half_t , // Element C
cutlass :: layout ::RowMajor, // Layout C
float , // Element Accumulator
cutlass :: arch ::OpClassTensorOp, // Use Tensor Cores
cutlass :: arch ::Sm80 // Target architecture
> ;
Method 2: CuTe MMA Atoms
CuTe provides composable MMA atoms for fine-grained control:
#include "cute/atom/mma_atom.hpp"
// Define MMA atom for SM80 16x8x8 BF16
using MMA_Atom = MMA_Atom <
SM80_16x8x8_F32BF16BF16F32_TN // Operation
> ;
// Use in tiled MMA
using TiledMMA = TiledMMA <
MMA_Atom,
Layout < Shape < _2, _4, _1 >> , // Repeat: 2×4 atoms
Layout < Shape < _1, _2, _1 >> // Value layout
> ;
// Perform MMA
auto tCrA = partition_fragment_A (mma, tCsA);
auto tCrB = partition_fragment_B (mma, tCsB);
auto tCrC = partition_fragment_C (mma, tCsC);
gemm (mma, tCrA, tCrB, tCrC); // Executes Tensor Core instructions
Data Type Support
Different Tensor Core generations support different types:
FP16 (Half Precision)
// Most widely supported
using ElementA = cutlass :: half_t ;
using ElementB = cutlass :: half_t ;
using ElementC = cutlass :: half_t ;
using ElementAccumulator = float ; // Higher precision accumulation
BF16 (Brain Float16)
// SM80+, better dynamic range than FP16
using ElementA = cutlass :: bfloat16_t ;
using ElementB = cutlass :: bfloat16_t ;
using ElementAccumulator = float ;
TF32 (TensorFloat32)
// SM80+, FP32 input with 19-bit precision
using ElementA = cutlass :: tfloat32_t ;
using ElementB = cutlass :: tfloat32_t ;
using ElementAccumulator = float ;
FP8 (8-bit Float)
// SM89+, highest throughput
using ElementA = cutlass :: float_e4m3_t ; // E4M3 format
using ElementB = cutlass :: float_e5m2_t ; // E5M2 format
using ElementAccumulator = float ;
Block-Scaled Types (SM100+)
// FP4, MXFP4, MXFP6, MXFP8 with per-block scaling
using ElementA = cutlass :: nvfp4_t ;
using ElementB = cutlass :: nvfp4_t ;
using ElementAccumulator = float ;
// Requires separate scale tensors
Data Type Precision vs Throughput
Higher Throughput (Lower Precision):
FP4: 4× FP16 throughput
FP8: 2× FP16 throughput
INT8: 2× FP16 throughput
Higher Precision (Lower Throughput):
FP64: 1/16× FP16 throughput
FP32: 1/2× FP16 throughput (via TF32)
FP16/BF16: Baseline throughput
Choose based on your accuracy requirements!
Structured Sparsity (SM80+)
Ampere introduced 2:4 structured sparsity support:
// For every 4 elements, exactly 2 are non-zero
// Provides 2× effective throughput
using Gemm = cutlass :: gemm :: device ::GemmSparse <
cutlass :: half_t ,
cutlass :: layout ::RowMajor,
cutlass :: half_t ,
cutlass :: layout ::ColumnMajor,
cutlass :: half_t ,
cutlass :: layout ::RowMajor,
float ,
cutlass :: arch ::OpClassTensorOp,
cutlass :: arch ::Sm80,
// Additional sparse configuration...
> ;
Matrix A is compressed, and a metadata tensor indicates which elements are non-zero.
1. Tile Size Selection
Choose tile sizes that are multiples of Tensor Core operation sizes:
// Good: Multiples of Tensor Core shape (16×8)
using ThreadblockShape = cutlass :: gemm :: GemmShape < 128 , 128 , 32 >;
using WarpShape = cutlass :: gemm :: GemmShape < 64 , 64 , 32 >;
// Bad: Not aligned to Tensor Core boundaries
using ThreadblockShape = cutlass :: gemm :: GemmShape < 100 , 100 , 20 >; // Inefficient!
2. Maximizing Occupancy
// Use enough warps per threadblock for high occupancy
using ThreadblockShape = cutlass :: gemm :: GemmShape < 128 , 128 , 32 >;
using WarpShape = cutlass :: gemm :: GemmShape < 32 , 64 , 32 >;
// Warps per threadblock: (128/32) × (128/64) = 4×2 = 8 warps
3. Double Buffering
// Overlap data loading with computation
const int kStages = 2 ; // Double buffering
// Or more stages for deeply pipelined kernels
const int kStages = 4 ; // SM80+
4. Async Copy (SM80+)
// Use cp.async to overlap global→shared memory copies
using GlobalToSharedCopyA =
Copy_Atom < SM80_CP_ASYNC_CACHEGLOBAL < cute :: uint128_t > , ElementA > ;
Performance Checklist:
✓ Tile sizes are multiples of Tensor Core shapes
✓ High occupancy (8+ warps per SM)
✓ Multi-stage pipeline (2-4 stages)
✓ Vectorized memory accesses (128-bit when possible)
✓ Async copy for SM80+ targets
✓ TMA for SM90+ targets
Common Pitfalls
1. Misaligned Data
// Bad: Pointer not aligned to 16 bytes
float16_t * A = allocate (M * K);
// Good: Ensure alignment
alignedPtr < float16_t , 16 > A = aligned_allocate (M * K);
2. Wrong Layout
// MMA instruction expects specific layouts
// Check that your tensor layouts match the MMA atom requirements
// For SM80_16x8x8_F32FP16FP16F32_TN:
// - A: Row-major (T = transposed in MMA terminology)
// - B: Column-major (N = non-transposed)
3. Incorrect Fragment Distribution
// Let CUTLASS handle fragment distribution
// Don't try to manually shuffle data between threads!
// Good: Use CUTLASS abstractions
auto tCrA = partition_fragment_A (mma, tCsA);
// Bad: Manual shuffling (error-prone)
// for (int t = 0; t < 32; ++t) { /* complex logic */ }
Debugging Tensor Core Kernels
// Print tensor shapes and layouts
#include "cute/util/print.hpp"
if ( thread0 ()) {
print ( "MMA shape: " );
print ( typename TiledMMA :: Shape_MNK {});
print ( " \n " );
print ( "Thread layout: " );
print ( typename TiledMMA :: ThrLayout {});
print ( " \n " );
}
Real-World Example
Complete Tensor Core GEMM snippet:
// SM80 FP16 Tensor Core GEMM
using MmaOp = SM80_16x8x16_F32F16F16F32_TN ;
using TiledMma = TiledMMA <
MMA_Atom < MmaOp > ,
Layout < Shape < _2,_2,_1 >> // 2×2 warp arrangement
> ;
// Partition inputs
auto tCrA = thr_mma . partition_fragment_A ( gA (_, _, k_tile));
auto tCrB = thr_mma . partition_fragment_B ( gB (_, _, k_tile));
auto tCrC = partition_fragment_C (thr_mma, make_shape (M, N));
clear (tCrC);
// Main loop
for ( int k = 0 ; k < K; k += kTileK) {
// Load to registers
copy ( tCgA (_, _, k), tCrA);
copy ( tCgB (_, _, k), tCrB);
// Tensor Core MMA
gemm (tiled_mma, tCrA, tCrB, tCrC);
}
// Store output
copy (tCrC, tCgC);
Next Steps
Memory Layouts Optimize data layouts for Tensor Cores
CuTe Library Use CuTe abstractions for MMA operations
Examples Explore Tensor Core code examples
GEMM Operations Build complete GEMM kernels