Skip to main content
The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase.

How plugins work in vLLM

Plugins are user-registered code that vLLM executes. Given vLLM’s architecture, multiple processes may be involved, especially when using distributed inference with various parallelism techniques. Key requirement: Every process created by vLLM needs to load the plugin. This is done by the load_plugins_by_group function in the vllm.plugins module.

How vLLM discovers plugins

vLLM’s plugin system uses the standard Python entry_points mechanism. This allows developers to register functions in their Python packages for use by other packages.

Example plugin

from setuptools import setup

setup(
    name='vllm_add_dummy_model',
    version='0.1',
    packages=['vllm_add_dummy_model'],
    entry_points={
        'vllm.general_plugins': [
            "register_dummy_model = vllm_add_dummy_model:register"
        ]
    }
)
For more information on adding entry points, see the official setuptools documentation.

Plugin components

Every plugin has three parts:

1. Plugin group

The name of the entry point group. This is the key of entry_points in setup.py. General plugins: Use vllm.general_plugins for vLLM’s general plugins.

2. Plugin name

The name of the plugin. This is the value in the dictionary of the entry_points dictionary. Example: register_dummy_model Filtering: Plugins can be filtered by name using the VLLM_PLUGINS environment variable. To load only a specific plugin, set VLLM_PLUGINS to the plugin name.

3. Plugin value

The fully qualified name of the function or module to register in the plugin system. Example: vllm_add_dummy_model:register refers to a function named register in the vllm_add_dummy_model module.

Types of supported plugins

General plugins

Group name: vllm.general_plugins Primary use case: Register custom, out-of-the-tree models into vLLM by calling ModelRegistry.register_model inside the plugin function. Example: bart-plugin which adds support for BartForConditionalGeneration.

Platform plugins

Group name: vllm.platform_plugins Primary use case: Register custom, out-of-the-tree platforms into vLLM. Return value: The plugin function should return:
  • None when the platform is not supported in the current environment
  • The platform class’s fully qualified name when the platform is supported

IO processor plugins

Group name: vllm.io_processor_plugins Primary use case: Register custom pre-/post-processing of the model prompt and model output for pooling models. Return value: The IOProcessor’s class fully qualified name.

Stat logger plugins

Group name: vllm.stat_logger_plugins Primary use case: Register custom, out-of-the-tree loggers into vLLM. Requirements: The entry point should be a class that subclasses StatLoggerBase.

Guidelines for writing plugins

General guidelines

Re-entrant requirement: The function specified in the entry point must be re-entrant (can be called multiple times without causing issues). This is necessary because the function might be called multiple times in some processes.

Platform plugin guidelines

Platform plugins allow you to add support for custom hardware platforms to vLLM.
1

Create project structure

Create a platform plugin project:
vllm_add_dummy_platform/
├── vllm_add_dummy_platform/
│   ├── __init__.py
│   ├── my_dummy_platform.py
│   ├── my_dummy_worker.py
│   ├── my_dummy_attention.py
│   ├── my_dummy_device_communicator.py
│   ├── my_dummy_custom_ops.py
└── setup.py
2

Register entry point

In setup.py, add the entry point:
setup(
    name="vllm_add_dummy_platform",
    ...
    entry_points={
        "vllm.platform_plugins": [
            "my_dummy_platform = vllm_add_dummy_platform:register"
        ]
    },
    ...
)
The register function should return the platform class’s fully qualified name:
def register():
    return "vllm_add_dummy_platform.my_dummy_platform.MyDummyPlatform"
3

Implement platform class

In my_dummy_platform.py, implement the platform class inheriting from vllm.platforms.interface.Platform.Key properties and methods:
  • _enum: Device enumeration from PlatformEnum (usually PlatformEnum.OOT for out-of-tree)
  • device_type: Type of device PyTorch uses (e.g., "cpu", "cuda")
  • device_name: Usually same as device_type, mainly for logging
  • check_and_update_config: Called early in initialization to update vLLM config. Must set worker_cls here
  • get_attn_backend_cls: Return the attention backend class’s fully qualified name
  • get_device_communicator_cls: Return the device communicator class’s fully qualified name
4

Implement worker class

In my_dummy_worker.py, implement the worker class inheriting from WorkerBase.Required methods:
  • init_device: Set up the device for the worker
  • initialize_cache: Set cache config for the worker
  • load_model: Load model weights to device
  • get_kv_cache_spec: Generate KV cache spec for the model
  • determine_available_memory: Profile peak memory usage
  • initialize_from_config: Allocate device KV cache
  • execute_model: Execute model inference (called every step)
Optional methods:
  • sleep and wakeup: Support sleep mode feature
  • compile_or_warm_up_model: Support graph mode feature
  • take_draft_token_ids: Support speculative decoding
  • add_lora, remove_lora, list_loras, pin_lora: Support LoRA
  • execute_dummy_batch: Support data parallelism
5

Implement attention backend

In my_dummy_attention.py, implement the attention backend class inheriting from AttentionBackend.Purpose: Calculate attention with your device.Examples: See vllm.v1.attention.backends for various attention backend implementations.
6

Implement custom ops (optional)

Implement custom ops for high performance. vLLM supports:PyTorch ops:
  • Communicator ops: Device communicator operations (all-reduce, all-gather, etc.). Inherit from DeviceCommunicatorBase
  • Common ops: Common operations (matmul, softmax, etc.). Register using CustomOp class
  • C++ ops: Implemented in C++ and registered as torch custom ops. Follow csrc module and vllm._custom_ops
Triton ops: Custom way doesn’t work for triton ops currently
7

Implement other modules (optional)

Implement other pluggable modules:
  • LoRA support
  • Graph backend
  • Quantization
  • Mamba attention backend

Compatibility guarantee

vLLM guarantees that the interface of documented plugins (such as ModelRegistry.register_model) will always be available.
Plugin developer responsibility: It is the responsibility of plugin developers to ensure their plugins are compatible with the version of vLLM they are targeting.
The interface for models/modules may change during vLLM’s development. If you see any deprecation log info, please upgrade your plugin to the latest version.

Deprecation announcements

Deprecated features:
  • use_v1 parameter in Platform.get_attn_backend_cls - Removed in v0.13.0
  • _Backend in vllm.attention - Removed in v0.13.0. Use vllm.v1.attention.backends.registry.register_backend instead
  • seed_everything platform interface - Removed in v0.16.0. Use vllm.utils.torch_utils.set_random_seed instead
  • prompt in Platform.validate_request - Will be removed in v0.18.0

Example: Complete general plugin

Here’s a complete example of a general plugin that registers a custom model:
from setuptools import setup, find_packages

setup(
    name="vllm_custom_model",
    version="0.1.0",
    packages=find_packages(),
    install_requires=["vllm>=0.4.0"],
    entry_points={
        "vllm.general_plugins": [
            "register_custom = vllm_custom_model:register"
        ]
    },
)

Next steps

Model registration

Learn more about registering models

Architecture

Understand vLLM’s architecture

Build docs developers (and LLMs) love