Skip to main content
This document walks you through the steps to extend a basic model so that it accepts multi-modal inputs.
It is assumed that you have already implemented the base model in vLLM according to the adding models guide.

Overview

Adding multimodal support involves:
  1. Updating the base vLLM model to handle multimodal embeddings
  2. Specifying processing information (max items, token counts)
  3. Specifying dummy inputs for memory profiling
  4. Specifying processing details (field config, prompt updates)
  5. Registering processor-related classes

Step 1: Update the base vLLM model

Implement placeholder string

Implement get_placeholder_str to define the placeholder string used to represent the multi-modal item in the text prompt:
class YourModelForImage2Seq(nn.Module):
    ...

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")
This should be consistent with the chat template of the model.

Mark language and tower models

Inside the __init__ method, initialize the language components inside _mark_language_model, and the multimodal components inside _mark_tower_model:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
    super().__init__()

    config = vllm_config.model_config.hf_config

    with self._mark_tower_model(vllm_config, "image"):
        self.vision_encoder = ...
        self.multi_modal_projector = ...

    with self._mark_language_model(vllm_config):
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

Implement embed_multimodal

Move the multi-modal embedding logic from the forward method to embed_multimodal:
def forward(
    self,
    input_ids: torch.Tensor | None,
    pixel_values: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
    if inputs_embeds is None:
        inputs_embeds = self.get_input_embeddings()(input_ids)

    if pixel_values is not None:
        image_features = self.get_image_features(
            pixel_values=pixel_values,
        )
        special_image_mask = self.get_placeholder_mask(
            input_ids,
            inputs_embeds=inputs_embeds,
            image_features=image_features,
        )
        inputs_embeds = inputs_embeds.masked_scatter(
            special_image_mask,
            image_features,
        )

    hidden_states = self.language_model(
        input_ids,
        positions,
        intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )
    ...
The text embedding and embedding merge are handled automatically by a default implementation of embed_input_ids. It does not need to be overridden in most cases.

Typical implementation pattern

Below is a boilerplate of a typical implementation pattern:
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
    image_features = self.vision_encoder(image_input)
    return self.multi_modal_projector(image_features)

def embed_multimodal(
    self,
    **kwargs: object,
) -> MultiModalEmbeddings | None:
    # Validate the multimodal input keyword arguments
    image_input = self._parse_and_validate_image_input(**kwargs)
    if image_input is None:
        return None

    # Run multimodal inputs through encoder and projector
    vision_embeddings = self._process_image_input(image_input)
    return vision_embeddings
The returned multimodal_embeddings must be either a 3D tensor of shape (num_items, feature_size, hidden_size), or a list/tuple of 2D tensors of shape (feature_size, hidden_size), so that multimodal_embeddings[i] retrieves the embeddings generated from the i-th multimodal data item.

Update model class interface

Once the above steps are done, update the model class with the SupportsMultiModal interface:
from vllm.model_executor.models.interfaces import SupportsMultiModal

class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
    ...

Step 2: Specify processing information

Create a subclass of BaseProcessingInfo to provide basic information related to HF processing.

Maximum number of input items

Override the get_supported_mm_limits method to return the maximum number of input items for each modality:
from vllm.multimodal.processing import BaseProcessingInfo

class YourModelProcessingInfo(BaseProcessingInfo):
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        # Support any number of images but only one video per prompt
        return {"image": None, "video": 1}

Step 3: Specify dummy inputs

Inherit BaseDummyInputsBuilder to construct dummy inputs for HF processing. The processed outputs are also used for memory profiling.

Implementing dummy input methods

Override get_dummy_text and get_dummy_mm_data to construct dummy inputs:
from vllm.multimodal.processing import BaseDummyInputsBuilder

class YourModelDummyInputsBuilder(BaseDummyInputsBuilder):
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token = processor.image_token

        return image_token * num_images
These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory.

Step 4: Specify processing details

Create a subclass of BaseMultiModalProcessor to fill in the missing details about HF processing. For more information, see Multi-Modal Data Processing.

Multi-modal fields

Override _get_mm_fields_config to return a schema of the tensors outputted by the HF processor:
from vllm.multimodal.processing import BaseMultiModalProcessor

class YourMultiModalProcessor(BaseMultiModalProcessor):
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
        )

Prompt updates

Override _get_prompt_updates to return a list of PromptUpdate instances that specify update operations performed by the HF processor:
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, object],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    hf_config = self.info.get_hf_config()
    image_token_id = hf_config.image_token_index

    def get_replacement(item_idx: int):
        images = mm_items.get_items("image", ImageProcessorItems)

        image_size = images.get_image_size(item_idx)
        num_image_tokens = self.info.get_num_image_tokens(
            image_width=image_size.width,
            image_height=image_size.height,
        )

        return [image_token_id] * num_image_tokens

    return [
        PromptReplacement(
            modality="image",
            target=[image_token_id],
            replacement=get_replacement,
        ),
    ]
Decorate the model class with MULTIMODAL_REGISTRY.register_processor to register the processor classes:
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY

@MULTIMODAL_REGISTRY.register_processor(
    YourMultiModalProcessor,
    info=YourModelProcessingInfo,
    dummy_inputs=YourModelDummyInputsBuilder,
)
class YourModelForImage2Seq(nn.Module, SupportsMultiModal):
    ...

Advanced topics

Inserting feature tokens without replacement

Some HF processors directly insert feature tokens without replacing anything. Use PromptInsertion instead of PromptReplacement:
from vllm.multimodal.processing import PromptInsertion

return [
    PromptInsertion(
        modality="image",
        offset=0,  # Insert at start
        insertion=get_insertion,
    ),
]
Examples:

Handling prompt updates unrelated to multi-modal data

If the HF processor performs additional processing regardless of the number of multi-modal items, override _apply_hf_processor_tokens_only. Examples:

Custom HF processor

Some models don’t define an HF processor class on HF Hub. In that case, define a custom HF processor and pass it to _call_hf_processor. Examples:

Next steps

Testing guide

Write tests for your multimodal model

Model registration

Register your model with vLLM

Build docs developers (and LLMs) love