FlaxControlNetModel adds spatial conditioning to Stable Diffusion models through additional inputs like edge maps, depth maps, or pose keypoints.
FlaxControlNetModel
Model architecture
ControlNet extends UNet with:- Conditioning embedding: Processes control images (e.g., edge maps, depth maps)
- Zero-initialized convolutions: Ensures training starts from the pretrained model
- Parallel down/mid blocks: Mirror the UNet structure
- Residual connections: Output residuals added to UNet activations
src/maxdiffusion/models/controlnet_flax.py:100
Configuration parameters
The size of the input sample
The number of channels in the input sample
The tuple of downsample blocks to use. Defaults to
("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")The tuple of output channels for each block
The number of layers per block
The dimension of the attention heads
The dimension of the cross attention features
The channel order of conditional image. Will convert to
rgb if it’s bgrThe tuple of output channels for each block in the conditioning_embedding layer
Methods
__call__
Forward pass through ControlNet.
Parameters:
Noisy inputs tensor of shape
(batch, channel, height, width)Timesteps for the diffusion process
Encoder hidden states of shape
(batch_size, sequence_length, hidden_size)The conditional input tensor (e.g., edge map, depth map) of shape
(batch, channel, height, width)The scale factor for ControlNet outputs
Use deterministic functions and disable dropout when not training
Tuple of residual samples from down blocks to be added to UNet
Residual sample from mid block to be added to UNet
FlaxControlNetOutput
Output dataclass for ControlNet located atsrc/maxdiffusion/models/controlnet_flax.py:30.
Residuals from down blocks
Residual from mid block
FlaxControlNetConditioningEmbedding
Conditioning embedding module located atsrc/maxdiffusion/models/controlnet_flax.py:43.
Architecture
Processes control images through:- Initial convolution
- Series of convolution blocks with downsampling
- Zero-initialized output convolution
Output channels for conditioning embedding
Channels for each block in the conditioning embedding