ConfigMixin for configuration management. All models inherit configuration capabilities for loading and saving.
ConfigMixin
Base configuration class that provides:- Configuration storage and access via
configattribute - Saving configurations to JSON
- Loading configurations from pretrained models
Key methods
from_config
Instantiates a model from a configuration dictionary.
Parameters:
Configuration dictionary with model parameters
Additional parameters to override config values
save_config
Saves the model configuration to a directory.
Parameters:
Directory to save the config.json file
Common model parameters
These parameters are common across most MaxDiffusion models:Data types
The dtype for activations and intermediate computations
The dtype for model weights/parameters
JAX precision for matmul operations. Options:
None, jax.lax.Precision.DEFAULT, jax.lax.Precision.HIGH, jax.lax.Precision.HIGHESTArchitecture
Number of input channels
Number of output channels
Size of input samples
Attention
Attention mechanism to use. Options:
dot_product: Standard scaled dot-product attentionflash: Flash attention for improved efficiency
Minimum sequence length required to apply flash attention
Block sizes for flash attention. Overrides default block sizes
Enable memory efficient attention (alternative to flash attention)
Normalization
Number of groups for group normalization layers
Regularization
Dropout probability for regularization
Model-specific configurations
UNet configuration
See UNet model reference for UNet-specific parameters including:down_block_types,up_block_typesblock_out_channelslayers_per_blockcross_attention_dim
VAE configuration
See VAE model reference for VAE-specific parameters including:latent_channelsscaling_factorblock_out_channels
Transformer configuration
See Transformer model reference for transformer-specific parameters including:num_layers,num_single_layersnum_attention_headsattention_head_dimjoint_attention_dim