This document walks you through the steps to extend a basic model so that it accepts multi-modal inputs.
It is assumed that you have already implemented the base model in vLLM according to the adding models guide .
Overview
Adding multimodal support involves:
Updating the base vLLM model to handle multimodal embeddings
Specifying processing information (max items, token counts)
Specifying dummy inputs for memory profiling
Specifying processing details (field config, prompt updates)
Registering processor-related classes
Step 1: Update the base vLLM model
Implement placeholder string
Implement get_placeholder_str to define the placeholder string used to represent the multi-modal item in the text prompt:
class YourModelForImage2Seq ( nn . Module ):
...
@ classmethod
def get_placeholder_str ( cls , modality : str , i : int ) -> str | None :
if modality.startswith( "image" ):
return "<image>"
raise ValueError ( "Only image modality is supported" )
This should be consistent with the chat template of the model.
Mark language and tower models
Inside the __init__ method, initialize the language components inside _mark_language_model, and the multimodal components inside _mark_tower_model:
def __init__ ( self , * , vllm_config : VllmConfig, prefix : str = "" ) -> None :
super (). __init__ ()
config = vllm_config.model_config.hf_config
with self ._mark_tower_model(vllm_config, "image" ):
self .vision_encoder = ...
self .multi_modal_projector = ...
with self ._mark_language_model(vllm_config):
self .language_model = init_vllm_registered_model(
vllm_config = vllm_config,
hf_config = config.text_config,
prefix = maybe_prefix(prefix, "language_model" ),
)
Implement embed_multimodal
Move the multi-modal embedding logic from the forward method to embed_multimodal:
def forward (
self ,
input_ids : torch.Tensor | None ,
pixel_values : torch.Tensor,
positions : torch.Tensor,
intermediate_tensors : IntermediateTensors | None = None ,
inputs_embeds : torch.Tensor | None = None ,
) -> torch.Tensor:
if inputs_embeds is None :
inputs_embeds = self .get_input_embeddings()(input_ids)
if pixel_values is not None :
image_features = self .get_image_features(
pixel_values = pixel_values,
)
special_image_mask = self .get_placeholder_mask(
input_ids,
inputs_embeds = inputs_embeds,
image_features = image_features,
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask,
image_features,
)
hidden_states = self .language_model(
input_ids,
positions,
intermediate_tensors,
inputs_embeds = inputs_embeds,
)
...
The text embedding and embedding merge are handled automatically by a default implementation of embed_input_ids. It does not need to be overridden in most cases.
Typical implementation pattern
Below is a boilerplate of a typical implementation pattern:
def _process_image_input ( self , image_input : YourModelImageInputs) -> torch.Tensor:
image_features = self .vision_encoder(image_input)
return self .multi_modal_projector(image_features)
def embed_multimodal (
self ,
** kwargs : object ,
) -> MultiModalEmbeddings | None :
# Validate the multimodal input keyword arguments
image_input = self ._parse_and_validate_image_input( ** kwargs)
if image_input is None :
return None
# Run multimodal inputs through encoder and projector
vision_embeddings = self ._process_image_input(image_input)
return vision_embeddings
The returned multimodal_embeddings must be either a 3D tensor of shape (num_items, feature_size, hidden_size), or a list/tuple of 2D tensors of shape (feature_size, hidden_size), so that multimodal_embeddings[i] retrieves the embeddings generated from the i-th multimodal data item.
Update model class interface
Once the above steps are done, update the model class with the SupportsMultiModal interface:
from vllm.model_executor.models.interfaces import SupportsMultiModal
class YourModelForImage2Seq ( nn . Module , SupportsMultiModal ):
...
Create a subclass of BaseProcessingInfo to provide basic information related to HF processing.
Override the get_supported_mm_limits method to return the maximum number of input items for each modality:
from vllm.multimodal.processing import BaseProcessingInfo
class YourModelProcessingInfo ( BaseProcessingInfo ):
def get_supported_mm_limits ( self ) -> Mapping[ str , int | None ]:
# Support any number of images but only one video per prompt
return { "image" : None , "video" : 1 }
Inherit BaseDummyInputsBuilder to construct dummy inputs for HF processing. The processed outputs are also used for memory profiling.
Override get_dummy_text and get_dummy_mm_data to construct dummy inputs:
Dummy text
Dummy multimodal data
from vllm.multimodal.processing import BaseDummyInputsBuilder
class YourModelDummyInputsBuilder ( BaseDummyInputsBuilder ):
def get_dummy_text ( self , mm_counts : Mapping[ str , int ]) -> str :
num_images = mm_counts.get( "image" , 0 )
processor = self .info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
These dummy inputs should result in the worst-case memory usage of the model so that vLLM can reserve the correct amount of memory.
Step 4: Specify processing details
Create a subclass of BaseMultiModalProcessor to fill in the missing details about HF processing.
For more information, see Multi-Modal Data Processing .
Multi-modal fields
Override _get_mm_fields_config to return a schema of the tensors outputted by the HF processor:
from vllm.multimodal.processing import BaseMultiModalProcessor
class YourMultiModalProcessor ( BaseMultiModalProcessor ):
def _get_mm_fields_config (
self ,
hf_inputs : BatchFeature,
hf_processor_mm_kwargs : Mapping[ str , object ],
) -> Mapping[ str , MultiModalFieldConfig]:
return dict (
pixel_values = MultiModalFieldConfig.batched( "image" ),
)
Prompt updates
Override _get_prompt_updates to return a list of PromptUpdate instances that specify update operations performed by the HF processor:
def _get_prompt_updates (
self ,
mm_items : MultiModalDataItems,
hf_processor_mm_kwargs : Mapping[ str , object ],
out_mm_kwargs : MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self .info.get_hf_config()
image_token_id = hf_config.image_token_index
def get_replacement ( item_idx : int ):
images = mm_items.get_items( "image" , ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_image_tokens = self .info.get_num_image_tokens(
image_width = image_size.width,
image_height = image_size.height,
)
return [image_token_id] * num_image_tokens
return [
PromptReplacement(
modality = "image" ,
target = [image_token_id],
replacement = get_replacement,
),
]
Decorate the model class with MULTIMODAL_REGISTRY.register_processor to register the processor classes:
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_processor (
YourMultiModalProcessor,
info = YourModelProcessingInfo,
dummy_inputs = YourModelDummyInputsBuilder,
)
class YourModelForImage2Seq ( nn . Module , SupportsMultiModal ):
...
Advanced topics
Inserting feature tokens without replacement
Some HF processors directly insert feature tokens without replacing anything. Use PromptInsertion instead of PromptReplacement:
from vllm.multimodal.processing import PromptInsertion
return [
PromptInsertion(
modality = "image" ,
offset = 0 , # Insert at start
insertion = get_insertion,
),
]
Examples:
If the HF processor performs additional processing regardless of the number of multi-modal items, override _apply_hf_processor_tokens_only.
Examples:
Custom HF processor
Some models don’t define an HF processor class on HF Hub. In that case, define a custom HF processor and pass it to _call_hf_processor.
Examples:
Next steps
Testing guide Write tests for your multimodal model
Model registration Register your model with vLLM