FlaxUNet2DConditionModel is a conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a denoised sample.
FlaxUNet2DConditionModel
Model architecture
The UNet architecture consists of:- Input projection: Convolutional layer that projects input channels to initial feature dimensions
- Time embedding: Sinusoidal position embeddings for timestep conditioning
- Down blocks: Sequence of downsampling blocks with optional cross-attention
- Mid block: Middle block with cross-attention for processing bottleneck features
- Up blocks: Sequence of upsampling blocks with optional cross-attention
- Output projection: Final convolution layer with group normalization
src/maxdiffusion/models/unet_2d_condition_flax.py:54
Configuration parameters
The size of the input sample
The number of channels in the input sample
The number of channels in the output
The tuple of downsample blocks to use. Defaults to
("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")The tuple of upsample blocks to use. Defaults to
("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")The tuple of output channels for each block
The number of layers per block
The dimension of the attention heads
The dimension of the cross attention features
Dropout probability for down, up and bottleneck blocks
Attention mechanism to use. Options:
dot_product, flashMinimum sequence length required to apply flash attention
JAX mesh for distributed computation (required if attention is set to flash)
Configures AQT quantization from github.com/google/aqt
Methods
__call__
Forward pass through the UNet model.
Parameters:
Noisy inputs tensor of shape
(batch, channel, height, width)Timesteps for the diffusion process
Encoder hidden states of shape
(batch_size, sequence_length, hidden_size)Additional embeddings that are added to the embeddings passed along to the UNet blocks
Use deterministic functions and disable dropout when not training
The hidden states output conditioned on encoder_hidden_states input. Output of last layer of model with shape
(batch_size, num_channels, height, width)FlaxUNet2DConditionOutput
Output class for the UNet model located atsrc/maxdiffusion/models/unet_2d_condition_flax.py:41.
The hidden states output conditioned on
encoder_hidden_states input. Output of last layer of model.JAX features
The model supports inherent JAX features:- Just-In-Time (JIT) compilation
- Automatic differentiation
- Vectorization (vmap)
- Parallelization (pmap)