Flux is a next-generation diffusion model offering exceptional image quality with optimized inference on TPUs and GPUs.
Quick start
You must have permissions to access Flux models on HuggingFace before running inference.
Flux schnell
Flux dev
Flux with FSDP
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_schnell.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_test \
output_dir=/tmp/ \
prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
per_device_batch_size= 1
Model variants
Flux schnell
Fast 4-step generation for real-time applications:
Steps : 4 inference steps
Speed : 0.8-2.2 seconds on TPU v6e/v4
Quality : High quality with minimal steps
Use case : Real-time generation, prototyping
Flux dev
Higher quality 28-step generation:
Steps : 28 inference steps
Speed : 5.5-23 seconds on TPU v6e/v4
Quality : Exceptional detail and coherence
Use case : Production applications, fine art
Benchmarks on 1024×1024 images with flash attention and bfloat16:
Model Accelerator Sharding Batch size Steps Time (sec) Flux dev v4-8 DDP 4 28 23.0 Flux schnell v4-8 DDP 4 4 2.2 Flux dev v6e-4 DDP 4 28 5.5 Flux schnell v6e-4 DDP 4 4 0.8 Flux schnell v6e-4 FSDP 4 4 1.2
Trillium optimizations
TPU v6e (Trillium) supports optimized flash attention block sizes for faster inference.
Enable Trillium optimizations
Uncomment the flash_block_sizes configuration:
Flux dev: src/maxdiffusion/configs/base_flux_dev.yml#60
Flux schnell: src/maxdiffusion/configs/base_flux_schnell.yml#68
Example configuration:
flash_block_sizes : {
"block_q" : 512 ,
"block_kv_compute" : 512 ,
"block_kv" : 512 ,
"block_q_dkv" : 512 ,
"block_kv_dkv" : 512 ,
"block_kv_dkv_compute" : 512 ,
"block_q_dq" : 512 ,
"block_kv_dq" : 512
}
Sharding strategies
Data parallelism (DDP)
Default sharding strategy that replicates the model:
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_ddp \
output_dir=/tmp/ \
prompt="A serene mountain landscape at sunset" \
ici_data_parallelism= 4
FSDP with encoder offloading disabled
Shard model parameters while keeping all components in HBM:
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_schnell.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_fsdp \
output_dir=/tmp/ \
prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
per_device_batch_size= 1 \
ici_data_parallelism= 1 \
ici_fsdp_parallelism= -1 \
offload_encoders=False
Benefits:
Keeps text encoders, VAE, and transformer in HBM at all times
Reduces host-device data transfer overhead
Improves performance for memory-constrained scenarios
LoRA support
Flux supports LoRA adapters for style transfer and fine-tuning.
Single LoRA
Load a single LoRA adapter:
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_lora \
output_dir=/tmp/ \
prompt='A cute corgi lives in a house made out of sushi, anime' \
num_inference_steps= 28 \
ici_data_parallelism= 1 \
ici_fsdp_parallelism= -1 \
split_head_dim=True \
lora_config='{"lora_model_name_or_path" : ["/home/user/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'
Multiple LoRAs
Combine multiple LoRA adapters:
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_multi_lora \
output_dir=/tmp/ \
prompt='A cute corgi lives in a house made out of sushi, anime' \
num_inference_steps= 28 \
ici_data_parallelism= 1 \
ici_fsdp_parallelism= -1 \
split_head_dim=True \
lora_config='{"lora_model_name_or_path" : ["/home/user/anime_lora.safetensors", "/home/user/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
LoRA parameters:
Parameter Description lora_model_name_or_pathPath to LoRA weights weight_nameSafetensors filename adapter_nameUnique adapter identifier scaleLoRA influence (0.0-1.0) from_ptConvert from PyTorch format
Tested LoRA collections
Not all LoRA formats have been tested. If a specific LoRA doesn’t load, please report it.
Flux supports fused attention on GPUs via TransformerEngine.
Installation
cd maxdiffusion
pip install -U "jax[cuda12]"
pip install -r requirements.txt
pip install --upgrade torch torchvision
pip install "transformer_engine[jax]"
pip install .
Run inference
NVTE_FUSED_ATTN = 1 HF_HUB_ENABLE_HF_TRANSFER = 1 \
python src/maxdiffusion/generate_flux.py \
src/maxdiffusion/configs/base_flux_dev.yml \
jax_cache_dir=/tmp/cache_dir \
run_name=flux_gpu \
output_dir=/tmp/ \
prompt='A cute corgi lives in a house made out of sushi, anime' \
num_inference_steps= 28 \
split_head_dim=True \
per_device_batch_size= 1 \
attention="cudnn_flash_te" \
hardware=gpu
Parameters
Parameter Description Default promptPrimary text prompt Required prompt_2Secondary prompt (optional) Same as prompt num_inference_stepsDenoising steps 4 (schnell), 28 (dev) guidance_scaleCFG strength 3.5 resolutionImage size (height/width) 1024 per_device_batch_sizeBatch size per device 1 seedRandom seed 0 max_sequence_lengthMax tokens for T5 encoder 512 split_head_dimSplit attention heads False offload_encodersOffload text encoders to CPU True
Architecture
Flux uses a novel architecture:
Transformer : Flow-matching transformer instead of UNet
Text encoders : Dual encoders (CLIP ViT-L/14 + T5-XXL)
VAE : Autoencoder with 8x spatial compression
Flow matching : Continuous normalizing flows for generation
Attention : Flash attention for efficient self-attention
Implementation details
The Flux pipeline (generate_flux.py:src/maxdiffusion/generate_flux.py) implements:
Latent packing
Flux uses 2×2 spatial packing (generate_flux.py:160-171):
def pack_latents ( latents , batch_size , num_channels_latents , height , width ):
latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2 , 2 , width // 2 , 2 ))
latents = jnp.permute_dims(latents, ( 0 , 2 , 4 , 1 , 3 , 5 ))
latents = jnp.reshape(latents, (batch_size, (height // 2 ) * (width // 2 ), num_channels_latents * 4 ))
return latents
Time shifting
Flux uses time shift for better high-frequency details (generate_flux.py:127-134):
def time_shift ( mu : float , sigma : float , t : Array):
return math.exp(mu) / (math.exp(mu) + ( 1 / t - 1 ) ** sigma)
# Apply shift based on sequence length
lin_function = get_lin_function( x1 = 256 , y1 = 0.5 , x2 = 4096 , y2 = 1.15 )
mu = lin_function(latents.shape[ 1 ])
timesteps = time_shift(mu, 1.0 , timesteps)
Dual text encoding
Combines CLIP and T5-XXL embeddings (generate_flux.py:246-273):
def encode_prompt ( prompt , prompt_2 , clip_tokenizer , clip_text_encoder , t5_tokenizer , t5_text_encoder ):
# CLIP pooled embeddings
pooled_prompt_embeds = get_clip_prompt_embeds(
prompt = prompt,
tokenizer = clip_tokenizer,
text_encoder = clip_text_encoder
)
# T5-XXL embeddings for sequence conditioning
prompt_embeds = get_t5_prompt_embeds(
prompt = prompt_2,
tokenizer = t5_tokenizer,
text_encoder = t5_text_encoder,
max_sequence_length = 512
)
text_ids = jnp.zeros((prompt_embeds.shape[ 1 ], 3 ))
return prompt_embeds, pooled_prompt_embeds, text_ids
Output
Generated images are saved as flux_{i}.png. The pipeline reports:
Compile time: Initial JAX compilation
Inference time: Generation duration
Next steps
Flux training Fine-tune Flux on custom datasets
Wan video Text-to-video generation
LoRA training Train custom Flux LoRA adapters
SDXL inference Alternative high-quality model