Skip to main content
Flux is a next-generation diffusion model offering exceptional image quality with optimized inference on TPUs and GPUs.

Quick start

You must have permissions to access Flux models on HuggingFace before running inference.
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_schnell.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_test \
  output_dir=/tmp/ \
  prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
  per_device_batch_size=1

Model variants

Flux schnell

Fast 4-step generation for real-time applications:
  • Steps: 4 inference steps
  • Speed: 0.8-2.2 seconds on TPU v6e/v4
  • Quality: High quality with minimal steps
  • Use case: Real-time generation, prototyping

Flux dev

Higher quality 28-step generation:
  • Steps: 28 inference steps
  • Speed: 5.5-23 seconds on TPU v6e/v4
  • Quality: Exceptional detail and coherence
  • Use case: Production applications, fine art

Performance benchmarks

Benchmarks on 1024×1024 images with flash attention and bfloat16:
ModelAcceleratorShardingBatch sizeStepsTime (sec)
Flux devv4-8DDP42823.0
Flux schnellv4-8DDP442.2
Flux devv6e-4DDP4285.5
Flux schnellv6e-4DDP440.8
Flux schnellv6e-4FSDP441.2

Trillium optimizations

TPU v6e (Trillium) supports optimized flash attention block sizes for faster inference.

Enable Trillium optimizations

Uncomment the flash_block_sizes configuration:
  • Flux dev: src/maxdiffusion/configs/base_flux_dev.yml#60
  • Flux schnell: src/maxdiffusion/configs/base_flux_schnell.yml#68
Example configuration:
flash_block_sizes: {
  "block_q": 512,
  "block_kv_compute": 512, 
  "block_kv": 512,
  "block_q_dkv": 512,
  "block_kv_dkv": 512,
  "block_kv_dkv_compute": 512,
  "block_q_dq": 512,
  "block_kv_dq": 512
}

Sharding strategies

Data parallelism (DDP)

Default sharding strategy that replicates the model:
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_ddp \
  output_dir=/tmp/ \
  prompt="A serene mountain landscape at sunset" \
  ici_data_parallelism=4

FSDP with encoder offloading disabled

Shard model parameters while keeping all components in HBM:
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_schnell.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_fsdp \
  output_dir=/tmp/ \
  prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
  per_device_batch_size=1 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  offload_encoders=False
Benefits:
  • Keeps text encoders, VAE, and transformer in HBM at all times
  • Reduces host-device data transfer overhead
  • Improves performance for memory-constrained scenarios

LoRA support

Flux supports LoRA adapters for style transfer and fine-tuning.

Single LoRA

Load a single LoRA adapter:
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_lora \
  output_dir=/tmp/ \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  split_head_dim=True \
  lora_config='{"lora_model_name_or_path" : ["/home/user/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'

Multiple LoRAs

Combine multiple LoRA adapters:
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_multi_lora \
  output_dir=/tmp/ \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  split_head_dim=True \
  lora_config='{"lora_model_name_or_path" : ["/home/user/anime_lora.safetensors", "/home/user/amateurphoto-v6-forcu.safetensors"], "weight_name" : ["anime_lora.safetensors","amateurphoto-v6-forcu.safetensors"], "adapter_name" : ["anime","realistic"], "scale": [0.6, 0.6], "from_pt": ["true","true"]}'
LoRA parameters:
ParameterDescription
lora_model_name_or_pathPath to LoRA weights
weight_nameSafetensors filename
adapter_nameUnique adapter identifier
scaleLoRA influence (0.0-1.0)
from_ptConvert from PyTorch format

Tested LoRA collections

Not all LoRA formats have been tested. If a specific LoRA doesn’t load, please report it.

GPU support with TransformerEngine

Flux supports fused attention on GPUs via TransformerEngine.

Installation

cd maxdiffusion
pip install -U "jax[cuda12]"
pip install -r requirements.txt
pip install --upgrade torch torchvision
pip install "transformer_engine[jax]"
pip install .

Run inference

NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_gpu \
  output_dir=/tmp/ \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  split_head_dim=True \
  per_device_batch_size=1 \
  attention="cudnn_flash_te" \
  hardware=gpu

Parameters

ParameterDescriptionDefault
promptPrimary text promptRequired
prompt_2Secondary prompt (optional)Same as prompt
num_inference_stepsDenoising steps4 (schnell), 28 (dev)
guidance_scaleCFG strength3.5
resolutionImage size (height/width)1024
per_device_batch_sizeBatch size per device1
seedRandom seed0
max_sequence_lengthMax tokens for T5 encoder512
split_head_dimSplit attention headsFalse
offload_encodersOffload text encoders to CPUTrue

Architecture

Flux uses a novel architecture:
  • Transformer: Flow-matching transformer instead of UNet
  • Text encoders: Dual encoders (CLIP ViT-L/14 + T5-XXL)
  • VAE: Autoencoder with 8x spatial compression
  • Flow matching: Continuous normalizing flows for generation
  • Attention: Flash attention for efficient self-attention

Implementation details

The Flux pipeline (generate_flux.py:src/maxdiffusion/generate_flux.py) implements:

Latent packing

Flux uses 2×2 spatial packing (generate_flux.py:160-171):
def pack_latents(latents, batch_size, num_channels_latents, height, width):
  latents = jnp.reshape(latents, (batch_size, num_channels_latents, height // 2, 2, width // 2, 2))
  latents = jnp.permute_dims(latents, (0, 2, 4, 1, 3, 5))
  latents = jnp.reshape(latents, (batch_size, (height // 2) * (width // 2), num_channels_latents * 4))
  return latents

Time shifting

Flux uses time shift for better high-frequency details (generate_flux.py:127-134):
def time_shift(mu: float, sigma: float, t: Array):
  return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

# Apply shift based on sequence length
lin_function = get_lin_function(x1=256, y1=0.5, x2=4096, y2=1.15)
mu = lin_function(latents.shape[1])
timesteps = time_shift(mu, 1.0, timesteps)

Dual text encoding

Combines CLIP and T5-XXL embeddings (generate_flux.py:246-273):
def encode_prompt(prompt, prompt_2, clip_tokenizer, clip_text_encoder, t5_tokenizer, t5_text_encoder):
  # CLIP pooled embeddings
  pooled_prompt_embeds = get_clip_prompt_embeds(
      prompt=prompt, 
      tokenizer=clip_tokenizer, 
      text_encoder=clip_text_encoder
  )
  
  # T5-XXL embeddings for sequence conditioning
  prompt_embeds = get_t5_prompt_embeds(
      prompt=prompt_2,
      tokenizer=t5_tokenizer,
      text_encoder=t5_text_encoder,
      max_sequence_length=512
  )
  
  text_ids = jnp.zeros((prompt_embeds.shape[1], 3))
  return prompt_embeds, pooled_prompt_embeds, text_ids

Output

Generated images are saved as flux_{i}.png. The pipeline reports:
  • Compile time: Initial JAX compilation
  • Inference time: Generation duration

Next steps

Flux training

Fine-tune Flux on custom datasets

Wan video

Text-to-video generation

LoRA training

Train custom Flux LoRA adapters

SDXL inference

Alternative high-quality model

Build docs developers (and LLMs) love