Skip to main content
Automatic Mixed Precision (AMP) allows you to train models faster and with less memory by automatically using lower precision (FP16/BF16) for operations that can tolerate it, while maintaining FP32 precision for operations that require it.

Why Mixed Precision?

Mixed precision training provides several benefits:
  • Faster Training: FP16 operations are 2-3x faster on modern GPUs (Volta, Turing, Ampere, and newer)
  • Reduced Memory: FP16 tensors use half the memory of FP32, allowing larger batch sizes
  • Maintained Accuracy: Critical operations remain in FP32 to prevent numerical instability

Basic AMP Usage

1

Import AMP Components

Import the autocast context manager and GradScaler:
import torch
from torch.amp import autocast, GradScaler
2

Create GradScaler

The GradScaler prevents gradient underflow by scaling the loss:
# Create scaler once at the beginning of training
scaler = GradScaler(device='cuda')
3

Wrap Forward Pass with Autocast

Use autocast to automatically cast operations to FP16:
for epoch in epochs:
    for batch in dataloader:
        optimizer.zero_grad()

        # Autocast wraps the forward pass
        with autocast(device_type='cuda', dtype=torch.float16):
            output = model(batch)
            loss = criterion(output, target)

        # Scale loss and call backward
        scaler.scale(loss).backward()

        # Unscale gradients and step optimizer
        scaler.step(optimizer)

        # Update scaler for next iteration
        scaler.update()

Understanding Autocast

The autocast context manager automatically chooses the optimal precision for each operation:

Supported Precisions

# Best for NVIDIA GPUs (Volta and newer)
with autocast(device_type='cuda', dtype=torch.float16):
    output = model(input)
BF16 (bfloat16) offers the same dynamic range as FP32 but with reduced precision. It’s more numerically stable than FP16 and is recommended when available.

Operation Precision Rules

Autocast applies different precisions based on operation type:
with autocast(device_type='cuda', dtype=torch.float16):
    # Matrix multiplications -> FP16 (fast)
    out1 = torch.matmul(a, b)
    
    # Convolutions -> FP16 (fast)
    out2 = F.conv2d(input, weight)
    
    # Softmax -> FP32 (accurate)
    out3 = F.softmax(logits, dim=-1)
    
    # Layer norm -> FP32 (accurate)
    out4 = F.layer_norm(input, normalized_shape)
    
    # Loss functions -> FP32 (accurate)
    loss = F.cross_entropy(output, target)

GradScaler Deep Dive

The GradScaler prevents gradient underflow in FP16 by scaling the loss:

How It Works

1

Scale Loss

Multiply loss by a scale factor before backward pass:
# Scale loss (typically by 2^16)
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
2

Unscale Gradients

Before optimizer step, unscale gradients and check for infinities:
# Unscale gradients (happens inside scaler.step)
scaler.step(optimizer)  # Only steps if no inf/nan
3

Update Scale Factor

Adjust scale factor based on whether infinities were found:
# Increase scale if no inf/nan, decrease if found
scaler.update()

GradScaler Configuration

scaler = GradScaler(
    device='cuda',
    init_scale=2.**16,      # Initial scale factor
    growth_factor=2.0,       # Multiply scale by this if no inf/nan
    backoff_factor=0.5,      # Multiply scale by this if inf/nan found
    growth_interval=2000,    # Steps before increasing scale
    enabled=True             # Enable/disable scaler
)

Gradient Clipping with AMP

When using gradient clipping with AMP, unscale gradients first:
for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast(device_type='cuda', dtype=torch.float16):
        output = model(batch)
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    
    # Unscale gradients before clipping
    scaler.unscale_(optimizer)
    
    # Now clip gradients (in FP32)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    # Step optimizer
    scaler.step(optimizer)
    scaler.update()
Always call scaler.unscale_() before gradient clipping. Clipping scaled gradients leads to incorrect behavior.

Multiple Optimizers

Handling multiple optimizers with AMP:
scaler = GradScaler(device='cuda')
opt1 = torch.optim.Adam(model1.parameters())
opt2 = torch.optim.SGD(model2.parameters())

for batch in dataloader:
    opt1.zero_grad()
    opt2.zero_grad()
    
    with autocast(device_type='cuda', dtype=torch.float16):
        output1 = model1(batch)
        output2 = model2(output1)
        loss = criterion(output2, target)
    
    scaler.scale(loss).backward()
    
    # Step both optimizers
    scaler.step(opt1)
    scaler.step(opt2)
    
    # Update scaler once
    scaler.update()

Gradient Accumulation

Combine AMP with gradient accumulation for large effective batch sizes:
scaler = GradScaler(device='cuda')
accumulation_steps = 4

for i, batch in enumerate(dataloader):
    with autocast(device_type='cuda', dtype=torch.float16):
        output = model(batch)
        loss = criterion(output, target)
        # Scale loss by accumulation steps
        loss = loss / accumulation_steps
    
    scaler.scale(loss).backward()
    
    # Only step every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

Custom Autograd Functions

For custom autograd functions, use decorators to control precision:
from torch.amp import custom_fwd, custom_bwd

class MyFunction(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, input, weight):
        # Forward runs in autocast precision
        ctx.save_for_backward(input, weight)
        return input.mm(weight)
    
    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        # Backward runs in same precision as forward
        input, weight = ctx.saved_tensors
        grad_input = grad_output.mm(weight.t())
        grad_weight = input.t().mm(grad_output)
        return grad_input, grad_weight
Force inputs to FP32 in custom functions:
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight):
    # input and weight are cast to FP32
    return input.mm(weight)

Distributed Training with AMP

Combine AMP with DDP for distributed mixed-precision training:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.amp import autocast, GradScaler

# Initialize process group
dist.init_process_group(backend='nccl')

# Wrap model with DDP
model = DDP(model.cuda(), device_ids=[local_rank])

# Create scaler
scaler = GradScaler(device='cuda')

for batch in dataloader:
    optimizer.zero_grad()
    
    with autocast(device_type='cuda', dtype=torch.bfloat16):
        output = model(batch)
        loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
AMP works seamlessly with DDP. Gradient scaling is handled per-process, and gradient synchronization happens correctly.

Performance Tips

Choose the Right Precision

# Best for: Inference, well-conditioned models
# Risks: Potential underflow/overflow
with autocast(device_type='cuda', dtype=torch.float16):
    output = model(input)

Disable AMP for Debugging

# Disable AMP with a flag
use_amp = True  # Set to False for debugging

with autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
    output = model(input)

Profile Mixed Precision

from torch.profiler import profile, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
    with autocast(device_type='cuda', dtype=torch.float16):
        output = model(input)
        loss = criterion(output, target)
    loss.backward()

print(prof.key_averages().table(sort_by="cuda_time_total"))

Troubleshooting

Common Issues:
  1. Loss becomes NaN: Try BF16 instead of FP16, or reduce learning rate
  2. Gradients overflow: Decrease GradScaler’s init_scale
  3. Slower than FP32: Ensure GPU supports Tensor Cores (compute capability ≥ 7.0)
  4. Memory not reduced: Check that large activations are in FP16/BF16

Debugging NaN Losses

# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Check for NaN/Inf in loss
if torch.isnan(loss) or torch.isinf(loss):
    print("Loss is NaN or Inf!")
    # Try reducing learning rate or using BF16

Build docs developers (and LLMs) love