Skip to main content
Schedulers control the denoising process in diffusion models. MaxDiffusion provides several Flax-based schedulers.

FlaxDDIMScheduler

Denoising Diffusion Implicit Models (DDIM) scheduler for faster sampling. Located in src/maxdiffusion/schedulers/scheduling_ddim_flax.py:66

Configuration parameters

num_train_timesteps
int
default:"1000"
Number of diffusion steps used to train the model
beta_start
float
default:"0.0001"
The starting beta value of inference
beta_end
float
default:"0.02"
The final beta value
beta_schedule
str
default:"linear"
The beta schedule, a mapping from a beta range to a sequence of betas. Options: linear, scaled_linear, squaredcos_cap_v2
trained_betas
jnp.ndarray
Option to pass an array of betas directly to bypass beta_start/beta_end
set_alpha_to_one
bool
default:"True"
Whether to use alpha product of 1 for the final step (vs. alpha at step 0)
steps_offset
int
default:"0"
An offset added to the inference steps
prediction_type
str
default:"epsilon"
What the model predicts. Options: epsilon (noise), sample (direct sample), v_prediction
timestep_spacing
str
default:"leading"
How timesteps are spaced. Options: leading, trailing

Methods

create_state

Creates the scheduler state. Returns:
state
DDIMSchedulerState
Initialized scheduler state

set_timesteps

Sets the discrete timesteps for the diffusion chain. Parameters:
state
DDIMSchedulerState
Current scheduler state
num_inference_steps
int
Number of diffusion steps for inference
Returns:
state
DDIMSchedulerState
Updated scheduler state with timesteps

step

Performs one denoising step. Parameters:
state
DDIMSchedulerState
Current scheduler state
model_output
jnp.ndarray
Direct output from learned diffusion model
timestep
int
Current discrete timestep in the diffusion chain
sample
jnp.ndarray
Current instance of sample being created
eta
float
default:"0.0"
Weight of noise for added noise in diffusion step
Returns:
prev_sample
jnp.ndarray
Sample at previous timestep
state
DDIMSchedulerState
Updated scheduler state

FlaxFlowMatchScheduler

Flow matching scheduler for continuous-time diffusion, used in WAN models. Located in src/maxdiffusion/schedulers/scheduling_flow_match_flax.py:70

Configuration parameters

num_train_timesteps
int
default:"1000"
Number of training timesteps
shift
float
default:"3.0"
Shift parameter for flow matching (typically 3.0 for 480p, 5.0 for 720p)
sigma_max
float
default:"1.0"
Maximum sigma value
sigma_min
float
default:"0.003/1.002"
Minimum sigma value
inverse_timesteps
bool
default:"False"
Whether to use inverse timesteps
extra_one_step
bool
default:"False"
Whether to add an extra step
reverse_sigmas
bool
default:"False"
Whether to reverse the sigma schedule

Methods

set_timesteps

Sets timesteps with optional time shifting. Parameters:
state
FlowMatchSchedulerState
Current scheduler state
num_inference_steps
int
default:"100"
Number of diffusion steps
denoising_strength
float
default:"1.0"
Strength of denoising process
training
bool
default:"False"
Whether the scheduler is being used for training
shift
float
Optional shift value to override config
Returns:
state
FlowMatchSchedulerState
Updated scheduler state

step

Performs one flow matching step. Parameters:
state
FlowMatchSchedulerState
Current scheduler state
model_output
jnp.ndarray
Model predicted velocity
timestep
jnp.ndarray
Current timestep
sample
jnp.ndarray
Current sample
to_final
bool
default:"False"
Whether this is the final step
Returns:
prev_sample
jnp.ndarray
Sample at previous timestep
state
FlowMatchSchedulerState
Updated scheduler state

add_noise

Adds noise to samples according to the flow matching schedule. Parameters:
state
FlowMatchSchedulerState
Current scheduler state
original_samples
jnp.ndarray
Original clean samples
noise
jnp.ndarray
Noise to add
timesteps
jnp.ndarray
Timesteps corresponding to noise levels
Returns:
noisy_samples
jnp.ndarray
Noisy samples

Other schedulers

MaxDiffusion also supports:

FlaxEulerDiscreteScheduler

Euler discrete scheduler for efficient sampling.

FlaxUniPCMultistepScheduler

UniPC multistep scheduler for fast high-quality sampling, used in WAN models.

FlaxDPMSolverMultistepScheduler

DPM-Solver++ for fast sampling with fewer steps.

FlaxLMSDiscreteScheduler

Linear multistep scheduler.

FlaxPNDMScheduler

Pseudo numerical methods for diffusion models.

Example: Using a scheduler

import jax
from maxdiffusion.schedulers import FlaxDDIMScheduler

# Create scheduler
scheduler = FlaxDDIMScheduler(
    num_train_timesteps=1000,
    beta_schedule="scaled_linear",
    prediction_type="epsilon"
)

# Create state
state = scheduler.create_state()

# Set timesteps for inference
state = scheduler.set_timesteps(
    state,
    num_inference_steps=50
)

# Denoising loop
latents = jax.random.normal(key, shape=(1, 4, 64, 64))

for t in state.timesteps:
    # Get model prediction
    noise_pred = model(latents, t, encoder_hidden_states)
    
    # Denoise step
    output = scheduler.step(state, noise_pred, t, latents)
    latents = output.prev_sample
    state = output.state

Scheduler selection guide

  • DDIM: Fast sampling with deterministic results
  • Euler: Simple and efficient, good default choice
  • DPM-Solver++: Highest quality in fewest steps
  • Flow matching: For continuous-time models like WAN
  • UniPC: Fast convergence for video models

FlaxRectifiedFlowScheduler

Rectified flow scheduler for continuous normalizing flows, used in LTX-Video models. Provides multistep sampling with configurable shift parameters for optimal video generation quality. Located in src/maxdiffusion/schedulers/scheduling_rectified_flow.py

Build docs developers (and LLMs) love