This guide covers training a DDPM model on CIFAR-10, a more challenging dataset with 32×32 color images. The implementation includes exponential moving average (EMA), gradient accumulation, checkpointing, and both DDPM and DDIM sampling.
Prerequisites
Install the required dependencies:
pip install torch torchvision matplotlib tqdm
Training configuration
The CIFAR-10 training uses a production-ready configuration:
epochs = 2000
batch_size = 256
image_size = 32
channels = 3 # RGB color images
Model architecture
The model uses a deeper U-Net with dropout for regularization:
diffusion = DiffusionProcessCIFAR(
image_size=32,
channels=3,
dropout_p=0.1,
device=device,
)
Data preparation
Dataset and augmentation
CIFAR-10 images are augmented with random horizontal flips and normalized to [-1, 1]:
src/training/train_diffusion_cifar.py
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # [0,1] -> [-1,1]
])
dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transform
)
Optimized dataloader
The dataloader is configured for maximum throughput:
src/training/train_diffusion_cifar.py
num_workers = 16
DataLoader(
dataset,
batch_size=256,
shuffle=True,
num_workers=num_workers,
pin_memory=(device.type == "cuda"),
persistent_workers=True,
prefetch_factor=2,
)
persistent_workers=True keeps worker processes alive between epochs, reducing startup overhead.
Training features
Exponential moving average (EMA)
EMA maintains a smoothed version of model weights for better sample quality:
# EMA is applied automatically during training
# The EMA model is used for sampling and checkpointing
samples = diffusion.sample(num_samples=16) # Uses EMA model
Gradient accumulation
Gradient accumulation enables larger effective batch sizes on limited memory:
for x, _ in loader:
x = x.to(device)
loss = diffusion.train_step(x) # Accumulates gradients
epoch_loss += loss
if diffusion.accum_steps == 0: # Optimizer step occurred
scheduler.step()
Learning rate schedule
Cosine annealing with warmup provides stable training:
src/training/train_diffusion_cifar.py
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step: int) -> float:
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
steps_per_epoch = len(loader) // diffusion.grad_accumulation_steps
total_steps = epochs * steps_per_epoch
warmup_steps = int(0.05 * total_steps) # 5% warmup
Running training
python src/training/train_diffusion_cifar.py
With environment variables
EPOCHS=500 python src/training/train_diffusion_cifar.py
RESUME_FROM_BEST=1 EPOCHS=3000 python src/training/train_diffusion_cifar.py
Training progress is printed every epoch:
Epoch 1/2000 | loss=0.1234
Epoch 2/2000 | loss=0.1198
...
Every 25 epochs, the script generates:
Noising visualization at timesteps [0, 200, 400, 600, 800, 999]
16 samples from the EMA model
Checkpointing system
Automatic checkpointing
Checkpoints are saved every 25 epochs and when a new best loss is achieved:
src/training/train_diffusion_cifar.py
def save_checkpoint(
diffusion, optimizer, scheduler, epoch,
loss_history, best_loss, wait,
checkpoint_dir, is_best=False
):
checkpoint = {
"epoch": epoch,
"model_state_dict": diffusion.model.state_dict(),
"ema_model_state_dict": diffusion.ema_model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"loss_history": loss_history,
"best_loss": best_loss,
"wait": wait,
}
torch.save(checkpoint, checkpoint_path)
Checkpoint files
Checkpoints are saved to $WORK/stable-diffusion-cifar/checkpoints/:
checkpoint_latest.pt - Most recent checkpoint
checkpoint_best.pt - Best loss checkpoint
checkpoint_epoch{N}.pt - Periodic checkpoints
Resume training
Resume from the best checkpoint:
RESUME_FROM_BEST=1 python src/training/train_diffusion_cifar.py
Or specify a checkpoint path:
RESUME_FROM=/path/to/checkpoint.pt python src/training/train_diffusion_cifar.py
Environment variables
Customize training behavior with environment variables:
| Variable | Default | Description |
|---|
EPOCHS | 2000 | Total epochs to train |
PATIENCE | 0 | Early stopping patience |
EARLY_STOP | 0 | Enable early stopping (1) |
RESUME_FROM | None | Checkpoint path to resume from |
RESUME_FROM_BEST | 1 | Resume from best checkpoint |
WORK | ~ | Working directory for outputs |
When resuming with EPOCHS, the value is the total epoch count, not additional epochs. If you stopped at epoch 1000 and want to train to epoch 3000, set EPOCHS=3000.
Output locations
All outputs are saved to $WORK/stable-diffusion-cifar/:
$WORK/stable-diffusion-cifar/
├── checkpoints/
│ ├── checkpoint_latest.pt
│ ├── checkpoint_best.pt
│ └── checkpoint_epoch{N}.pt
├── cifar_samples/
│ ├── beta_schedule_cifar.png
│ ├── noising_epoch{N}.png
│ ├── samples_epoch{N}.png
│ ├── training_curve_cifar.png
│ ├── DDPM_CIFAR.png
│ └── DDIM_CIFAR.png
└── best_model_cifar.pt
Sampling methods
The trained model supports both DDPM and DDIM sampling:
DDPM sampling (1000 steps)
final_samples = diffusion.sample(num_samples=16)
utils.save_image(
torch.clamp((final_samples + 1) / 2, 0, 1),
"DDPM_CIFAR.png",
nrow=4,
)
DDIM sampling (50 steps)
DDIM provides faster sampling with comparable quality:
src/training/train_diffusion_cifar.py
final_ddim = diffusion.sample_ddim(num_samples=16, ddim_steps=50)
utils.save_image(
torch.clamp((final_ddim + 1) / 2, 0, 1),
"DDIM_CIFAR.png",
nrow=4,
)
DDIM sampling is 20× faster than DDPM (50 steps vs 1000) with minimal quality loss.
GPU optimizations
src/training/train_diffusion_cifar.py
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
torch.backends.cudnn.benchmark = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
Memory management
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
This allows PyTorch to allocate more GPU memory as needed.
Expected results
With default hyperparameters:
- Training time: 15-20 hours on A100 GPU for 2000 epochs
- Memory usage: ~10-15 GB VRAM with batch_size=256
- Final loss: ~0.005-0.010 MSE
- Sample quality: Sharp CIFAR-10 images with correct colors and shapes
Training curve
Expect the loss to:
- Drop rapidly in the first 100 epochs
- Gradually decrease until ~500 epochs
- Slowly improve with diminishing returns after 1000 epochs
Troubleshooting
Out of memory errors
Reduce batch size or enable gradient accumulation:
Poor sample quality
Ensure you’re using the EMA model for sampling. The raw model weights produce noisier samples.
Training divergence
If loss increases dramatically:
- Check learning rate (default: 1e-4 with AdamW)
- Reduce gradient accumulation
- Add gradient clipping
Next steps
- Scale training to HPC clusters: HPC SLURM Guide
- Explore the diffusion model implementation:
src/models/diffusion_cifar.py
- Experiment with DDIM sampling:
src/models/diffusion_cifar.py:sample_ddim()