Skip to main content
The FlaxUNet2DConditionModel is a conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a denoised sample.

FlaxUNet2DConditionModel

Model architecture

The UNet architecture consists of:
  • Input projection: Convolutional layer that projects input channels to initial feature dimensions
  • Time embedding: Sinusoidal position embeddings for timestep conditioning
  • Down blocks: Sequence of downsampling blocks with optional cross-attention
  • Mid block: Middle block with cross-attention for processing bottleneck features
  • Up blocks: Sequence of upsampling blocks with optional cross-attention
  • Output projection: Final convolution layer with group normalization
Located in src/maxdiffusion/models/unet_2d_condition_flax.py:54

Configuration parameters

sample_size
int
default:"32"
The size of the input sample
in_channels
int
default:"4"
The number of channels in the input sample
out_channels
int
default:"4"
The number of channels in the output
down_block_types
Tuple[str]
The tuple of downsample blocks to use. Defaults to ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
up_block_types
Tuple[str]
The tuple of upsample blocks to use. Defaults to ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
block_out_channels
Tuple[int]
default:"(320, 640, 1280, 1280)"
The tuple of output channels for each block
layers_per_block
int
default:"2"
The number of layers per block
attention_head_dim
int | Tuple[int]
default:"8"
The dimension of the attention heads
cross_attention_dim
int
default:"1280"
The dimension of the cross attention features
dropout
float
default:"0.0"
Dropout probability for down, up and bottleneck blocks
attention_kernel
str
default:"dot_product"
Attention mechanism to use. Options: dot_product, flash
flash_min_seq_length
int
default:"4096"
Minimum sequence length required to apply flash attention
mesh
jax.sharding.Mesh
JAX mesh for distributed computation (required if attention is set to flash)
quant
AqtQuantization
Configures AQT quantization from github.com/google/aqt

Methods

__call__

Forward pass through the UNet model. Parameters:
sample
jnp.ndarray
Noisy inputs tensor of shape (batch, channel, height, width)
timesteps
jnp.ndarray | float | int
Timesteps for the diffusion process
encoder_hidden_states
jnp.ndarray
Encoder hidden states of shape (batch_size, sequence_length, hidden_size)
added_cond_kwargs
dict
Additional embeddings that are added to the embeddings passed along to the UNet blocks
train
bool
default:"False"
Use deterministic functions and disable dropout when not training
Returns:
sample
jnp.ndarray
The hidden states output conditioned on encoder_hidden_states input. Output of last layer of model with shape (batch_size, num_channels, height, width)

FlaxUNet2DConditionOutput

Output class for the UNet model located at src/maxdiffusion/models/unet_2d_condition_flax.py:41.
sample
jnp.ndarray
The hidden states output conditioned on encoder_hidden_states input. Output of last layer of model.

JAX features

The model supports inherent JAX features:
  • Just-In-Time (JIT) compilation
  • Automatic differentiation
  • Vectorization (vmap)
  • Parallelization (pmap)

Build docs developers (and LLMs) love