Skip to main content
The FluxPipeline is a text-to-image generation pipeline using the Flux transformer architecture.

FluxPipeline

Located in src/maxdiffusion/pipelines/flux/flux_pipeline.py:42

Components

t5_encoder
FlaxT5EncoderModel
T5 text encoder for processing prompts
clip_encoder
FlaxCLIPTextModel
CLIP text encoder for pooled embeddings
vae
FlaxAutoencoderKL
Variational Auto-Encoder for latent encoding/decoding
t5_tokenizer
AutoTokenizer
Tokenizer for T5 encoder
clip_tokenizer
CLIPTokenizer
Tokenizer for CLIP encoder
flux
FluxTransformer2DModel
Flux transformer model for denoising
scheduler
FlaxEulerDiscreteScheduler
Euler discrete scheduler
mesh
Mesh
JAX mesh for distributed computation

Methods

encode_prompt

Encodes text prompts using T5 and CLIP encoders. Parameters:
prompt
str | List[str]
The prompt or prompts to guide image generation
prompt_2
str | List[str]
Optional second prompt (defaults to prompt if not provided)
num_images_per_prompt
int
default:"1"
Number of images to generate per prompt
max_sequence_length
int
default:"512"
Maximum sequence length for T5 encoder
Returns:
prompt_embeds
jnp.ndarray
T5 text embeddings
pooled_prompt_embeds
jnp.ndarray
CLIP pooled embeddings
text_ids
jnp.ndarray
Text position IDs

prepare_latents

Prepares initial latent tensors for generation. Parameters:
batch_size
int
Batch size
num_channels_latents
int
Number of channels in latent space
height
int
Height of generated images
width
int
Width of generated images
dtype
jnp.dtype
Data type for latents
rng
jax.random.PRNGKey
Random key for initialization
Returns:
latents
jnp.ndarray
Packed latent tensors
latent_image_ids
jnp.ndarray
Position IDs for latents

time_shift

Applies time shifting to timesteps based on sequence length. Parameters:
latents
jnp.ndarray
Latent tensors
timesteps
jnp.ndarray
Original timesteps
Returns:
timesteps
jnp.ndarray
Shifted timesteps

__call__

Generate images from configuration. Parameters:
timesteps
int
Number of denoising steps
flux_params
dict
Flux transformer parameters
vae_params
dict
VAE parameters
Returns:
images
jnp.ndarray
Generated images

Key features

  • Dual text encoders: Combines T5 and CLIP for rich text understanding
  • Rotary position embeddings: Uses RoPE for better positional encoding
  • Flow matching: Uses flow matching scheduler for efficient sampling
  • Packed latents: Packs latent dimensions for efficient processing
  • Time shifting: Dynamically adjusts timesteps based on resolution

Example usage

import jax
from maxdiffusion import FluxPipeline

# Initialize pipeline with config
config = HyperParameters(
    pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev",
    prompt="a professional photograph of an astronaut",
    resolution=1024,
    guidance_scale=3.5,
    seed=42
)

# Create pipeline components
devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

pipeline = FluxPipeline(
    tokenizer=FluxPipeline.load_tokenizer(config),
    text_encoder=FluxPipeline.load_text_encoder(config),
    vae=FluxPipeline.load_vae(devices_array, mesh, rngs, config)[0],
    # ... other components
    mesh=mesh,
    config=config
)

# Generate
images = pipeline(
    timesteps=50,
    flux_params=flux_params,
    vae_params=vae_params
)

Build docs developers (and LLMs) love