Skip to main content

Overview

Model merging in ComfyUI allows you to combine weights from multiple models to create hybrid models with characteristics of both sources. This technique is widely used in the generative AI community to blend artistic styles, improve model capabilities, or create specialized variants.

Basic Concepts

State Dictionaries

Models in PyTorch store their learnable parameters in a state dictionary:
import torch

# Get model state dict
state_dict = model.model.state_dict()

# State dict structure:
{
    'layer1.weight': tensor([...]),
    'layer1.bias': tensor([...]),
    'layer2.weight': tensor([...]),
    ...
}
Each key represents a layer or parameter name, and each value is a tensor containing the weights.

Weight Interpolation

The simplest merging strategy is linear interpolation:
merged_weight = weight1 * (1 - alpha) + weight2 * alpha
Where alpha is the merge ratio (0.0 = 100% model1, 1.0 = 100% model2).

Merge Strategies

Linear Interpolation (Weighted Sum)

The most common method - blend weights linearly:
def merge_linear(model1, model2, ratio=0.5):
    """
    Merge two models using linear interpolation
    
    Args:
        model1: First model (ComfyUI model object)
        model2: Second model (ComfyUI model object)
        ratio: Blend ratio (0.0 to 1.0)
    
    Returns:
        Merged model
    """
    # Clone the first model
    merged_model = model1.clone()
    
    # Get state dicts
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    # Merge weights
    merged_state = {}
    for key in state_dict1.keys():
        if key in state_dict2:
            merged_state[key] = (
                state_dict1[key] * (1 - ratio) + 
                state_dict2[key] * ratio
            )
        else:
            # Keep original weight if not in model2
            merged_state[key] = state_dict1[key]
    
    # Load merged weights
    merged_model.model.load_state_dict(merged_state)
    
    return merged_model
Linear interpolation works best when models share the same architecture and were trained on similar data.

Add Difference (Model Arithmetic)

Combine models using arithmetic operations:
def merge_add_difference(base_model, model_a, model_b, multiplier=1.0):
    """
    Merge using: base + (model_a - model_b) * multiplier
    
    This technique adds the "difference" between two models to a base.
    Useful for transferring specific features or styles.
    """
    merged_model = base_model.clone()
    
    state_base = base_model.model.state_dict()
    state_a = model_a.model.state_dict()
    state_b = model_b.model.state_dict()
    
    merged_state = {}
    for key in state_base.keys():
        if key in state_a and key in state_b:
            difference = state_a[key] - state_b[key]
            merged_state[key] = state_base[key] + difference * multiplier
        else:
            merged_state[key] = state_base[key]
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model
If you have:
  • Base: General purpose model
  • Model A: Model fine-tuned for anime style
  • Model B: Original base before fine-tuning
Then base + (A - B) transfers the anime style to your base model.

Selective Layer Merging

Merge different layers with different ratios:
def merge_selective(model1, model2, layer_ratios):
    """
    Merge with different ratios per layer
    
    Args:
        layer_ratios: Dict mapping layer patterns to ratios
                     e.g., {'encoder': 0.3, 'decoder': 0.7, 'default': 0.5}
    """
    merged_model = model1.clone()
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    merged_state = {}
    for key in state_dict1.keys():
        # Determine ratio for this layer
        ratio = layer_ratios.get('default', 0.5)
        for pattern, r in layer_ratios.items():
            if pattern in key:
                ratio = r
                break
        
        if key in state_dict2:
            merged_state[key] = (
                state_dict1[key] * (1 - ratio) + 
                state_dict2[key] * ratio
            )
        else:
            merged_state[key] = state_dict1[key]
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model
Selective merging is useful when you want to preserve certain model characteristics. For example, keep encoder layers from model1 while using decoder layers from model2.

Working with Model Patches

ComfyUI uses a patching system for efficient model modifications:
# Clone a model (creates a patcher)
cloned_model = original_model.clone()

# Access the underlying PyTorch model
pytorch_model = cloned_model.model

# Get all models in the patch chain
for patch_model in model.model_patches_models():
    print(f"Patch: {patch_model}")

Understanding Model Cloning

When you clone a model in ComfyUI:
  1. Creates a new model patcher
  2. Shares weights with the original (memory efficient)
  3. Allows independent modifications
  4. Tracks parent-child relationships
from comfy import model_management

# Clone preserves the patch hierarchy
cloned = model.clone()

# Check if models are clones
if model.is_clone(cloned):
    print("Models share base weights")

Merging Quantized Models

When working with quantized models, you need to handle quantization metadata:
from comfy.quant_ops import QuantizedTensor
import torch

def merge_quantized(model1, model2, ratio=0.5):
    """
    Merge models that may contain quantized weights
    """
    merged_model = model1.clone()
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    merged_state = {}
    for key in state_dict1.keys():
        if key not in state_dict2:
            merged_state[key] = state_dict1[key]
            continue
        
        weight1 = state_dict1[key]
        weight2 = state_dict2[key]
        
        # Check if either weight is quantized
        is_quantized1 = isinstance(weight1, QuantizedTensor)
        is_quantized2 = isinstance(weight2, QuantizedTensor)
        
        if is_quantized1 or is_quantized2:
            # Dequantize before merging
            if is_quantized1:
                weight1 = weight1.dequantize()
            if is_quantized2:
                weight2 = weight2.dequantize()
            
            # Merge in high precision
            merged_weight = weight1 * (1 - ratio) + weight2 * ratio
            
            # Re-quantize if needed
            # (In practice, you might want to keep merged models unquantized)
            merged_state[key] = merged_weight
        else:
            # Standard merge for non-quantized weights
            merged_state[key] = weight1 * (1 - ratio) + weight2 * ratio
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model
Merging quantized models requires dequantization, which increases memory usage. Consider merging in FP16/BF16 and re-quantizing the result.

Memory Considerations

Model merging can be memory-intensive:
from comfy import model_management
import gc
import torch

def merge_with_memory_management(model1, model2, ratio=0.5):
    """
    Merge models with explicit memory management
    """
    # Check available memory
    device = model_management.get_torch_device()
    free_memory = model_management.get_free_memory(device)
    
    # Estimate memory needed (3x model size for safety)
    model_size = model_management.module_size(model1.model)
    required_memory = model_size * 3
    
    if required_memory > free_memory:
        print("Insufficient memory, freeing models...")
        # Free memory by unloading other models
        model_management.free_memory(
            required_memory, 
            device
        )
    
    # Perform merge
    merged_model = model1.clone()
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    merged_state = {}
    for key in state_dict1.keys():
        if key in state_dict2:
            merged_state[key] = (
                state_dict1[key] * (1 - ratio) + 
                state_dict2[key] * ratio
            )
        else:
            merged_state[key] = state_dict1[key]
    
    # Clear original state dicts
    del state_dict1
    del state_dict2
    gc.collect()
    torch.cuda.empty_cache()
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model

CPU Merging

Move models to CPU before merging if VRAM is limited

Batch Processing

Merge layers in batches to reduce peak memory usage

Clear Cache

Call torch.cuda.empty_cache() after merging

Monitor Usage

Use get_free_memory() to track available VRAM

Saving Merged Models

After merging, save the result:
import safetensors.torch
import torch

def save_merged_model(model, output_path):
    """
    Save merged model to disk
    """
    # Get state dict
    state_dict = model.model.state_dict()
    
    # Convert to CPU and appropriate dtype
    cpu_state_dict = {}
    for key, value in state_dict.items():
        # Move to CPU
        cpu_value = value.cpu()
        
        # Convert FP8/quantized to FP16 for compatibility
        if hasattr(cpu_value, 'dequantize'):
            cpu_value = cpu_value.dequantize()
        
        if cpu_value.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
            cpu_value = cpu_value.to(torch.float16)
        
        cpu_state_dict[key] = cpu_value
    
    # Save using safetensors (recommended) or torch
    if output_path.endswith('.safetensors'):
        safetensors.torch.save_file(cpu_state_dict, output_path)
    else:
        torch.save(cpu_state_dict, output_path)
    
    print(f"Model saved to {output_path}")

Advanced Techniques

SLERP (Spherical Linear Interpolation)

For better interpolation of directional data:
import torch
import torch.nn.functional as F

def slerp(val, low, high):
    """
    Spherical linear interpolation
    
    Better than linear interpolation for some cases,
    especially when weights represent directions/orientations
    """
    low_norm = low / torch.norm(low, dim=-1, keepdim=True)
    high_norm = high / torch.norm(high, dim=-1, keepdim=True)
    
    omega = torch.acos((low_norm * high_norm).sum(-1))
    so = torch.sin(omega)
    
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(-1) * low + \
          (torch.sin(val * omega) / so).unsqueeze(-1) * high
    
    return res

def merge_slerp(model1, model2, ratio=0.5):
    """
    Merge using SLERP instead of linear interpolation
    """
    merged_model = model1.clone()
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    merged_state = {}
    for key in state_dict1.keys():
        if key in state_dict2:
            # Use SLERP for multi-dimensional tensors
            if state_dict1[key].dim() > 1:
                merged_state[key] = slerp(
                    ratio, 
                    state_dict1[key], 
                    state_dict2[key]
                )
            else:
                # Fall back to linear for 1D (biases)
                merged_state[key] = (
                    state_dict1[key] * (1 - ratio) + 
                    state_dict2[key] * ratio
                )
        else:
            merged_state[key] = state_dict1[key]
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model

Block-Weighted Merging

Different ratios for different transformer blocks:
def merge_block_weighted(model1, model2, block_weights):
    """
    Merge with different weights per transformer block
    
    Args:
        block_weights: List of ratios, one per block
                      e.g., [0.1, 0.2, 0.3, ..., 0.9]
    """
    merged_model = model1.clone()
    state_dict1 = model1.model.state_dict()
    state_dict2 = model2.model.state_dict()
    
    merged_state = {}
    for key in state_dict1.keys():
        # Determine which block this layer belongs to
        ratio = 0.5  # Default
        
        # Extract block number from key (e.g., 'layers.5.attn.weight')
        import re
        match = re.search(r'layers\.(\d+)\.', key)
        if match:
            block_idx = int(match.group(1))
            if block_idx < len(block_weights):
                ratio = block_weights[block_idx]
        
        if key in state_dict2:
            merged_state[key] = (
                state_dict1[key] * (1 - ratio) + 
                state_dict2[key] * ratio
            )
        else:
            merged_state[key] = state_dict1[key]
    
    merged_model.model.load_state_dict(merged_state)
    return merged_model

Best Practices

1

Verify Compatibility

Ensure models have the same architecture before merging
2

Test Ratios

Experiment with different merge ratios (0.3, 0.5, 0.7) to find best results
3

Monitor Memory

Use memory management functions to avoid OOM errors
4

Validate Output

Test merged model thoroughly before deploying
5

Document Settings

Keep records of merge configurations for reproducibility

Common Issues

Problem: Models have different layer structuresSolution: Only merge models with identical architectures. Check layer names and shapes:
sd1 = model1.model.state_dict()
sd2 = model2.model.state_dict()

if set(sd1.keys()) != set(sd2.keys()):
    print("Architecture mismatch!")
    print("Keys in model1 only:", set(sd1.keys()) - set(sd2.keys()))
    print("Keys in model2 only:", set(sd2.keys()) - set(sd1.keys()))
Problem: CUDA OOM during mergeSolution: Merge on CPU or use memory management:
# Option 1: Merge on CPU
model1.model.to('cpu')
model2.model.to('cpu')
merged = merge_linear(model1, model2, 0.5)

# Option 2: Free memory first
model_management.free_memory(
    required_memory, 
    device
)
Problem: Merged model produces bad resultsSolutions:
  • Try different merge ratios
  • Use SLERP instead of linear interpolation
  • Consider selective layer merging
  • Ensure models are from similar training lineage

Memory Management

Manage VRAM during model operations

Quantization

Work with quantized model weights

Custom Nodes

Create custom model merging nodes

Build docs developers (and LLMs) love