Skip to main content
ControlNet enables precise control over image generation using conditioning signals like edge maps, depth maps, and segmentation masks.

Installation

ControlNet requires OpenCV for image processing:
apt-get update && apt-get install ffmpeg libsm6 libxext6 -y

Quick start

python src/maxdiffusion/controlnet/generate_controlnet_replicated.py

Supported models

MaxDiffusion supports ControlNet for:
  • Stable Diffusion 1.4: Uses runwayml/stable-diffusion-v1-5 base model
  • Stable Diffusion XL: Uses SDXL base model with ControlNet

Architecture

ControlNet adds trainable copies of the encoder layers to enable conditioning:
  • Control input: Edge map, depth map, or other conditioning signal
  • ControlNet: Trainable encoder that processes control input
  • Base model: Stable Diffusion UNet with injected control features
  • Conditioning scale: Adjustable influence of control signal

SD 1.4 ControlNet

Basic usage

The SD 1.4 pipeline (generate_controlnet_replicated.py:src/maxdiffusion/controlnet/generate_controlnet_replicated.py) uses a Canny edge detector:
python src/maxdiffusion/controlnet/generate_controlnet_replicated.py

Configuration

Customize via config parameters:
config.prompt = "a photograph of a modern building"
config.negative_prompt = "blurry, distorted"
config.controlnet_image = "https://example.com/edge_map.png"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-canny"
config.controlnet_from_pt = True
config.controlnet_conditioning_scale = 1.0
config.num_inference_steps = 50
config.per_device_batch_size = 1
config.seed = 0

Implementation

The SD 1.4 pipeline loads ControlNet and base model (generate_controlnet_replicated.py:41-47):
# Load ControlNet
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    config.controlnet_model_name_or_path, 
    from_pt=config.controlnet_from_pt, 
    dtype=jnp.float32
)

# Load pipeline with ControlNet
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    config.pretrained_model_name_or_path, 
    controlnet=controlnet, 
    revision=config.revision, 
    dtype=jnp.float32
)
params["controlnet"] = controlnet_params

Inference

The pipeline processes control image and generates conditioned output (generate_controlnet_replicated.py:52-74):
# Prepare inputs
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)

# Replicate and shard
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
processed_image = shard(processed_image)

# Generate
output = pipe(
    prompt_ids=prompt_ids,
    image=processed_image,
    params=p_params,
    prng_seed=rng,
    num_inference_steps=config.num_inference_steps,
    neg_prompt_ids=negative_prompt_ids,
    controlnet_conditioning_scale=controlnet_conditioning_scale,
    jit=True,
).images

SDXL ControlNet

Basic usage

The SDXL pipeline (generate_controlnet_sdxl_replicated.py:src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py) includes Canny edge detection:
python src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py

Edge detection

The SDXL pipeline applies Canny edge detection to input images (generate_controlnet_sdxl_replicated.py:44-49):
image = load_image(config.controlnet_image)
image = np.array(image)

# Apply Canny edge detection
image = cv2.Canny(image, 100, 200)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
image = Image.fromarray(image)
This creates a 3-channel edge map suitable for ControlNet conditioning.

Implementation

SDXL ControlNet uses bfloat16 precision (generate_controlnet_sdxl_replicated.py:51-63):
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    config.controlnet_model_name_or_path, 
    from_pt=config.controlnet_from_pt, 
    dtype=config.activations_dtype
)

pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
    config.pretrained_model_name_or_path, 
    controlnet=controlnet, 
    revision=config.revision, 
    dtype=config.activations_dtype
)

# Cast params to bfloat16 (except scheduler)
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state
params["controlnet"] = controlnet_params

Parameters

ParameterDescriptionDefault
promptText description of desired imageRequired
negative_promptConcepts to avoidEmpty
controlnet_imageURL or path to control imageRequired
controlnet_model_name_or_pathControlNet model checkpointRequired
controlnet_from_ptConvert from PyTorch formatTrue
controlnet_conditioning_scaleControl signal strength (0.0-2.0)1.0
num_inference_stepsDenoising steps50
per_device_batch_sizeImages per device1
seedRandom seed0

Conditioning scale

The controlnet_conditioning_scale parameter controls how strongly the control signal influences generation:
  • 0.0: No control (standard generation)
  • 0.5: Light control, more creative freedom
  • 1.0: Balanced control (recommended)
  • 1.5: Strong control adherence
  • 2.0: Very strict control following

Example: Adjusting control strength

# Light control for more variation
config.controlnet_conditioning_scale = 0.5

# Strong control for precise structure
config.controlnet_conditioning_scale = 1.5

Control signal types

ControlNet supports various conditioning types:

Canny edges

  • Use case: Preserve structural composition
  • Model: lllyasviel/sd-controlnet-canny
  • Preprocessing: Canny edge detection (threshold 100, 200)

Depth maps

  • Use case: Control spatial depth and 3D structure
  • Model: lllyasviel/sd-controlnet-depth
  • Preprocessing: MiDaS depth estimation

Segmentation

  • Use case: Control object layout and positioning
  • Model: lllyasviel/sd-controlnet-seg
  • Preprocessing: Semantic segmentation

Human pose

  • Use case: Control human figure poses
  • Model: lllyasviel/sd-controlnet-openpose
  • Preprocessing: OpenPose skeleton detection

Custom control images

Provide custom edge maps or control signals:
config.controlnet_image = "/path/to/custom_edge_map.png"
Control images should:
  • Match the target generation resolution
  • Be grayscale or 3-channel (RGB)
  • Clearly define structural elements

Multi-device inference

ControlNet uses pmap for multi-device replication:
num_samples = jax.device_count() * config.per_device_batch_size
rng = jax.random.split(rng, jax.device_count())

# Replicate params across devices
p_params = replicate(params)

# Shard inputs
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
This distributes generation across all available devices.

Output

Generated images are saved as generated_image.png. The first image in the batch is saved by default. To save all images:
for i, image in enumerate(output_images):
  image.save(f"generated_image_{i}.png")

Examples

Building from edges

config.prompt = "a modern glass office building, blue sky, photorealistic"
config.negative_prompt = "blurry, cartoon, painting"
config.controlnet_image = "building_edges.png"
config.controlnet_conditioning_scale = 1.2

Portrait from pose

config.prompt = "portrait of a woman in a red dress, studio lighting"
config.negative_prompt = "deformed, blurry"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-openpose"
config.controlnet_image = "pose_skeleton.png"
config.controlnet_conditioning_scale = 1.0

Landscape from depth

config.prompt = "mountain landscape at sunset, dramatic clouds"
config.negative_prompt = "flat, boring"
config.controlnet_model_name_or_path = "lllyasviel/sd-controlnet-depth"
config.controlnet_image = "depth_map.png"
config.controlnet_conditioning_scale = 0.8

Next steps

SDXL inference

Higher quality base model for ControlNet

Stable Diffusion

Standard SD inference without control

Training overview

Train custom ControlNet models

Configuration

Full configuration reference

Build docs developers (and LLMs) love