Skip to main content
Flux is a transformer-based diffusion model for high-quality text-to-image generation. MaxDiffusion supports fine-tuning Flux Dev on TPU v5p.
Flux fine-tuning has only been tested on TPU v5p. Other accelerators may not be supported.

Performance benchmarks

Expected results on 1024 x 1024 images with flash attention and bfloat16:
ModelAcceleratorSharding StrategyPer Device Batch SizeGlobal Batch SizeStep Time (secs)
Flux-devv5p-8DDP141.31

Basic training

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://<your-gcs-bucket>/" \
  save_final_checkpoint=True \
  jax_cache_dir="/tmp/jax_cache"

Configuration

The base config is located at src/maxdiffusion/configs/base_flux_dev.yml.

Key parameters

ParameterDefaultDescription
pretrained_model_name_or_pathblack-forest-labs/FLUX.1-devBase Flux model
clip_model_name_or_pathariG23498/clip-vit-large-patch14-text-flaxCLIP text encoder
t5xxl_model_name_or_pathariG23498/t5-v1-1-xxl-flaxT5-XXL text encoder
flux_nameflux-devFlux model variant
weights_dtypebfloat16Weight precision
activations_dtypebfloat16Activation precision
attentionflashAttention mechanism
resolution1024Training image resolution
per_device_batch_size1Batch size per device
learning_rate1.e-5Initial learning rate
max_train_steps1500Maximum training steps

Flux-specific parameters

flux_name: "flux-dev"
max_sequence_length: 512
time_shift: True
base_shift: 0.5
max_shift: 1.15
offload_encoders: True  # Offload T5 encoder to save memory

Dataset configuration

dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tfrecord'  # Options: tfrecord, hf, tf, grain, synthetic
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
resolution: 1024
image_column: 'image'
caption_column: 'text'

Training parameters

learning_rate: 1.e-5
scale_lr: False
max_train_steps: 1500
num_train_epochs: 1
warmup_steps_fraction: 0.1
per_device_batch_size: 1

Optimizer configuration

adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 0  # No weight decay for Flux
max_grad_norm: 1.0

Parallelism configuration

# Mesh axes
mesh_axes: ['data', 'fsdp', 'context', 'tensor']

# ICI parallelism (within host)
ici_data_parallelism: -1  # Auto-shard
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1

# DCN parallelism (across hosts)
dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1  # Auto-shard
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1

Logical axis rules

Flux uses specialized sharding rules for transformer layers:
logical_axis_rules:
  - ['batch', 'data']
  - ['activation_batch', ['data','fsdp']]
  - ['activation_heads', 'tensor']
  - ['activation_kv', 'tensor']
  - ['mlp','tensor']  # MLP layers sharded on tensor axis
  - ['embed','fsdp']
  - ['heads', 'tensor']

Advanced features

Train from scratch

train_new_flux: False  # Set to True to initialize random weights

Text encoder training

Text encoder training is currently not supported for Flux.

Flash attention block sizes

Flux uses default flash attention block sizes:
flash_block_sizes: {}
For v6e (Trillium) with larger vmem, use:
flash_block_sizes:
  block_q: 1536
  block_kv_compute: 1536
  block_kv: 1536
  block_q_dkv: 1536
  block_kv_dkv: 1536
  block_kv_dkv_compute: 1536
  block_q_dq: 1536
  block_kv_dq: 1536

Memory optimization

Offload encoders:
offload_encoders: True  # Offload T5 encoder after text encoding
Precision settings:
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
precision: "DEFAULT"

Scheduler configuration

diffusion_scheduler_config:
  _class_name: 'FlaxEulerDiscreteScheduler'
  prediction_type: 'epsilon'
  rescale_zero_terminal_snr: False
  timestep_spacing: 'trailing'

Synthetic data for benchmarking

For benchmarking without downloading datasets:
dataset_type: 'synthetic'
# synthetic_num_samples: null  # null for infinite samples

# Optional dimension overrides:
resolution: 512

Generate images from checkpoint

After training, generate images with your fine-tuned Flux model:
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml \
  run_name="test-flux-train" \
  output_dir="gs://<your-gcs-bucket>/" \
  jax_cache_dir="/tmp/jax_cache"

Generation parameters

prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 3.5
guidance_rescale: 0.0
num_inference_steps: 50

Profiling

Enable performance profiling:
enable_profiler: True
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: ""  # Leave empty for default profiler

Checkpointing

checkpoint_every: -1  # -1 to disable periodic checkpoints
save_final_checkpoint: True  # Save checkpoint at end of training
enable_single_replica_ckpt_restoring: False

LoRA support

Flux supports LoRA adapters:
lora_config:
  lora_model_name_or_path: []
  weight_name: []
  adapter_name: []
  scale: []
  from_pt: []
Example configuration:
lora_config:
  lora_model_name_or_path: ["ByteDance/Hyper-SD"]
  weight_name: ["Hyper-SDXL-2steps-lora.safetensors"]
  adapter_name: ["hyper-sdxl"]
  scale: [0.7]
  from_pt: [True]

Custom datasets

HuggingFace dataset

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
  dataset_name="your-username/your-dataset" \
  dataset_type="hf"

TFRecords

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
  dataset_type="tfrecord" \
  train_data_dir="gs://your-bucket/tfrecords/"

Local directory

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
  dataset_type="tf" \
  train_data_dir="/path/to/images"

Monitoring

View training metrics:
tensorboard --logdir=gs://your-bucket/test-flux-train/tensorboard/

Best practices

  1. Use v5p for Flux: Only tested on TPU v5p
  2. Enable encoder offloading: Set offload_encoders=True to save memory
  3. Use bfloat16: Always use bfloat16 for weights and activations
  4. Flash attention: Use attention=flash for optimal performance
  5. Monitor VRAM: Flux is memory-intensive; adjust batch size if needed
  6. Save final checkpoint: Set save_final_checkpoint=True to preserve trained weights

Build docs developers (and LLMs) love