Skip to main content
The FlaxStableDiffusionXLPipeline extends Stable Diffusion with higher resolution generation and dual text encoders.

FlaxStableDiffusionXLPipeline

Located in src/maxdiffusion/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py:43

Components

text_encoder
FlaxCLIPTextModel
First CLIP text encoder
text_encoder_2
FlaxCLIPTextModel
Second CLIP text encoder for improved text understanding
vae
FlaxAutoencoderKL
Variational Auto-Encoder (VAE) model to encode and decode images
tokenizer
CLIPTokenizer
First tokenizer
tokenizer_2
CLIPTokenizer
Second tokenizer corresponding to text_encoder_2
unet
FlaxUNet2DConditionModel
UNet to denoise the encoded image latents with additional conditioning
scheduler
SchedulerMixin
A scheduler to be used in combination with unet to denoise the encoded image latents

Methods

prepare_inputs

Tokenizes text prompts with both tokenizers. Parameters:
prompt
str | List[str]
The prompt or prompts to guide image generation
Returns:
inputs
jnp.ndarray
Stacked tokenized input IDs from both tokenizers with shape (batch_size, 2, sequence_length)

get_embeddings

Encodes prompts using dual text encoders. Parameters:
prompt_ids
jnp.array
Tokenized prompt IDs from both tokenizers
params
dict
Model parameters
Returns:
prompt_embeds
jnp.ndarray
Concatenated embeddings from both text encoders
text_embeds
jnp.ndarray
Pooled text embeddings from the second encoder

__call__

Generate images from text prompts. Parameters:
prompt_ids
jax.Array
Tokenized prompt IDs from both encoders
params
dict | FrozenDict
Model parameters for all pipeline components
prng_seed
jax.Array
Random seed for generation
num_inference_steps
int
default:"50"
The number of denoising steps
guidance_scale
float | jax.Array
default:"7.5"
Guidance scale for classifier-free guidance
height
int
The height in pixels of the generated image. Defaults to unet.config.sample_size * vae_scale_factor
width
int
The width in pixels of the generated image. Defaults to unet.config.sample_size * vae_scale_factor
latents
jnp.array
Pre-generated noisy latents
neg_prompt_ids
jnp.array
Tokenized negative prompt IDs
output_type
str
Output type - set to “latent” to return latents instead of decoded images
jit
bool
default:"False"
Whether to run pmap versions of the generation functions
Returns:
images
jnp.ndarray
Generated images or latents

Key differences from Stable Diffusion

  • Dual text encoders: Uses two CLIP text encoders for improved text understanding
  • Higher resolution: Optimized for generating 1024x1024 images
  • Additional conditioning: Uses pooled text embeddings and time IDs for micro-conditioning
  • Concatenated embeddings: Text embeddings from both encoders are concatenated

Example usage

import jax
from maxdiffusion import FlaxStableDiffusionXLPipeline

# Load pipeline
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    dtype=jax.numpy.bfloat16
)

prompt = "a professional photograph of an astronaut riding a horse"
prompt_ids = pipeline.prepare_inputs(prompt)

# Generate
prng_seed = jax.random.PRNGKey(0)
images = pipeline(
    prompt_ids,
    params,
    prng_seed,
    num_inference_steps=50,
    height=1024,
    width=1024
).images

Build docs developers (and LLMs) love