Skip to main content
Dreambooth enables personalized fine-tuning of Stable Diffusion models using a small set of example images (3-5 images). MaxDiffusion supports Dreambooth training for Stable Diffusion 1.x and 2.x models.

Supported models

  • Stable Diffusion 1.4
  • Stable Diffusion 2.x
Dreambooth training uses prior preservation to maintain model quality while learning the new concept.

Basic training

python src/maxdiffusion/dreambooth/train_dreambooth.py src/maxdiffusion/configs/base14.yml \
  class_data_dir=<your-class-dir> \
  instance_data_dir=<your-instance-dir> \
  instance_prompt="a photo of ohwx dog" \
  class_prompt="photo of a dog" \
  max_train_steps=150 \
  jax_cache_dir=<your-cache-dir> \
  activations_dtype=bfloat16 \
  weights_dtype=float32 \
  per_device_batch_size=1 \
  enable_profiler=False \
  precision=DEFAULT \
  cache_dreambooth_dataset=False \
  learning_rate=4e-6 \
  num_class_images=100 \
  run_name=<your-run-name> \
  output_dir=gs://<your-bucket-name>

Dataset preparation

1

Prepare instance images

Create a directory with 3-5 images of your subject:
mkdir -p /path/to/instance_images
# Add your subject images (e.g., photos of your dog)
Each image should clearly show the subject you want to train.
2

Prepare class images (optional)

Create or generate 100-200 images of the same class:
mkdir -p /path/to/class_images
# Add general class images (e.g., random dog photos)
MaxDiffusion can automatically generate class images using num_class_images.
3

Choose unique identifier

Select a unique identifier token (e.g., ohwx, sks, xyz123) that doesn’t exist in the model’s vocabulary.This will be used in your instance prompt: "a photo of ohwx dog"

Configuration

Dreambooth training extends the base Stable Diffusion configs with additional parameters.

Dreambooth-specific parameters

ParameterDefaultDescription
instance_data_dir''Directory with instance images (your subject)
class_data_dir''Directory with class images (for prior preservation)
instance_prompt''Prompt describing instance (e.g., “a photo of ohwx dog”)
class_prompt''Prompt describing class (e.g., “a photo of a dog”)
num_class_images100Number of class images to generate if missing
prior_loss_weight1.0Weight for prior preservation loss
cache_dreambooth_datasetFalseCache preprocessed dataset

Training parameters

Recommended settings for Dreambooth:
learning_rate: 4e-6  # Lower than standard fine-tuning
max_train_steps: 150  # Fewer steps needed
per_device_batch_size: 1
activations_dtype: bfloat16
weights_dtype: float32
precision: DEFAULT

Model configuration

Use Stable Diffusion 1.4 or 2 Base:
# For SD 1.4
pretrained_model_name_or_path: 'CompVis/stable-diffusion-v1-4'
revision: 'flax'

# For SD 2 Base
pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
revision: 'main'

Example: Training on a custom subject

Dog example

python src/maxdiffusion/dreambooth/train_dreambooth.py \
  src/maxdiffusion/configs/base14.yml \
  instance_data_dir="/data/my_dog_photos" \
  class_data_dir="/data/class_dogs" \
  instance_prompt="a photo of ohwx dog" \
  class_prompt="a photo of a dog" \
  num_class_images=100 \
  max_train_steps=150 \
  learning_rate=4e-6 \
  per_device_batch_size=1 \
  output_dir="gs://my-bucket/dreambooth-dog" \
  run_name="my-dog-model" \
  jax_cache_dir="/tmp/jax_cache"

Person example

python src/maxdiffusion/dreambooth/train_dreambooth.py \
  src/maxdiffusion/configs/base14.yml \
  instance_data_dir="/data/my_photos" \
  class_data_dir="/data/class_people" \
  instance_prompt="a photo of sks person" \
  class_prompt="a photo of a person" \
  num_class_images=200 \
  max_train_steps=200 \
  learning_rate=2e-6 \
  per_device_batch_size=1 \
  output_dir="gs://my-bucket/dreambooth-person" \
  run_name="my-person-model"

Prior preservation

Prior preservation prevents the model from overfitting to your instance images by training on both:
  1. Instance images: Your specific subject (e.g., your dog)
  2. Class images: General examples of the class (e.g., random dogs)
The loss function balances:
total_loss = instance_loss + prior_loss_weight * class_loss

Adjusting prior preservation

Increase prior_loss_weight to preserve more of the original model:
prior_loss_weight: 2.0  # Stronger preservation
Decrease to allow more customization:
prior_loss_weight: 0.5  # More flexible to instance data

Dataset caching

For faster training iterations, cache the preprocessed dataset:
cache_dreambooth_dataset: True
dataset_save_location: "/tmp/dreambooth_cache"

Hyperparameter tuning

Learning rate

  • Too high (greater than 1e-5): Model forgets original capabilities, overfits to instances
  • Too low (less than 1e-7): Slow learning, may not capture subject
  • Recommended: 2e-6 to 5e-6

Training steps

  • Too few (less than 100): Subject not learned well
  • Too many (greater than 400): Overfitting, loss of diversity
  • Recommended: 150-250 steps

Number of instance images

  • Minimum: 3 images
  • Recommended: 5-10 images
  • Maximum: 20 images (beyond this, standard fine-tuning may be better)

Number of class images

  • Recommended: 100-200 images
  • More class images = better prior preservation
  • Fewer class images = faster training but possible overfitting

Generate images after training

Use your Dreambooth model to generate images:
python -m src.maxdiffusion.generate src/maxdiffusion/configs/base14.yml \
  run_name="my-dog-model" \
  output_dir="gs://my-bucket/dreambooth-dog" \
  prompt="a photo of ohwx dog on the beach" \
  from_pt=False \
  attention=dot_product

Best practices

Image quality

  1. Use high-quality, well-lit images
  2. Show subject from multiple angles
  3. Vary backgrounds and contexts
  4. Avoid heavily edited or filtered images

Prompt engineering

  1. Use a rare identifier token (ohwx, sks, xyz123)
  2. Include the class name in prompts (“dog”, “person”, “style”)
  3. Keep instance and class prompts similar in structure

Training tips

  1. Start with default hyperparameters
  2. Monitor loss - should decrease steadily
  3. Test generation every 50 steps to check progress
  4. Use GCS buckets for output to prevent data loss
  5. Enable profiling for the first run to check performance

Troubleshooting

Model overfits to training images

  • Decrease max_train_steps (try 100 instead of 150)
  • Increase prior_loss_weight (try 1.5 or 2.0)
  • Add more varied instance images

Subject not learned well

  • Increase max_train_steps (try 200 or 250)
  • Increase learning_rate (try 5e-6)
  • Reduce prior_loss_weight (try 0.5)
  • Use more instance images

Generated images lack diversity

  • Increase num_class_images (try 200)
  • Increase prior_loss_weight
  • Use more diverse instance images

Out of memory

  • Reduce per_device_batch_size to 1
  • Use bfloat16 for weights and activations
  • Disable cache_dreambooth_dataset

Advanced configuration

Optimizer settings

adam_b1: 0.9
adam_b2: 0.999
adam_eps: 1.e-8
adam_weight_decay: 1.e-2
max_grad_norm: 1.0

Data preprocessing

resolution: 512  # Match base model resolution
center_crop: True  # Crop images to square
random_flip: False  # Disable for asymmetric subjects

Checkpointing

checkpoint_every: 50  # Save every 50 steps

Monitoring

View training progress:
tensorboard --logdir=gs://my-bucket/dreambooth-dog/my-dog-model/tensorboard/
Key metrics to watch:
  • Training loss (should decrease)
  • Learning rate (follows schedule)
  • Step time (should stabilize after warmup)

Build docs developers (and LLMs) love