Skip to main content
The FlaxControlNetModel adds spatial conditioning to Stable Diffusion models through additional inputs like edge maps, depth maps, or pose keypoints.

FlaxControlNetModel

Model architecture

ControlNet extends UNet with:
  • Conditioning embedding: Processes control images (e.g., edge maps, depth maps)
  • Zero-initialized convolutions: Ensures training starts from the pretrained model
  • Parallel down/mid blocks: Mirror the UNet structure
  • Residual connections: Output residuals added to UNet activations
Located in src/maxdiffusion/models/controlnet_flax.py:100

Configuration parameters

sample_size
int
default:"32"
The size of the input sample
in_channels
int
default:"4"
The number of channels in the input sample
down_block_types
Tuple[str]
The tuple of downsample blocks to use. Defaults to ("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")
block_out_channels
Tuple[int]
default:"(320, 640, 1280, 1280)"
The tuple of output channels for each block
layers_per_block
int
default:"2"
The number of layers per block
attention_head_dim
int | Tuple[int]
default:"8"
The dimension of the attention heads
cross_attention_dim
int
default:"1280"
The dimension of the cross attention features
controlnet_conditioning_channel_order
str
default:"rgb"
The channel order of conditional image. Will convert to rgb if it’s bgr
conditioning_embedding_out_channels
Tuple[int]
default:"(16, 32, 96, 256)"
The tuple of output channels for each block in the conditioning_embedding layer

Methods

__call__

Forward pass through ControlNet. Parameters:
sample
jnp.ndarray
Noisy inputs tensor of shape (batch, channel, height, width)
timesteps
jnp.ndarray | float | int
Timesteps for the diffusion process
encoder_hidden_states
jnp.ndarray
Encoder hidden states of shape (batch_size, sequence_length, hidden_size)
controlnet_cond
jnp.ndarray
The conditional input tensor (e.g., edge map, depth map) of shape (batch, channel, height, width)
conditioning_scale
float
default:"1.0"
The scale factor for ControlNet outputs
train
bool
default:"False"
Use deterministic functions and disable dropout when not training
Returns:
down_block_res_samples
jnp.ndarray
Tuple of residual samples from down blocks to be added to UNet
mid_block_res_sample
jnp.ndarray
Residual sample from mid block to be added to UNet

FlaxControlNetOutput

Output dataclass for ControlNet located at src/maxdiffusion/models/controlnet_flax.py:30.
down_block_res_samples
jnp.ndarray
Residuals from down blocks
mid_block_res_sample
jnp.ndarray
Residual from mid block

FlaxControlNetConditioningEmbedding

Conditioning embedding module located at src/maxdiffusion/models/controlnet_flax.py:43.

Architecture

Processes control images through:
  • Initial convolution
  • Series of convolution blocks with downsampling
  • Zero-initialized output convolution
Parameters:
conditioning_embedding_channels
int
Output channels for conditioning embedding
block_out_channels
Tuple[int]
default:"(16, 32, 96, 256)"
Channels for each block in the conditioning embedding

Build docs developers (and LLMs) love