Skip to main content
MaxDiffusion includes built-in profiling capabilities to help you analyze and optimize training and inference performance.

Overview

The profiler captures detailed performance metrics including:
  • Step time and throughput (TFLOP/s)
  • Compilation time
  • Memory usage patterns
  • Device utilization
  • Communication overhead

Configuration parameters

Profiling is controlled by three parameters in your config file:
enable_profiler: False
skip_first_n_steps_for_profiler: 5
profiler_steps: 10

enable_profiler

  • Type: boolean
  • Default: False
  • Description: Master switch to enable or disable profiling
Set to True to activate profiling:
enable_profiler: True

skip_first_n_steps_for_profiler

  • Type: integer
  • Default: 5
  • Description: Number of initial training steps to skip before profiling begins
This parameter is important because:
  • Early steps include compilation overhead
  • Step times are unstable during warmup
  • Skipping these steps provides more accurate performance measurements
skip_first_n_steps_for_profiler: 3  # Start profiling after 3 steps

profiler_steps

  • Type: integer
  • Default: 10
  • Description: Number of steps to profile after the skip period
profiler_steps: 5  # Profile 5 steps after warmup

Example usage

Training with profiling

Here’s a complete example from Wan 2.1 training that shows profiling in action:
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention='flash' \
  weights_dtype=bfloat16 \
  activations_dtype=bfloat16 \
  run_name=${RUN_NAME} \
  output_dir=${OUTPUT_DIR} \
  train_data_dir=${DATASET_DIR} \
  max_train_steps=1000 \
  enable_profiler=True \
  skip_first_n_steps_for_profiler=3 \
  profiler_steps=3 \
  per_device_batch_size=0.25 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=4
Expected output:
***** Running training *****
Instantaneous batch size per device = 0.25
Total train batch size (w. parallel & distributed) = 1
Total optimization steps = 1000
Calculated TFLOPs per pass: 4893.2719
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
To see full metrics 'tensorboard --logdir=gs://bucket/tensorboard/'
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210
completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120
Note how steps 0-2 have longer times (including compilation), while step 3 onward shows stable performance.

Inference profiling

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=50 \
  enable_profiler=True \
  skip_first_n_steps_for_profiler=2 \
  profiler_steps=5 \
  run_name=wan-inference-profiling \
  output_dir=gs://your-bucket

Profiling workflow

1. Enable profiling

Modify your config or pass parameters on the command line:
enable_profiler=True \
skip_first_n_steps_for_profiler=5 \
profiler_steps=10

2. Run your workload

Execute training or inference as normal. The profiler will:
  1. Skip the first N steps (default: 5)
  2. Collect profiling data for the next M steps (default: 10)
  3. Save profiling traces to your output directory

3. Analyze results

View metrics in TensorBoard:
tensorboard --logdir=gs://your-bucket/your-run-name/tensorboard/
Or examine profiler traces using JAX profiling tools.

Best practices

Skip warmup steps

Always skip initial steps to avoid skewed measurements:
# Good: Skips compilation and warmup
skip_first_n_steps_for_profiler: 5

# Bad: Includes compilation overhead
skip_first_n_steps_for_profiler: 0

Profile sufficient steps

Capture enough steps to identify patterns:
# Good: Enough data to see trends
profiler_steps: 10

# Limited: May miss performance variations
profiler_steps: 2

Production vs development

Disable profiling in production to avoid overhead:
# Development/tuning
enable_profiler: True

# Production
enable_profiler: False

Common profiling scenarios

Optimize step time

enable_profiler: True
skip_first_n_steps_for_profiler: 5
profiler_steps: 20  # Longer profile for detailed analysis

Quick performance check

enable_profiler: True
skip_first_n_steps_for_profiler: 3
profiler_steps: 5  # Short profile for quick feedback

Multi-slice profiling

enable_profiler: True
skip_first_n_steps_for_profiler: 10  # More warmup for distributed setup
profiler_steps: 15

Understanding profiler output

The profiler generates:
  1. Console metrics: Step time, TFLOP/s, loss values
  2. TensorBoard logs: Detailed metrics over time
  3. JAX traces: Low-level performance data for expert analysis

Key metrics to monitor

  • TFLOP/s/device: Higher is better; indicates compute efficiency
  • Step time: Lower is better; total time per training step
  • Step time stability: Should stabilize after warmup
  • Loss values: Verify training is progressing correctly

Build docs developers (and LLMs) love