Skip to main content
MaxDiffusion models use Flax’s ConfigMixin for configuration management. All models inherit configuration capabilities for loading and saving.

ConfigMixin

Base configuration class that provides:
  • Configuration storage and access via config attribute
  • Saving configurations to JSON
  • Loading configurations from pretrained models

Key methods

from_config

Instantiates a model from a configuration dictionary. Parameters:
config
dict
Configuration dictionary with model parameters
**kwargs
any
Additional parameters to override config values

save_config

Saves the model configuration to a directory. Parameters:
save_directory
str
Directory to save the config.json file

Common model parameters

These parameters are common across most MaxDiffusion models:

Data types

dtype
jnp.dtype
default:"jnp.float32"
The dtype for activations and intermediate computations
weights_dtype
jnp.dtype
default:"jnp.float32"
The dtype for model weights/parameters
precision
jax.lax.Precision
JAX precision for matmul operations. Options: None, jax.lax.Precision.DEFAULT, jax.lax.Precision.HIGH, jax.lax.Precision.HIGHEST

Architecture

in_channels
int
Number of input channels
out_channels
int
Number of output channels
sample_size
int
Size of input samples

Attention

attention_kernel
str
default:"dot_product"
Attention mechanism to use. Options:
  • dot_product: Standard scaled dot-product attention
  • flash: Flash attention for improved efficiency
flash_min_seq_length
int
default:"4096"
Minimum sequence length required to apply flash attention
flash_block_sizes
BlockSizes
Block sizes for flash attention. Overrides default block sizes
use_memory_efficient_attention
bool
default:"False"
Enable memory efficient attention (alternative to flash attention)

Normalization

norm_num_groups
int
default:"32"
Number of groups for group normalization layers

Regularization

dropout
float
default:"0.0"
Dropout probability for regularization

Model-specific configurations

UNet configuration

See UNet model reference for UNet-specific parameters including:
  • down_block_types, up_block_types
  • block_out_channels
  • layers_per_block
  • cross_attention_dim

VAE configuration

See VAE model reference for VAE-specific parameters including:
  • latent_channels
  • scaling_factor
  • block_out_channels

Transformer configuration

See Transformer model reference for transformer-specific parameters including:
  • num_layers, num_single_layers
  • num_attention_heads
  • attention_head_dim
  • joint_attention_dim

Configuration decorators

@flax_register_to_config

Decorator that registers a Flax model class to use ConfigMixin.
from maxdiffusion.configuration_utils import flax_register_to_config
import flax.linen as nn

@flax_register_to_config
class MyModel(nn.Module, ConfigMixin):
    in_channels: int = 3
    out_channels: int = 3
    
    def setup(self):
        # Access config via self.config
        channels = self.config.in_channels

Example: Loading a model config

from maxdiffusion.models import FlaxUNet2DConditionModel

# Load from pretrained
model = FlaxUNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet"
)

# Access configuration
print(model.config.in_channels)  # 4
print(model.config.attention_head_dim)  # 8

# Override config values
model = FlaxUNet2DConditionModel.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="unet",
    attention_kernel="flash",  # Override to use flash attention
    dtype=jnp.bfloat16  # Override dtype
)

Build docs developers (and LLMs) love