FluxTransformer2DModel
The Flux transformer model for text-to-image generation.Model architecture
The Flux transformer consists of:- Patch embedding: Converts input into patches
- Time and text embeddings: Combined timestep and text projection embeddings
- Position embeddings: Rotary position embeddings (RoPE)
- Double blocks: MMDiT blocks with separate image and text streams
- Single blocks: Single-stream transformer blocks
- Output projection: Projects back to patch space
src/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:298
Configuration parameters
Patch size to turn the input data into small patches
The number of channels in the input
The number of layers of MMDiT blocks to use
The number of layers of single DiT blocks to use
The number of channels in each head
The number of heads to use for multi-head attention
The number of encoder_hidden_states dimensions to use
Number of dimensions to use when projecting the pooled_projections
Whether to use guidance embeddings
Dimensions for rotary position embeddings
Attention mechanism to use
Minimum sequence length required to apply flash attention
JAX mesh for distributed computation
Methods
__call__
Forward pass through the Flux transformer.
Parameters:
Input latent representations
Text encoder hidden states
Pooled text embeddings
Timestep for the diffusion process
Image position IDs
Text position IDs
Guidance scale values
Predicted noise or velocity with shape matching the input
FluxTransformerBlock
Double-stream transformer block (MMDiT architecture) located atsrc/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:153.
Features
- Dual processing streams for images and text
- Adaptive layer normalization
- Cross-attention between modalities
- MLP feedforward networks
FluxSingleTransformerBlock
Single-stream transformer block located atsrc/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:53.
Features
- Combined image and text processing
- Adaptive layer normalization
- Attention with RoPE
- Gated MLP activation
Transformer2DModelOutput
Output dataclass for transformer models located atsrc/maxdiffusion/models/flux/transformers/transformer_flux_flax.py:41.
The hidden states output with shape
(batch_size, num_channels, height, width)