The NVIDIA Ampere architecture (compute capability 8.0 and 8.6) introduced significant improvements to Tensor Core operations, including TF32, enhanced BF16 support, and asynchronous memory copy.
Supported GPUs
| GPU Model | Compute Capability | Min CUDA Toolkit |
|---|
| NVIDIA A100 Tensor Core GPU | 8.0 | 11.4 |
| NVIDIA A10 | 8.6 | 11.4 |
| NVIDIA GeForce RTX 30x0 series | 8.6 | 11.4 |
Ampere was the first architecture to support TensorFloat-32 (TF32), enabling FP32 acceleration through Tensor Cores without code changes.
Key Features
1. TensorFloat-32 (TF32)
TF32 is a 19-bit floating point format that provides the range of FP32 with the performance of FP16 Tensor Cores.
// From include/cutlass/arch/mma_sm80.h
template <>
struct Mma<
gemm::GemmShape<16, 8, 4>,
32,
tfloat32_t,
layout::RowMajor,
tfloat32_t,
layout::ColumnMajor,
float,
layout::RowMajor,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<16, 8, 4>;
using ArchTag = arch::Sm80;
CUTLASS_HOST_DEVICE
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
FragmentC const &c) const {
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#endif
}
};
Benefits:
- 8x throughput vs FP32 on CUDA cores
- No explicit conversion required
- Maintains FP32 dynamic range
2. BFloat16 (BF16) Support
BF16 provides better numeric stability than FP16 for many workloads:
// BF16 x BF16 = FP32 accumulation
template <>
struct Mma<
gemm::GemmShape<16, 8, 8>,
32,
bfloat16_t,
layout::RowMajor,
bfloat16_t,
layout::ColumnMajor,
float,
layout::RowMajor,
OpMultiplyAdd> {
using ArchTag = arch::Sm80;
CUTLASS_HOST_DEVICE
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
FragmentC const &c) const {
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B[0]),
"f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
#endif
}
};
3. FP64 Tensor Cores
Ampere introduced the first FP64 Tensor Core operations:
// From include/cutlass/arch/mma_sm80.h
// Matrix multiply-add operation: F64 = F64 * F64 + F64
template <>
struct Mma<
gemm::GemmShape<8,8,4>,
32,
double,
layout::RowMajor,
double,
layout::ColumnMajor,
double,
layout::RowMajor,
OpMultiplyAdd> {
using Shape = gemm::GemmShape<8,8,4>;
using ArchTag = arch::Sm80;
CUTLASS_HOST_DEVICE
void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b,
FragmentC const &c) const {
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 "
"{%0,%1}, {%2}, {%3}, {%4,%5};\n"
: "=d"(D[0]), "=d"(D[1])
: "d"(A), "d"(B), "d"(C[0]), "d"(C[1]));
#endif
}
};
Shape: 8x8x4
- 8 rows × 8 columns output
- 4 elements in K dimension
- 2x throughput vs FP64 CUDA cores
4. Asynchronous Copy (cp.async)
Ampere introduced cp.async for overlapping memory copies with computation:
// From include/cutlass/arch/memory_sm80.h
template <int SizeInBytes>
struct cp_async<SizeInBytes, CacheOperation::Always> {
CUTLASS_DEVICE
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
#if CUDA_CP_ASYNC_ACTIVATED
unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n" ::"r"((int)pred_guard),
"r"(smem_int_ptr), "l"(global_ptr), "n"(SizeInBytes));
#endif
}
};
Usage Pattern:
// Issue async copies
cp_async<16>(smem_ptr, global_ptr, true);
cp_async_fence(); // Establish ordering
// Do computation
// ...
// Wait for copies to complete
cp_async_wait<0>();
__syncthreads();
Always pair cp_async_fence() with cp_async_wait<N>() to ensure proper synchronization.
5. Integer and Sub-byte Precision
INT8 Operations
// S32 = S8 × S8 + S32
template <>
struct Mma<
gemm::GemmShape<16,8,16>,
32,
int8_t,
layout::RowMajor,
int8_t,
layout::ColumnMajor,
int,
layout::RowMajor,
OpMultiplyAddSaturate> {
using ArchTag = arch::Sm80;
CUTLASS_HOST_DEVICE
void operator()(
FragmentC &d, FragmentA const &a, FragmentB const &b, FragmentC const &c
) const {
#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED)
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite "
"{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n"
: "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3])
: "r"(A[0]), "r"(A[1]), "r"(B),
"r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]));
#endif
}
};
INT4 Operations
// S32 = S4 × S4 + S32 (16x8x64 shape)
template <>
struct Mma<
gemm::GemmShape<16, 8, 64>,
32,
cutlass::int4b_t,
layout::RowMajor,
cutlass::int4b_t,
layout::ColumnMajor,
int,
layout::RowMajor,
OpMultiplyAddSaturate> {
using ArchTag = arch::Sm80;
// ... PTX: mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite
};
Binary (1-bit) Operations
// S32 = B1 AND B1 + S32 (16x8x256 shape)
template <>
struct Mma<
gemm::GemmShape<16,8,256>,
32,
cutlass::uint1b_t,
layout::RowMajor,
cutlass::uint1b_t,
layout::ColumnMajor,
int32_t,
layout::RowMajor,
OpAndPopc> {
using ArchTag = arch::Sm80;
// ... PTX: mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc
};
Binary operations support both AND and XOR modes, useful for neural network quantization and specialized applications.
Instruction Shapes
| Data Type | Instruction Shape (MxNxK) | Accumulator |
|---|
| TF32 | 16x8x4, 16x8x8 | FP32 |
| BF16 | 16x8x8, 16x8x16 | FP32 |
| FP16 | 16x8x16 | FP16/FP32 |
| FP64 | 8x8x4 | FP64 |
| INT8 | 16x8x16, 16x8x32 | INT32 |
| INT4 | 16x8x64 | INT32 |
| INT1 | 16x8x256 | INT32 |
Multistage Pipeline
Ampere uses cp.async to build efficient multistage pipelines:
// Conceptual pipeline structure
for (int stage = 0; stage < num_stages; ++stage) {
// Issue async copies for next stage
cp_async<16>(smem_ptr[stage], global_ptr + offset);
cp_async_fence();
}
for (int tile = 0; tile < num_tiles; ++tile) {
// Wait for stage to complete
cp_async_wait<num_stages - 2>();
__syncthreads();
// Compute with current stage
mma_operation(smem_ptr[tile % num_stages]);
// Prefetch next stage
if (tile + num_stages < num_tiles) {
cp_async<16>(smem_ptr[(tile + num_stages) % num_stages],
global_ptr + offset);
cp_async_fence();
}
}
Code Example
Here’s a complete Ampere TF32 GEMM example:
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
using Gemm = cutlass::gemm::device::Gemm<
float, // ElementA
cutlass::layout::RowMajor, // LayoutA
float, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
float, // ElementC
cutlass::layout::RowMajor, // LayoutC
float, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // Tensor Core operation
cutlass::arch::Sm80, // Architecture
cutlass::gemm::GemmShape<256, 128, 32>, // ThreadblockShape
cutlass::gemm::GemmShape<64, 64, 32>, // WarpShape
cutlass::gemm::GemmShape<16, 8, 8>, // InstructionShape (TF32)
cutlass::epilogue::thread::LinearCombination<
float, 128 / cutlass::sizeof_bits<float>::value,
float, float
>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3 // Stages
>;
int main() {
Gemm gemm_op;
Gemm::Arguments args{
{M, N, K}, // Problem size
{A, lda}, // Tensor A
{B, ldb}, // Tensor B
{C, ldc}, // Tensor C
{C, ldc}, // Tensor D
{alpha, beta} // Scalars
};
cutlass::Status status = gemm_op(args);
return status == cutlass::Status::kSuccess ? 0 : -1;
}
Compilation
Compile for Ampere architecture:
# For A100 (SM80)
nvcc -arch=sm_80 -std=c++17 example.cu -o example
# Using CMake
cmake .. -DCUTLASS_NVCC_ARCHS=80
# For both SM80 and SM86
cmake .. -DCUTLASS_NVCC_ARCHS="80;86"
Optimal Tile Sizes
For Ampere, recommended threadblock shapes:
- TF32/BF16: 128x128x32, 256x128x32
- FP16: 128x256x32, 256x128x64
- INT8: 128x256x64, 256x128x64
- FP64: 64x64x16, 128x64x16
Pipeline Stages
- Small kernels: 2-3 stages
- Large kernels: 4-5 stages
- Balance between latency hiding and shared memory usage
Shared Memory Configuration
// Request maximum shared memory per block
cudaFuncSetAttribute(
kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
98304 // 96 KB
);
Examples
CUTLASS provides numerous Ampere examples:
examples/14_ampere_tf32_tensorop_gemm/ - TF32 GEMM
examples/15_ampere_sparse_tensorop_gemm/ - Sparse Tensor Core operations
examples/18_ampere_fp64_tensorop_affine2_gemm/ - FP64 Tensor Cores
examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/ - 3xTF32 for accuracy
See Also