Skip to main content
MaxDiffusion includes transformer-based models for diffusion, including the Flux transformer architecture.

FluxTransformer2DModel

The Flux transformer model for text-to-image generation.

Model architecture

The Flux transformer consists of:
  • Patch embedding: Converts input into patches
  • Time and text embeddings: Combined timestep and text projection embeddings
  • Position embeddings: Rotary position embeddings (RoPE)
  • Double blocks: MMDiT blocks with separate image and text streams
  • Single blocks: Single-stream transformer blocks
  • Output projection: Projects back to patch space
Located in src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:298

Configuration parameters

patch_size
int
default:"1"
Patch size to turn the input data into small patches
in_channels
int
default:"64"
The number of channels in the input
num_layers
int
default:"19"
The number of layers of MMDiT blocks to use
num_single_layers
int
default:"38"
The number of layers of single DiT blocks to use
attention_head_dim
int
default:"128"
The number of channels in each head
num_attention_heads
int
default:"24"
The number of heads to use for multi-head attention
joint_attention_dim
int
default:"4096"
The number of encoder_hidden_states dimensions to use
pooled_projection_dim
int
default:"768"
Number of dimensions to use when projecting the pooled_projections
guidance_embeds
bool
default:"False"
Whether to use guidance embeddings
axes_dims_rope
Tuple[int]
default:"(16, 56, 56)"
Dimensions for rotary position embeddings
attention_kernel
str
default:"dot_product"
Attention mechanism to use
flash_min_seq_length
int
default:"4096"
Minimum sequence length required to apply flash attention
mesh
jax.sharding.Mesh
JAX mesh for distributed computation

Methods

__call__

Forward pass through the Flux transformer. Parameters:
hidden_states
jnp.ndarray
Input latent representations
encoder_hidden_states
jnp.ndarray
Text encoder hidden states
pooled_projections
jnp.ndarray
Pooled text embeddings
timestep
jnp.ndarray
Timestep for the diffusion process
img_ids
jnp.ndarray
Image position IDs
txt_ids
jnp.ndarray
Text position IDs
guidance
jnp.ndarray
Guidance scale values
Returns:
sample
jnp.ndarray
Predicted noise or velocity with shape matching the input

FluxTransformerBlock

Double-stream transformer block (MMDiT architecture) located at src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:153.

Features

  • Dual processing streams for images and text
  • Adaptive layer normalization
  • Cross-attention between modalities
  • MLP feedforward networks

FluxSingleTransformerBlock

Single-stream transformer block located at src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:53.

Features

  • Combined image and text processing
  • Adaptive layer normalization
  • Attention with RoPE
  • Gated MLP activation

Transformer2DModelOutput

Output dataclass for transformer models located at src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:41.
sample
jnp.ndarray
The hidden states output with shape (batch_size, num_channels, height, width)

Build docs developers (and LLMs) love