Skip to main content
MaxDiffusion supports loading LoRA (Low-Rank Adaptation) adapters for efficient model customization during inference. This feature is available for both Flux and Wan models.

Flux LoRA

Flux supports loading LoRA adapters to customize image generation without retraining the full model.
Not all LoRA formats have been tested. If a specific LoRA doesn’t load, please report it to the team.

Tested LoRA models

Basic usage

First, download the LoRA file to a local directory:
# Example: downloading anime_lora.safetensors to /home/jfacevedo/
Then run inference with the LoRA configuration:
python src/maxdiffusion/generate_flux.py \
  src/maxdiffusion/configs/base_flux_dev.yml \
  jax_cache_dir=/tmp/cache_dir \
  run_name=flux_test \
  output_dir=/tmp/ \
  prompt='A cute corgi lives in a house made out of sushi, anime' \
  num_inference_steps=28 \
  ici_data_parallelism=1 \
  ici_fsdp_parallelism=-1 \
  split_head_dim=True \
  lora_config='{"lora_model_name_or_path" : ["/home/jfacevedo/anime_lora.safetensors"], "weight_name" : ["anime_lora.safetensors"], "adapter_name" : ["anime"], "scale": [0.8], "from_pt": ["true"]}'

LoRA configuration parameters

The lora_config parameter accepts a JSON string with the following fields:
  • lora_model_name_or_path: Path to the LoRA weights file
  • weight_name: Name of the weights file
  • adapter_name: Identifier for the adapter
  • scale: Strength of the LoRA effect (0.0 to 1.0, typically 0.6-0.8)
  • from_pt: Whether to load from PyTorch format (use “true” for safetensors)

Wan LoRA

Wan models support LoRA adapters in ComfyUI and AI Toolkit formats for video generation customization.
Not all LoRA formats have been tested. Currently supports ComfyUI and AI Toolkit formats. If a specific LoRA doesn’t load, please let the team know.

Setup

  1. Create a copy of the relevant config file (e.g., src/maxdiffusion/configs/base_wan_i2v_14b.yml)
  2. Update the prompt and LoRA details in the config
  3. Set enable_lora: True in the config

Running inference

HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true \
  --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true \
  --xla_tpu_enable_async_collective_fusion_multiple_steps=true \
  --xla_tpu_overlap_compute_collective_tc=true \
  --xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_i2v_14b.yml \
  jax_cache_dir=gs://jfacevedo-maxdiffusion/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_context_parallelism=2 \
  run_name=wan-lora-inference-testing-720p \
  output_dir=gs:/jfacevedo-maxdiffusion \
  seed=118445 \
  enable_lora=True

How it works

MaxDiffusion’s LoRA implementation supports:
  • Standard LoRA: Low-rank decomposition with down and up projection matrices
  • Weight diffs: Direct weight modifications for fine-tuning
  • Bias diffs: Bias parameter adjustments
  • LoCON: LoRA for convolutional layers with kernel size > 1x1
The implementation uses device-side computation via JAX JIT for efficient merging of LoRA weights with base model parameters.

Build docs developers (and LLMs) love