Skip to main content

Overview

The train_flux.py script trains Flux Dev models using the MaxDiffusion framework. Flux training has been tested on TPU v5p with support for flash attention and bfloat16 precision.

Command-line usage

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml [OPTIONS]

Basic example

python src/maxdiffusion/train_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://<your-gcs-bucket>/" \
  save_final_checkpoint=True \
  jax_cache_dir="/tmp/jax_cache"

Performance

Expected results on 1024 x 1024 images with flash attention and bfloat16:
ModelAcceleratorSharding StrategyPer Device Batch SizeGlobal Batch SizeStep Time (secs)
Flux-devv5p-8DDP141.31

Configuration parameters

Run configuration

run_name
string
required
Name for this training run. Used for organizing outputs and metrics.
base_output_directory
string
GCS bucket path for outputs (e.g., gs://my-bucket/). Checkpoints and metrics will be saved here.
output_dir
string
default:"sdxl-model-finetuned"
Local or GCS directory for model outputs.
save_final_checkpoint
boolean
default:false
Save the final checkpoint after training completes.

Model configuration

pretrained_model_name_or_path
string
default:"black-forest-labs/FLUX.1-dev"
HuggingFace model identifier or local path to pretrained Flux model.
clip_model_name_or_path
string
default:"ariG23498/clip-vit-large-patch14-text-flax"
HuggingFace model identifier for CLIP text encoder.
t5xxl_model_name_or_path
string
default:"ariG23498/t5-v1-1-xxl-flax"
HuggingFace model identifier for T5-XXL text encoder.
flux_name
string
default:"flux-dev"
Flux model variant name. Options: flux-dev, flux-schnell.
unet_checkpoint
string
default:""
Path to a specific transformer checkpoint to load.
revision
string
default:"refs/pr/95"
Model revision/branch to use from HuggingFace.
weights_dtype
string
default:"bfloat16"
Data type for model weights. Options: float32, bfloat16.
activations_dtype
string
default:"bfloat16"
Data type for layer activations. Options: float32, bfloat16.
precision
string
default:"DEFAULT"
JAX precision for matmul and conv operations. Options: DEFAULT, HIGH, HIGHEST.
from_pt
boolean
default:true
Load weights from PyTorch format.
train_new_flux
boolean
default:false
If true, randomly initialize Flux weights to train from scratch. Otherwise, load from pretrained model.

Flux-specific parameters

max_sequence_length
number
default:512
Maximum text sequence length for T5 encoder.
time_shift
boolean
default:true
Enable time shifting for Flux flow matching.
base_shift
number
Base shift parameter for Flux.
max_shift
number
Maximum shift parameter for Flux.
offload_encoders
boolean
default:true
Offload T5 encoder after text encoding to save memory.

Attention configuration

attention
string
default:"flash"
Attention mechanism to use. Options: dot_product, flash, cudnn_flash_te.
split_head_dim
boolean
default:true
Whether to split attention head dimensions for sharding.
attention_sharding_uniform
boolean
default:true
Use uniform sequence sharding for both self and cross attention.
mask_padding_tokens
boolean
default:true
Pass segment IDs to attention to avoid attending to padding tokens. Improves quality when padding is significant.
flash_block_sizes
object
Custom block sizes for flash attention. On v6e (Trillium), use larger blocks:
flash_block_sizes: {
  "block_q" : 1536,
  "block_kv_compute" : 1536,
  "block_kv" : 1536,
  "block_q_dkv" : 1536,
  "block_kv_dkv" : 1536,
  "block_kv_dkv_compute" : 1536,
  "block_q_dq" : 1536,
  "block_kv_dq" : 1536
}

Training hyperparameters

learning_rate
number
Learning rate for the optimizer.
scale_lr
boolean
default:false
Scale learning rate by the number of GPUs/TPUs and batch size.
max_train_steps
number
default:1500
Maximum number of training steps. Takes priority over num_train_epochs.
num_train_epochs
number
default:1
Number of training epochs.
max_train_samples
number
Maximum number of training samples to use. -1 means use all samples.
per_device_batch_size
number
default:1
Batch size per device.
warmup_steps_fraction
number
Fraction of total steps to use for learning rate warmup.
seed
number
default:0
Random seed for reproducibility.

Optimizer parameters

adam_b1
number
Exponential decay rate for first moment estimates.
adam_b2
number
Exponential decay rate for second moment estimates.
adam_eps
number
Small constant for numerical stability.
adam_weight_decay
number
default:0
Weight decay coefficient for AdamW optimizer.
max_grad_norm
number
Maximum gradient norm for gradient clipping.

Loss and noise schedule

snr_gamma
number
SNR-weighted loss gamma parameter. Set to -1.0 to disable.
timestep_bias
object
Configuration for biasing timestep sampling during training.
  • strategy: Bias strategy. Options: none, earlier, later, range
  • multiplier: Bias multiplier (2.0 doubles weight, 0.5 halves it)
  • begin: Start timestep for range strategy
  • end: End timestep for range strategy
  • portion: Fraction of timesteps to bias
diffusion_scheduler_config
object
Override parameters for the diffusion scheduler.
  • _class_name: Scheduler class name (default: FlaxEulerDiscreteScheduler)
  • prediction_type: Prediction type (default: epsilon)
  • rescale_zero_terminal_snr: Whether to rescale zero terminal SNR
  • timestep_spacing: Timestep spacing strategy (default: trailing)

Dataset configuration

dataset_name
string
default:"diffusers/pokemon-gpt4-captions"
HuggingFace dataset identifier.
train_data_dir
string
default:""
Local directory containing training data. Either this or dataset_name must be set.
train_split
string
default:"train"
Dataset split to use for training.
dataset_type
string
default:"tfrecord"
Dataset format type. Options: tfrecord, hf, tf, grain, synthetic.
cache_latents_text_encoder_outputs
boolean
default:true
Cache image latents and text encoder outputs to reduce memory and speed up training.
dataset_save_location
string
default:"/tmp/pokemon-gpt4-captions_xl"
Path to save transformed dataset when caching is enabled.
image_column
string
default:"image"
Name of the image column in the dataset.
caption_column
string
default:"text"
Name of the caption column in the dataset.
resolution
number
default:1024
Image resolution for training.
center_crop
boolean
default:false
Whether to center crop images before resizing.
random_flip
boolean
default:false
Whether to randomly flip images horizontally.
enable_data_shuffling
boolean
default:true
Shuffle the dataset during training.

Parallelism and sharding

hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.
mesh_axes
array
default:["data","fsdp","context","tensor"]
Logical mesh axes for parallelism.
dcn_data_parallelism
number
default:1
Data parallelism across DCN.
dcn_fsdp_parallelism
number
FSDP parallelism across DCN. -1 for auto-sharding.
dcn_tensor_parallelism
number
default:1
Tensor parallelism across DCN.
ici_data_parallelism
number
Data parallelism within ICI. -1 for auto-sharding.
ici_fsdp_parallelism
number
default:1
FSDP parallelism within ICI.
ici_tensor_parallelism
number
default:1
Tensor parallelism within ICI.

Checkpointing

checkpoint_every
number
Save checkpoint every N samples. -1 disables checkpointing.
enable_single_replica_ckpt_restoring
boolean
default:false
Enable one replica to read checkpoint and broadcast to others.

Metrics and logging

write_metrics
boolean
default:true
Save metrics such as loss and TFLOPS to GCS.
gcs_metrics
boolean
default:false
Write metrics to GCS.
log_period
number
default:100
Tensorboard flush period.

Profiling

enable_profiler
boolean
default:false
Enable JAX profiler.
skip_first_n_steps_for_profiler
number
default:5
Skip first N steps when profiling to exclude compilation.
profiler_steps
number
default:10
Number of steps to profile.
profiler
string
default:""
Profiler configuration.

Generation parameters

prompt
string
Prompt for test image generation during training.
prompt_2
string
Secondary prompt for dual text encoder.
negative_prompt
string
default:"purple, red"
Negative prompt for guidance.
do_classifier_free_guidance
boolean
default:true
Enable classifier-free guidance.
guidance_scale
number
Classifier-free guidance scale for Flux.
num_inference_steps
number
default:50
Number of denoising steps for test generation.

Expected outputs

Checkpoints

When save_final_checkpoint is true or checkpoint_every is set, checkpoints are saved to:
{base_output_directory}/{run_name}/checkpoints/
Checkpoints include:
  • Flux transformer weights
  • Optimizer state
  • Training step information

Metrics

Training metrics are saved to:
{base_output_directory}/{run_name}/metrics/
Metrics include:
  • Training loss
  • Learning rate
  • TFLOPS per device
  • Step time

Tensorboard logs

Tensorboard logs are written to:
{base_output_directory}/{run_name}/tensorboard/
View logs with:
tensorboard --logdir=gs://your-bucket/{run_name}/tensorboard/

Generating images from trained checkpoint

After training, generate images with:
python src/maxdiffusion/generate_flux_pipeline.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://<your-gcs-bucket>/" \
  jax_cache_dir="/tmp/jax_cache"

Build docs developers (and LLMs) love