Flux is a transformer-based diffusion model for high-quality text-to-image generation. MaxDiffusion supports fine-tuning Flux Dev on TPU v5p.
Flux fine-tuning has only been tested on TPU v5p. Other accelerators may not be supported.
Expected results on 1024 x 1024 images with flash attention and bfloat16:
| Model | Accelerator | Sharding Strategy | Per Device Batch Size | Global Batch Size | Step Time (secs) |
|---|
| Flux-dev | v5p-8 | DDP | 1 | 4 | 1.31 |
Basic training
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
run_name="test-flux-train" \
output_dir="gs://<your-gcs-bucket>/" \
save_final_checkpoint=True \
jax_cache_dir="/tmp/jax_cache"
Configuration
The base config is located at src/maxdiffusion/configs/base_flux_dev.yml.
Key parameters
| Parameter | Default | Description |
|---|
pretrained_model_name_or_path | black-forest-labs/FLUX.1-dev | Base Flux model |
clip_model_name_or_path | ariG23498/clip-vit-large-patch14-text-flax | CLIP text encoder |
t5xxl_model_name_or_path | ariG23498/t5-v1-1-xxl-flax | T5-XXL text encoder |
flux_name | flux-dev | Flux model variant |
weights_dtype | bfloat16 | Weight precision |
activations_dtype | bfloat16 | Activation precision |
attention | flash | Attention mechanism |
resolution | 1024 | Training image resolution |
per_device_batch_size | 1 | Batch size per device |
learning_rate | 1.e-5 | Initial learning rate |
max_train_steps | 1500 | Maximum training steps |
Flux-specific parameters
flux_name: "flux-dev"
max_sequence_length: 512
time_shift: True
base_shift: 0.5
max_shift: 1.15
offload_encoders: True # Offload T5 encoder to save memory
Dataset configuration
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tfrecord' # Options: tfrecord, hf, tf, grain, synthetic
cache_latents_text_encoder_outputs: True
dataset_save_location: '/tmp/pokemon-gpt4-captions_xl'
resolution: 1024
image_column: 'image'
caption_column: 'text'
Training parameters
learning_rate: 1.e-5
scale_lr: False
max_train_steps: 1500
num_train_epochs: 1
warmup_steps_fraction: 0.1
per_device_batch_size: 1
Optimizer configuration
adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 0 # No weight decay for Flux
max_grad_norm: 1.0
Parallelism configuration
# Mesh axes
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
# ICI parallelism (within host)
ici_data_parallelism: -1 # Auto-shard
ici_fsdp_parallelism: 1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
# DCN parallelism (across hosts)
dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1 # Auto-shard
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
Logical axis rules
Flux uses specialized sharding rules for transformer layers:
logical_axis_rules:
- ['batch', 'data']
- ['activation_batch', ['data','fsdp']]
- ['activation_heads', 'tensor']
- ['activation_kv', 'tensor']
- ['mlp','tensor'] # MLP layers sharded on tensor axis
- ['embed','fsdp']
- ['heads', 'tensor']
Advanced features
Train from scratch
train_new_flux: False # Set to True to initialize random weights
Text encoder training
Text encoder training is currently not supported for Flux.
Flash attention block sizes
Flux uses default flash attention block sizes:
For v6e (Trillium) with larger vmem, use:
flash_block_sizes:
block_q: 1536
block_kv_compute: 1536
block_kv: 1536
block_q_dkv: 1536
block_kv_dkv: 1536
block_kv_dkv_compute: 1536
block_q_dq: 1536
block_kv_dq: 1536
Memory optimization
Offload encoders:
offload_encoders: True # Offload T5 encoder after text encoding
Precision settings:
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'
precision: "DEFAULT"
Scheduler configuration
diffusion_scheduler_config:
_class_name: 'FlaxEulerDiscreteScheduler'
prediction_type: 'epsilon'
rescale_zero_terminal_snr: False
timestep_spacing: 'trailing'
Synthetic data for benchmarking
For benchmarking without downloading datasets:
dataset_type: 'synthetic'
# synthetic_num_samples: null # null for infinite samples
# Optional dimension overrides:
resolution: 512
Generate images from checkpoint
After training, generate images with your fine-tuned Flux model:
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml \
run_name="test-flux-train" \
output_dir="gs://<your-gcs-bucket>/" \
jax_cache_dir="/tmp/jax_cache"
Generation parameters
prompt: "A magical castle in the middle of a forest, artistic drawing"
prompt_2: "A magical castle in the middle of a forest, artistic drawing"
negative_prompt: "purple, red"
do_classifier_free_guidance: True
guidance_scale: 3.5
guidance_rescale: 0.0
num_inference_steps: 50
Profiling
Enable performance profiling:
enable_profiler: True
skip_first_n_steps_for_profiler: 5
profiler_steps: 10
profiler: "" # Leave empty for default profiler
Checkpointing
checkpoint_every: -1 # -1 to disable periodic checkpoints
save_final_checkpoint: True # Save checkpoint at end of training
enable_single_replica_ckpt_restoring: False
LoRA support
Flux supports LoRA adapters:
lora_config:
lora_model_name_or_path: []
weight_name: []
adapter_name: []
scale: []
from_pt: []
Example configuration:
lora_config:
lora_model_name_or_path: ["ByteDance/Hyper-SD"]
weight_name: ["Hyper-SDXL-2steps-lora.safetensors"]
adapter_name: ["hyper-sdxl"]
scale: [0.7]
from_pt: [True]
Custom datasets
HuggingFace dataset
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
dataset_name="your-username/your-dataset" \
dataset_type="hf"
TFRecords
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
dataset_type="tfrecord" \
train_data_dir="gs://your-bucket/tfrecords/"
Local directory
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml \
dataset_type="tf" \
train_data_dir="/path/to/images"
Monitoring
View training metrics:
tensorboard --logdir=gs://your-bucket/test-flux-train/tensorboard/
Best practices
- Use v5p for Flux: Only tested on TPU v5p
- Enable encoder offloading: Set
offload_encoders=True to save memory
- Use bfloat16: Always use bfloat16 for weights and activations
- Flash attention: Use
attention=flash for optimal performance
- Monitor VRAM: Flux is memory-intensive; adjust batch size if needed
- Save final checkpoint: Set
save_final_checkpoint=True to preserve trained weights