Overview
Thegenerate_flux.py script provides inference for Flux models (dev and schnell). Flux represents a significant advancement in diffusion models with improved quality and speed. The script supports LoRA loading, multiple sharding strategies, and optimized flash attention.
Usage
Flux Schnell (fast)
Flux Dev (high quality)
With LoRA
FSDP sharding (keeps all components in HBM)
GPU with fused attention
Configuration parameters
Model parameters
Base model path. Use
black-forest-labs/FLUX.1-dev or black-forest-labs/FLUX.1-schnell.Flux variant identifier. Options:
flux-dev, flux-schnell.CLIP text encoder model path.
T5-XXL text encoder model path.
Generation parameters
Primary text prompt for image generation.
Secondary prompt (can be same as primary or different for prompt blending).
Number of denoising steps. Flux-dev typically uses 28-50, schnell uses 4.
Guidance scale for prompt adherence. Flux works well with lower values (3-5).
Output image resolution. Flux is optimized for 1024x1024.
Maximum sequence length for text encoders.
Flow matching parameters
Enable time shifting schedule to favor high timesteps for higher signal images.
Base shift parameter for time schedule.
Maximum shift parameter for time schedule.
Performance parameters
Data type for model weights. Always use
bfloat16 for Flux.Data type for activations. Always use
bfloat16 for Flux.Attention implementation. Options:
flash (TPU), cudnn_flash_te (GPU), dot_product.Split attention head dimensions for better parallelization.
Offload T5 encoder to CPU after encoding to save memory.
Custom block sizes for flash attention. Use optimized values for TPU v6e:
LoRA parameters
Configuration for loading LoRA adapters. Supports multiple LoRAs.Properties:
lora_model_name_or_path(array): Paths to LoRA filesweight_name(array): Filenames of LoRA weightsadapter_name(array): Unique adapter namesscale(array): Scale factors (0.0-1.0)from_pt(array): Load from PyTorch format
Parallelism parameters
Data parallelism for DDP strategy. Use -1 for auto-sharding.
FSDP parallelism. Set to -1 to shard model across devices (keeps everything in HBM).
Batch size per device.
System parameters
Unique run identifier.
Output directory for generated images.
Random seed for reproducibility.
JAX compilation cache directory.
Hardware type. Options:
tpu, gpu.Output
Images are saved with the naming patternflux_{i}.png.
Performance data
Expected results on 1024x1024 images with flash attention and bfloat16:| Model | Accelerator | Sharding | Batch Size | Steps | Time (seconds) |
|---|---|---|---|---|---|
| Flux-dev | v4-8 | DDP | 4 | 28 | 23 |
| Flux-schnell | v4-8 | DDP | 4 | 4 | 2.2 |
| Flux-dev | v6e-4 | DDP | 4 | 28 | 5.5 |
| Flux-schnell | v6e-4 | DDP | 4 | 4 | 0.8 |
| Flux-schnell | v6e-4 | FSDP | 4 | 4 | 1.2 |
Implementation details
Flux uses a unique architecture:- Dual text encoders (CLIP and T5-XXL) for rich text understanding
- Flow matching instead of traditional diffusion
- Transformer-based architecture with rotary positional embeddings
- Packing-based latent representation
- Time shift scheduling for improved quality
~/workspace/source/src/maxdiffusion/generate_flux.py:1