Overview
Thetrain_flux.py script trains Flux Dev models using the MaxDiffusion framework. Flux training has been tested on TPU v5p with support for flash attention and bfloat16 precision.
Command-line usage
Basic example
Performance
Expected results on 1024 x 1024 images with flash attention and bfloat16:| Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) |
|---|---|---|---|---|---|
| Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 |
Configuration parameters
Run configuration
Name for this training run. Used for organizing outputs and metrics.
GCS bucket path for outputs (e.g.,
gs://my-bucket/). Checkpoints and metrics will be saved here.Local or GCS directory for model outputs.
Save the final checkpoint after training completes.
Model configuration
HuggingFace model identifier or local path to pretrained Flux model.
HuggingFace model identifier for CLIP text encoder.
HuggingFace model identifier for T5-XXL text encoder.
Flux model variant name. Options:
flux-dev, flux-schnell.Path to a specific transformer checkpoint to load.
Model revision/branch to use from HuggingFace.
Data type for model weights. Options:
float32, bfloat16.Data type for layer activations. Options:
float32, bfloat16.JAX precision for matmul and conv operations. Options:
DEFAULT, HIGH, HIGHEST.Load weights from PyTorch format.
If true, randomly initialize Flux weights to train from scratch. Otherwise, load from pretrained model.
Flux-specific parameters
Maximum text sequence length for T5 encoder.
Enable time shifting for Flux flow matching.
Base shift parameter for Flux.
Maximum shift parameter for Flux.
Offload T5 encoder after text encoding to save memory.
Attention configuration
Attention mechanism to use. Options:
dot_product, flash, cudnn_flash_te.Whether to split attention head dimensions for sharding.
Use uniform sequence sharding for both self and cross attention.
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
Custom block sizes for flash attention. On v6e (Trillium), use larger blocks:
Training hyperparameters
Learning rate for the optimizer.
Scale learning rate by the number of GPUs/TPUs and batch size.
Maximum number of training steps. Takes priority over
num_train_epochs.Number of training epochs.
Maximum number of training samples to use. -1 means use all samples.
Batch size per device.
Fraction of total steps to use for learning rate warmup.
Random seed for reproducibility.
Optimizer parameters
Exponential decay rate for first moment estimates.
Exponential decay rate for second moment estimates.
Small constant for numerical stability.
Weight decay coefficient for AdamW optimizer.
Maximum gradient norm for gradient clipping.
Loss and noise schedule
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
Configuration for biasing timestep sampling during training.
strategy: Bias strategy. Options:none,earlier,later,rangemultiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)begin: Start timestep for range strategyend: End timestep for range strategyportion: Fraction of timesteps to bias
Override parameters for the diffusion scheduler.
_class_name: Scheduler class name (default:FlaxEulerDiscreteScheduler)prediction_type: Prediction type (default:epsilon)rescale_zero_terminal_snr: Whether to rescale zero terminal SNRtimestep_spacing: Timestep spacing strategy (default:trailing)
Dataset configuration
HuggingFace dataset identifier.
Local directory containing training data. Either this or
dataset_name must be set.Dataset split to use for training.
Dataset format type. Options:
tfrecord, hf, tf, grain, synthetic.Cache image latents and text encoder outputs to reduce memory and speed up training.
Path to save transformed dataset when caching is enabled.
Name of the image column in the dataset.
Name of the caption column in the dataset.
Image resolution for training.
Whether to center crop images before resizing.
Whether to randomly flip images horizontally.
Shuffle the dataset during training.
Parallelism and sharding
Hardware type. Options:
tpu, gpu.Logical mesh axes for parallelism.
Data parallelism across DCN.
FSDP parallelism across DCN. -1 for auto-sharding.
Tensor parallelism across DCN.
Data parallelism within ICI. -1 for auto-sharding.
FSDP parallelism within ICI.
Tensor parallelism within ICI.
Checkpointing
Save checkpoint every N samples. -1 disables checkpointing.
Enable one replica to read checkpoint and broadcast to others.
Metrics and logging
Save metrics such as loss and TFLOPS to GCS.
Write metrics to GCS.
Tensorboard flush period.
Profiling
Enable JAX profiler.
Skip first N steps when profiling to exclude compilation.
Number of steps to profile.
Profiler configuration.
Generation parameters
Prompt for test image generation during training.
Secondary prompt for dual text encoder.
Negative prompt for guidance.
Enable classifier-free guidance.
Classifier-free guidance scale for Flux.
Number of denoising steps for test generation.
Expected outputs
Checkpoints
Whensave_final_checkpoint is true or checkpoint_every is set, checkpoints are saved to:
- Flux transformer weights
- Optimizer state
- Training step information
Metrics
Training metrics are saved to:- Training loss
- Learning rate
- TFLOPS per device
- Step time