Skip to main content
This guide walks you through the steps to implement a new model architecture in vLLM.
Many decoder language models can now be automatically loaded using the Transformers modeling backend without having to implement them in vLLM. Try running vllm serve <model> first to see if it works!

Overview

vLLM models are specialized PyTorch models that take advantage of various vLLM features to optimize their performance. The complexity of integrating a model into vLLM depends heavily on the model’s architecture:
  • Simple: Model shares similar architecture with an existing vLLM model
  • Moderate: Model has standard components but unique architecture
  • Complex: Model includes new operators (e.g., new attention mechanism)

Step 1: Bring your model code

First, clone the PyTorch model code from the source repository. For instance, vLLM’s OPT model was adapted from HuggingFace’s modeling_opt.py file.
Make sure to review and adhere to the original code’s copyright and licensing terms!
It is recommended to find a model similar to yours in vllm/model_executor/models and adapt it to your model’s architecture.

Step 2: Make your code compatible with vLLM

Initialization code

All vLLM modules within the model must include a prefix argument in their constructor. The prefix is typically the full name of the module in the model’s state dictionary and is crucial for:
  • Runtime support: vLLM’s attention operators are registered in a model’s state by their full names
  • Non-uniform quantization support: Quantized checkpoints can selectively quantize certain layers. By providing the prefix during initialization, vLLM can match the current layer’s prefix with the quantization configuration
from torch import nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention import Attention

class MyAttention(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.attn = Attention(prefix=f"{prefix}.attn")

class MyDecoderLayer(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.self_attn = MyAttention(
            vllm_config,
            prefix=f"{prefix}.self_attn"
        )

class MyModel(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str):
        super().__init__()
        self.layers = nn.ModuleList([
            MyDecoderLayer(
                vllm_config,
                prefix=f"{prefix}.layers.{i}"
            )
            for i in range(vllm_config.model_config.hf_config.num_hidden_layers)
        ])

class MyModelForCausalLM(nn.Module):
    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.model = MyModel(vllm_config, prefix=f"{prefix}.model")

Computation code

1

Add embed_input_ids method

Add an embed_input_ids method inside your model module that returns the text embeddings given input_ids:
class MyModel(nn.Module):
    ...

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
This provides a unified interface in case your model is used within a composite multimodal model.
2

Rewrite the forward method

Modify the forward method to:
  • Remove unnecessary code (e.g., training-specific code)
  • Treat input_ids and positions as flattened tensors with a single batch size dimension
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
    ...
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.

Step 3: Implement tensor parallelism and quantization support

If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.

Parallel layers

Replace your model’s linear and embedding layers with their tensor-parallel versions:
LayerPurposeUsage
VocabParallelEmbeddingEmbedding layerInput embeddings
ParallelLMHeadOutput layerLM head
ReplicatedLinearReplicated linearNo memory saving, inputs and weights replicated
RowParallelLinearRow-parallel linearSecond FFN layer, attention output
ColumnParallelLinearColumn-parallel linearFirst FFN layer, QKV projection
MergedColumnParallelLinearMerged column-parallelFirst FFN layer with weighted activation
QKVParallelLinearQKV projectionMulti-head and grouped-query attention

Linear method for quantization

All linear layers above take linear_method as an input. vLLM will set this parameter according to different quantization schemes to support weight quantization.
self.qkv_proj = QKVParallelLinear(
    hidden_size,
    head_dim,
    total_num_heads,
    total_num_kv_heads,
    bias=False,
    linear_method=linear_method,
)

Step 4: Implement weight loading logic

Implement the load_weights method in your *ForCausalLM class. This method should:
  1. Load weights from the HuggingFace checkpoint file
  2. Assign them to the corresponding layers in your model
  3. Handle merged layers (MergedColumnParallelLinear, QKVParallelLinear) by loading separated weight matrices
class MyModelForCausalLM(nn.Module):
    ...

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            # Handle weight loading logic
            ...

Step 5: Special model architectures

Models with interleaving sliding windows

To support a model with interleaving sliding windows:
  1. Make sure the model’s config.json contains layer_types
  2. In the modeling code, parse the correct sliding window value for every layer and pass it to the attention layer’s per_layer_sliding_window argument
sliding_window = (
    config.sliding_window if layer_type == "sliding" else None
)
attn = Attention(
    ...,
    per_layer_sliding_window=sliding_window,
)

Models that use Mamba

vLLM supports three different scenarios:
Models that use Mamba layers (Mamba-1 or Mamba-2) but do not use attention layers.
  • Inherit protocol IsAttentionFree
  • Implement class methods get_mamba_state_dtype_from_config and get_mamba_state_shape_from_config
  • Use MambaMixer (Mamba-1) or MambaMixer2 (Mamba-2) classes
  • Add model to MODELS_CONFIG_MAP in vllm/model_executor/models/config.py
Example: MambaForCausalLM (Mamba-1) or Mamba2ForCausalLM (Mamba-2)

Next steps

After implementing your model:

Model registration

Register your model with vLLM

Testing guide

Write tests for your model

Multimodal support

Add multimodal capabilities

Getting help

If you are encountering issues while integrating your model into vLLM, feel free to open a GitHub issue or ask on our developer slack. We will be happy to help you out!

Build docs developers (and LLMs) love