Skip to main content
The FlaxStableDiffusionPipeline is a Flax-based pipeline for text-to-image generation using Stable Diffusion.

FlaxStableDiffusionPipeline

Located in src/maxdiffusion/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py:77

Components

vae
FlaxAutoencoderKL
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations
text_encoder
FlaxCLIPTextModel
Frozen text-encoder (clip-vit-large-patch14)
tokenizer
CLIPTokenizer
A CLIPTokenizer to tokenize text
unet
FlaxUNet2DConditionModel
A FlaxUNet2DConditionModel to denoise the encoded image latents
scheduler
SchedulerMixin
A scheduler to be used in combination with unet to denoise the encoded image latents. Can be one of FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler, or FlaxDPMSolverMultistepScheduler

Methods

prepare_inputs

Tokenizes text prompts. Parameters:
prompt
str | List[str]
The prompt or prompts to guide image generation
Returns:
input_ids
np.ndarray
Tokenized input IDs

__call__

Generate images from text prompts. Parameters:
prompt_ids
jnp.array
Tokenized prompt IDs
params
dict | FrozenDict
Model parameters for all pipeline components
prng_seed
jax.Array
Random seed for generation
num_inference_steps
int
default:"50"
The number of denoising steps. More steps usually lead to higher quality images at the expense of slower inference
height
int
The height in pixels of the generated image. Defaults to unet.config.sample_size * vae_scale_factor
width
int
The width in pixels of the generated image. Defaults to unet.config.sample_size * vae_scale_factor
guidance_scale
float
default:"7.5"
A higher guidance scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality. Guidance scale is enabled when guidance_scale > 1
latents
jnp.ndarray
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation
neg_prompt_ids
jnp.ndarray
Tokenized negative prompt IDs for classifier-free guidance
jit
bool
default:"False"
Whether to run pmap versions of the generation functions
Returns:
images
np.ndarray
Generated images as numpy arrays
nsfw_content_detected
bool
Whether NSFW content was detected (always False in current implementation)

Example usage

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from maxdiffusion import FlaxStableDiffusionPipeline

# Load pipeline
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="bf16",
    dtype=jax.numpy.bfloat16
)

prompt = "a photo of an astronaut riding a horse on mars"

# Prepare inputs
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# Shard inputs and params
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

# Generate
images = pipeline(
    prompt_ids,
    params,
    prng_seed,
    num_inference_steps,
    jit=True
).images

images = pipeline.numpy_to_pil(
    np.asarray(images.reshape((num_samples,) + images.shape[-3:]))
)

Build docs developers (and LLMs) love