This guide provides a step-by-step process for adding a new model architecture to TensorRT-LLM’s PyTorch backend.
Prerequisites
- Working TensorRT-LLM installation
- Understanding of PyTorch and Transformer architectures
- Familiarity with your target model architecture
Overview
Adding a new model involves four main steps:
- Model Configuration: Define or reuse HuggingFace configuration
- Model Definition: Implement PyTorch model classes
- Weight Loading: Map weights from source checkpoints
- Model Registration: Register for auto-discovery
Step 1: Model Configuration
If the model is supported in HuggingFace transformers, reuse their config:
from transformers import LlamaConfig
For new models, create a configuration class:
from transformers.configuration_utils import PretrainedConfig
class MyConfig(PretrainedConfig):
model_type = "mymodel"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
Step 2: Model Definition
Create modeling_mymodel.py with the model structure:
from typing import Optional
import torch
from torch import nn
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModel, DecoderModelForCausalLM
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
class MyAttention(Attention):
def __init__(self, model_config: ModelConfig, layer_idx: Optional[int] = None):
super().__init__(
hidden_size=model_config.pretrained_config.hidden_size,
num_attention_heads=model_config.pretrained_config.num_attention_heads,
# ... other parameters
)
class MyDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig, layer_idx: int):
super().__init__()
self.self_attn = MyAttention(model_config, layer_idx)
self.mlp = MyMLP(model_config)
self.input_layernorm = RMSNorm(...)
def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata):
# Implement layer forward pass
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(hidden_states, attn_metadata)
hidden_states = residual + hidden_states
# ... MLP and residual
return hidden_states
class MyModel(DecoderModel):
def __init__(self, model_config: ModelConfig):
super().__init__(model_config)
self.embed_tokens = Embedding(...)
self.layers = nn.ModuleList([
MyDecoderLayer(model_config, i)
for i in range(model_config.pretrained_config.num_hidden_layers)
])
self.norm = RMSNorm(...)
class MyModelForCausalLM(DecoderModelForCausalLM):
def __init__(self, model_config: ModelConfig):
super().__init__(model_config)
self.model = MyModel(model_config)
self.lm_head = Linear(...)
Step 3: Weight Loading
Map HuggingFace weights to TensorRT-LLM model:
def load_weights_from_hf(self, hf_model_dir):
# Load HuggingFace checkpoint
hf_state_dict = torch.load(f"{hf_model_dir}/pytorch_model.bin")
# Map weights to TensorRT-LLM structure
for layer_idx in range(self.config.num_hidden_layers):
# Attention weights
self.layers[layer_idx].self_attn.qkv.weight.copy_(
torch.cat([
hf_state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"],
hf_state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"],
hf_state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"]
])
)
Step 4: Model Registration
Core Models
Add to tensorrt_llm/models/automodel.py:
register_model(
"MyModelForCausalLM",
"modeling_mymodel",
"MyModelForCausalLM",
"MyConfig"
)
Out-of-Tree Models (Recommended)
Create a separate Python package without modifying TensorRT-LLM:
# my_custom_model/__init__.py
from tensorrt_llm.models import register_model
from .modeling_mymodel import MyModelForCausalLM, MyConfig
register_model(
"MyModelForCausalLM",
__name__ + ".modeling_mymodel",
"MyModelForCausalLM",
"MyConfig"
)
Use it:
import my_custom_model # Register the model
from tensorrt_llm import LLM
llm = LLM(model="path/to/mymodel")
Testing
Test your model implementation:
from tensorrt_llm import LLM
# Test inference
llm = LLM(model="path/to/mymodel")
output = llm.generate(["Hello, world!"])
print(output[0].text)
Best Practices
Use TensorRT-LLM Optimized Modules
Always use TensorRT-LLM’s optimized modules for best performance:
tensorrt_llm._torch.modules.linear.Linear
tensorrt_llm._torch.modules.attention.Attention
tensorrt_llm._torch.modules.rms_norm.RMSNorm
Handle Tensor Parallelism
Use the Mapping configuration to handle distributed execution automatically.
Study existing models like modeling_llama.py as reference implementations.
Next Steps
Custom Kernels
Learn how to write custom CUDA kernels
Model Configuration
Explore model configuration options