Skip to main content

Overview

The generate_flux.py script provides inference for Flux models (dev and schnell). Flux represents a significant advancement in diffusion models with improved quality and speed. The script supports LoRA loading, multiple sharding strategies, and optimized flash attention.

Usage

Flux Schnell (fast)

python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_schnell.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_test \
  output_dir=/tmp/ \
  prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
  per_device_batch_size=1

Flux Dev (high quality)

python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_test \
  output_dir=/tmp/ \
  prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" \
  per_device_batch_size=1

With LoRA

python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  split_head_dim=True \
  lora_config='{"lora_model_name_or_path" : ["/path/to/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'

FSDP sharding (keeps all components in HBM)

python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_schnell.yml \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  offload_encoders=False

GPU with fused attention

NVTE_FUSED_ATTN=1 HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  split_head_dim=True \
  per_device_batch_size=1 \
  attention="cudnn_flash_te" \
  hardware=gpu

Configuration parameters

Model parameters

pretrained_model_name_or_path
string
default:"black-forest-labs/FLUX.1-dev"
Base model path. Use black-forest-labs/FLUX.1-dev or black-forest-labs/FLUX.1-schnell.
flux_name
string
default:"flux-dev"
Flux variant identifier. Options: flux-dev, flux-schnell.
clip_model_name_or_path
string
default:"ariG23498/clip-vit-large-patch14-text-flax"
CLIP text encoder model path.
t5xxl_model_name_or_path
string
default:"ariG23498/t5-v1-1-xxl-flax"
T5-XXL text encoder model path.

Generation parameters

prompt
string
Primary text prompt for image generation.
prompt_2
string
Secondary prompt (can be same as primary or different for prompt blending).
num_inference_steps
integer
default:50
Number of denoising steps. Flux-dev typically uses 28-50, schnell uses 4.
guidance_scale
float
Guidance scale for prompt adherence. Flux works well with lower values (3-5).
resolution
integer
default:1024
Output image resolution. Flux is optimized for 1024x1024.
max_sequence_length
integer
default:512
Maximum sequence length for text encoders.

Flow matching parameters

time_shift
boolean
default:true
Enable time shifting schedule to favor high timesteps for higher signal images.
base_shift
float
Base shift parameter for time schedule.
max_shift
float
Maximum shift parameter for time schedule.

Performance parameters

weights_dtype
string
default:"bfloat16"
Data type for model weights. Always use bfloat16 for Flux.
activations_dtype
string
default:"bfloat16"
Data type for activations. Always use bfloat16 for Flux.
attention
string
default:"flash"
Attention implementation. Options: flash (TPU), cudnn_flash_te (GPU), dot_product.
split_head_dim
boolean
default:true
Split attention head dimensions for better parallelization.
offload_encoders
boolean
default:true
Offload T5 encoder to CPU after encoding to save memory.
flash_block_sizes
object
Custom block sizes for flash attention. Use optimized values for TPU v6e:
flash_block_sizes: {
  "block_q": 1536,
  "block_kv_compute": 1536,
  "block_kv": 1536,
  "block_q_dkv": 1536,
  "block_kv_dkv": 1536,
  "block_kv_dkv_compute": 1536,
  "block_q_dq": 1536,
  "block_kv_dq": 1536
}

LoRA parameters

lora_config
object
Configuration for loading LoRA adapters. Supports multiple LoRAs.Properties:
  • lora_model_name_or_path (array): Paths to LoRA files
  • weight_name (array): Filenames of LoRA weights
  • adapter_name (array): Unique adapter names
  • scale (array): Scale factors (0.0-1.0)
  • from_pt (array): Load from PyTorch format
Example (multiple LoRAs):
{
  "lora_model_name_or_path": ["/path/anime.safetensors", "/path/photo.safetensors"],
  "weight_name": ["anime.safetensors", "photo.safetensors"],
  "adapter_name": ["anime", "realistic"],
  "scale": [0.6, 0.6],
  "from_pt": ["true", "true"]
}

Parallelism parameters

ici_data_parallelism
integer
Data parallelism for DDP strategy. Use -1 for auto-sharding.
ici_fsdp_parallelism
integer
default:1
FSDP parallelism. Set to -1 to shard model across devices (keeps everything in HBM).
per_device_batch_size
integer
default:1
Batch size per device.

System parameters

run_name
string
required
Unique run identifier.
output_dir
string
default:""
Output directory for generated images.
seed
integer
default:0
Random seed for reproducibility.
jax_cache_dir
string
default:""
JAX compilation cache directory.
hardware
string
default:"tpu"
Hardware type. Options: tpu, gpu.

Output

Images are saved with the naming pattern flux_{i}.png.

Performance data

Expected results on 1024x1024 images with flash attention and bfloat16:
ModelAcceleratorShardingBatch SizeStepsTime (seconds)
Flux-devv4-8DDP42823
Flux-schnellv4-8DDP442.2
Flux-devv6e-4DDP4285.5
Flux-schnellv6e-4DDP440.8
Flux-schnellv6e-4FSDP441.2

Implementation details

Flux uses a unique architecture:
  • Dual text encoders (CLIP and T5-XXL) for rich text understanding
  • Flow matching instead of traditional diffusion
  • Transformer-based architecture with rotary positional embeddings
  • Packing-based latent representation
  • Time shift scheduling for improved quality
See source: ~/workspace/source/src/maxdiffusion/generate_flux.py:1

Build docs developers (and LLMs) love