Skip to main content
LoRA (Low-Rank Adaptation) adapters enable efficient fine-tuning and serving of customized models. vLLM supports serving multiple LoRA adapters on top of a base model with minimal overhead.

Overview

LoRA adapters add small trainable parameters to specific layers of a pre-trained model, allowing you to customize model behavior without full fine-tuning. vLLM can:
  • Serve multiple LoRA adapters simultaneously
  • Switch between adapters on a per-request basis
  • Load adapters dynamically at runtime
  • Combine base model and adapter requests in the same batch
LoRA adapters work with any vLLM model that implements the SupportsLoRA interface. Most popular transformer models are supported.

Quick start

Offline inference

Download an adapter and use it with the base model:
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# Download the LoRA adapter
sql_lora_path = snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider")

# Initialize base model with LoRA support
llm = LLM(
    model="meta-llama/Llama-3.2-3B-Instruct",
    enable_lora=True
)

# Configure sampling
sampling_params = SamplingParams(
    temperature=0,
    max_tokens=256,
    stop=["[/assistant]"],
)

# Prompts for the SQL adapter
prompts = [
    "[user] Write a SQL query to answer the question based on the table schema.\n\n"
    "context: CREATE TABLE airports (icao VARCHAR, airport VARCHAR)\n\n"
    "question: Name the ICAO for Lilongwe International Airport [/user] [assistant]",
]

# Generate with LoRA adapter
outputs = llm.generate(
    prompts,
    sampling_params,
    lora_request=LoRARequest("sql_adapter", 1, sql_lora_path),
)

for output in outputs:
    print(f"Generated SQL: {output.outputs[0].text}")
The LoRARequest takes three arguments:
  • name: Human-readable identifier for the adapter
  • lora_int_id: Unique integer ID (must be unique across all adapters)
  • lora_path: Path to the adapter weights (local path or HuggingFace repo ID)

Serving with API

Start the server with LoRA support:
vllm serve meta-llama/Llama-3.2-3B-Instruct \
    --enable-lora \
    --lora-modules sql-lora=jeeejeee/llama32-3b-text2sql-spider
Make requests using the adapter name as the model ID:
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "sql-lora",
        "prompt": "Generate SQL for: SELECT * FROM users WHERE",
        "max_tokens": 100,
        "temperature": 0
    }'
The base model is also available:
curl http://localhost:8000/v1/models | jq .
{
    "object": "list",
    "data": [
        {
            "id": "meta-llama/Llama-3.2-3B-Instruct",
            "object": "model",
            "owned_by": "vllm"
        },
        {
            "id": "sql-lora",
            "object": "model",
            "owned_by": "vllm"
        }
    ]
}

Multiple LoRA adapters

Serve multiple adapters and let vLLM handle switching between them:
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest

# Initialize with multi-LoRA support
engine_args = EngineArgs(
    model="meta-llama/Llama-3.2-3B-Instruct",
    enable_lora=True,
    max_loras=2,  # Number of LoRAs that can be active simultaneously
    max_lora_rank=64,  # Maximum rank across all adapters
    max_cpu_loras=4,  # Size of CPU cache for swapped-out adapters
)

engine = LLMEngine.from_engine_args(engine_args)

# Define requests with different adapters
requests = [
    ("Base model prompt", None),
    ("SQL task prompt", LoRARequest("sql-lora", 1, "/path/to/sql-adapter")),
    ("Math task prompt", LoRARequest("math-lora", 2, "/path/to/math-adapter")),
    ("Another SQL prompt", LoRARequest("sql-lora", 1, "/path/to/sql-adapter")),
]

sampling_params = SamplingParams(temperature=0.8, max_tokens=256)

# Add requests to engine
for i, (prompt, lora_request) in enumerate(requests):
    engine.add_request(
        str(i),
        prompt,
        sampling_params,
        lora_request=lora_request
    )

# Process all requests (vLLM handles adapter switching)
while engine.has_unfinished_requests():
    outputs = engine.step()
    for output in outputs:
        if output.finished:
            print(f"Request {output.request_id}: {output.outputs[0].text}")
When max_loras=1, requests using different adapters are processed sequentially. Increase max_loras to batch requests with different adapters together (requires more memory).

Configuration parameters

LoRA configuration

enable_lora
bool
default:"False"
Enable LoRA adapter support. Required to use any LoRA features.
max_loras
int
default:"1"
Maximum number of LoRA adapters that can be active in a single batch.
  • Higher values allow more parallel adapter serving
  • Increases memory usage (each slot requires preallocated tensors)
max_lora_rank
int
default:"16"
Maximum rank allowed for LoRA adapters.
  • Must be >= the maximum rank of any adapter you plan to use
  • Higher values increase memory usage
  • Set to the actual maximum rank needed (avoid overallocation)
max_cpu_loras
int
default:"None"
Maximum number of LoRA adapters to cache in CPU memory.
  • Adapters not in GPU cache can be quickly swapped from CPU
  • Set higher if you frequently switch between many adapters

Server configuration

lora_modules
str | list
Pre-load LoRA adapters at server startup.Format: name=path or JSON with name, path, and optional base_model_name.
# Simple format
vllm serve model --enable-lora \
    --lora-modules sql-lora=/path/to/sql adapter=repo/name

# JSON format with base model lineage
vllm serve model --enable-lora \
    --lora-modules '{"name": "sql-lora", "path": "repo/name", "base_model_name": "meta-llama/Llama-3.2-3B"}'

Dynamic adapter loading

Load and unload adapters at runtime without restarting the server.

Enable runtime updates

export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
vllm serve meta-llama/Llama-3.2-3B-Instruct --enable-lora
Enabling runtime LoRA updates in production is risky as it allows users to manage model adapters. Use with caution and proper access controls.

Load adapter via API

curl -X POST http://localhost:8000/v1/load_lora_adapter \
    -H "Content-Type: application/json" \
    -d '{
        "lora_name": "new_adapter",
        "lora_path": "/path/to/adapter"
    }'
Response: Success: LoRA adapter 'new_adapter' added successfully

Unload adapter via API

curl -X POST http://localhost:8000/v1/unload_lora_adapter \
    -H "Content-Type: application/json" \
    -d '{
        "lora_name": "new_adapter"
    }'
Response: Success: LoRA adapter 'new_adapter' removed successfully

In-place reloading

Replace an adapter with updated weights while keeping the same name:
curl -X POST http://localhost:8000/v1/load_lora_adapter \
    -H "Content-Type: application/json" \
    -d '{
        "lora_name": "my-adapter",
        "lora_path": "/path/to/updated/weights",
        "load_inplace": true
    }'
Useful for online reinforcement learning scenarios where adapters are continuously updated.

Dynamic adapter resolution

Use plugins to automatically load adapters on-demand from various sources.

Built-in resolvers

vLLM includes two resolver plugins: Local filesystem resolver:
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
export VLLM_PLUGINS=lora_filesystem_resolver
export VLLM_LORA_RESOLVER_CACHE_DIR=/path/to/lora/directory

vllm serve meta-llama/Llama-3.2-3B-Instruct --enable-lora
When a request arrives for adapter foobar, vLLM looks for /path/to/lora/directory/foobar and loads it automatically. Hugging Face Hub resolver:
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
export VLLM_PLUGINS=lora_hf_hub_resolver
export VLLM_LORA_RESOLVER_HF_REPO_LIST=org/repo1,org/repo2

vllm serve meta-llama/Llama-3.2-3B-Instruct --enable-lora
Request adapter org/repo1/subpath to download from the specified subpath in the repository.
Hugging Face Hub resolver downloads remote content and is not intended for production. Use proper authentication and access controls.

Custom resolver plugin

Implement your own adapter resolver:
from vllm.lora.request import LoRARequest
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
import s3fs

class S3LoRAResolver(LoRAResolver):
    def __init__(self):
        self.s3 = s3fs.S3FileSystem()
        
    async def resolve_lora(self, base_model_name, lora_name):
        # Download from S3
        s3_path = f"s3://my-bucket/loras/{base_model_name}/{lora_name}"
        local_path = f"/tmp/loras/{lora_name}"
        
        await self.s3._get(s3_path, local_path, recursive=True)
        
        return LoRARequest(
            lora_name=lora_name,
            lora_path=local_path,
            lora_int_id=abs(hash(lora_name)),
        )

# Register the resolver
s3_resolver = S3LoRAResolver()
LoRAResolverRegistry.register_resolver("s3_resolver", s3_resolver)
See vLLM Plugin System for more details.

Multi-modal LoRA support

vLLM experimentally supports LoRA for vision/audio tower and connector components in multi-modal models.
from vllm import LLM, SamplingParams

llm = LLM(
    model="multi-modal-model",
    enable_lora=True,
    max_lora_rank=64,
    # LoRA can be applied to tower and connector layers
)
Multi-modal LoRA support requires implementing token helper functions for the tower and connector. See PR #26674 and Issue #31479 for current model support status.

Default multi-modal LoRAs

Some multi-modal models (e.g., Granite Speech, Phi-4-multimodal) include LoRA adapters that should always be applied for specific modalities.
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset

llm = LLM(
    model="ibm-granite/granite-speech-3.3-2b",
    enable_lora=True,
    max_lora_rank=64,
    # Automatically apply audio LoRA when audio inputs are present
    default_mm_loras={"audio": "ibm-granite/granite-speech-3.3-2b"},
)

audio = AudioAsset("mary_had_lamb").audio_and_sample_rate

outputs = llm.generate(
    {
        "prompt": "<|audio|>Transcribe this:",
        "multi_modal_data": {"audio": audio},
    },
    sampling_params=SamplingParams(temperature=0.2, max_tokens=64),
)
The audio LoRA is automatically applied without explicit LoRARequest. For serving:
vllm serve ibm-granite/granite-speech-3.3-2b \
    --enable-lora \
    --default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
    --max-lora-rank 64

Best practices

Set max_lora_rank correctly

Setting max_lora_rank too high wastes memory and degrades performance. Always set it to the actual maximum rank of your adapters.
# Check adapter ranks
ls -la /path/to/adapter/  # Look for adapter_config.json
cat /path/to/adapter/adapter_config.json | grep -i rank

# If your adapters use rank 8, 16, and 32:
vllm serve model --enable-lora --max-lora-rank 32  # ✅ Good
vllm serve model --enable-lora --max-lora-rank 256  # ❌ Wasteful

Memory considerations

LoRA serving uses additional memory:
  • Each active LoRA slot: ~100-500 MB (depends on rank and model size)
  • CPU cache: stores inactive adapters
  • GPU memory: reduced available KV cache space
llm = LLM(
    model="large-model",
    enable_lora=True,
    max_loras=4,  # 4 active slots
    max_lora_rank=64,
    max_cpu_loras=10,  # Cache 10 adapters in CPU
    gpu_memory_utilization=0.85,  # May need to reduce from 0.9
)

Adapter organization

Organize adapters by use case:
lora_adapters/
├── base-model-name/
│   ├── sql-task/
│   │   ├── adapter_config.json
│   │   └── adapter_model.safetensors
│   ├── summarization/
│   │   ├── adapter_config.json
│   │   └── adapter_model.safetensors
│   └── code-generation/
│       ├── adapter_config.json
│       └── adapter_model.safetensors

Complete example

Here’s a complete example with multiple adapters:
from huggingface_hub import snapshot_download
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

# Download adapters
sql_adapter = snapshot_download("jeeejeee/llama32-3b-text2sql-spider")

# Initialize with LoRA support
llm = LLM(
    model="meta-llama/Llama-3.2-3B-Instruct",
    enable_lora=True,
    max_loras=2,
    max_lora_rank=64,
    max_cpu_loras=5,
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=200,
)

# Mix base model and adapter requests
requests = [
    # Base model request
    {
        "prompt": "Explain what LoRA adapters are:",
        "lora_request": None,
    },
    # SQL adapter request
    {
        "prompt": "[user] Generate SQL: SELECT users WHERE age > [/user] [assistant]",
        "lora_request": LoRARequest("sql", 1, sql_adapter),
    },
    # Another SQL request (reuses cached adapter)
    {
        "prompt": "[user] Generate SQL: JOIN tables on [/user] [assistant]",
        "lora_request": LoRARequest("sql", 1, sql_adapter),
    },
]

# Generate (vLLM batches compatible requests)
for req in requests:
    outputs = llm.generate(
        req["prompt"],
        sampling_params,
        lora_request=req["lora_request"],
    )
    
    adapter = req["lora_request"].lora_name if req["lora_request"] else "base"
    print(f"Adapter: {adapter}")
    print(f"Output: {outputs[0].outputs[0].text}")
    print("-" * 80)
  • LoRA paper - Original research
  • Sampling parameters - Configure generation
  • Source: docs/features/lora.md
  • Example: examples/offline_inference/multilora_inference.py

Build docs developers (and LLMs) love