Skip to main content
TensorRT-LLM provides a flexible framework for adding custom model architectures. This guide walks through the process of implementing, registering, and using custom models.

Overview

Adding a custom model involves four main steps:
  1. Model Configuration - Define the configuration class
  2. Model Definition - Implement the model architecture
  3. Weight Loading - Handle checkpoint loading and conversion
  4. Model Registration - Register the model for auto-discovery
If your model is already supported in HuggingFace Transformers, you can reuse the HuggingFace configuration and adapt the modeling code.

Prerequisites

  • Working TensorRT-LLM installation
  • Understanding of PyTorch and Transformer architectures
  • Familiarity with your target model architecture

Step 1: Model Configuration

Create a configuration class that defines all model hyperparameters.
If your model exists in HuggingFace Transformers, reuse their config:
configuration_mymodel.py
from transformers import LlamaConfig

# Reuse existing HuggingFace config
MyConfig = LlamaConfig
Or extend it with custom parameters:
from transformers import PretrainedConfig

class MyConfig(PretrainedConfig):
    model_type = "mymodel"
    
    def __init__(
        self,
        vocab_size=32000,
        hidden_size=4096,
        num_hidden_layers=32,
        num_attention_heads=32,
        num_key_value_heads=32,
        intermediate_size=11008,
        hidden_act="silu",
        max_position_embeddings=2048,
        rope_theta=10000.0,
        rms_norm_eps=1e-6,
        # Custom parameters
        custom_param=None,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.intermediate_size = intermediate_size
        self.hidden_act = hidden_act
        self.max_position_embeddings = max_position_embeddings
        self.rope_theta = rope_theta
        self.rms_norm_eps = rms_norm_eps
        self.custom_param = custom_param
        super().__init__(**kwargs)

Step 2: Model Definition

Implement your model architecture using TensorRT-LLM’s base classes and modules.

Model Structure

A typical decoder model consists of:
  1. Attention Layer - Inherits from Attention
  2. Decoder Layer - Inherits from DecoderLayer
  3. Model - Inherits from DecoderModel
  4. Model for Causal LM - Inherits from DecoderModelForCausalLM

Implementation Example

modeling_mymodel.py
from typing import Optional

import torch
from torch import nn
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (
    DecoderModel, 
    DecoderModelForCausalLM,
    register_auto_model
)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.rms_norm import RMSNorm

from configuration_mymodel import MyConfig


class MyAttention(Attention):
    """Custom attention implementation."""
    
    def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: Optional[int] = None):
        config = model_config.pretrained_config
        
        super().__init__(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            max_position_embeddings=config.max_position_embeddings,
            dtype=model_config.dtype,
            attention_mask_type=AttentionMaskType.causal,
            position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
            rotary_embedding_base=config.rope_theta,
            rotary_embedding_scaling=getattr(config, 'rope_scaling', None),
            layer_idx=layer_idx,
            tp_group=model_config.mapping.tp_group,
            tp_size=model_config.mapping.tp_size,
        )


class MyDecoderLayer(DecoderLayer):
    """Single transformer decoder layer."""
    
    def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: int):
        super().__init__()
        config = model_config.pretrained_config
        
        # Layer normalization
        self.input_layernorm = RMSNorm(
            normalized_shape=config.hidden_size,
            eps=config.rms_norm_eps,
            dtype=model_config.dtype
        )
        
        # Self-attention
        self.self_attn = MyAttention(model_config, layer_idx)
        
        # Post-attention layer norm
        self.post_attention_layernorm = RMSNorm(
            normalized_shape=config.hidden_size,
            eps=config.rms_norm_eps,
            dtype=model_config.dtype
        )
        
        # MLP/FFN
        self.mlp = GatedMLP(
            hidden_size=config.hidden_size,
            ffn_hidden_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            dtype=model_config.dtype,
            bias=False,
            tp_group=model_config.mapping.tp_group,
            tp_size=model_config.mapping.tp_size,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_metadata: AttentionMetadata,
        **kwargs
    ):
        # Pre-norm + attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attn_metadata=attn_metadata
        )
        hidden_states = residual + hidden_states
        
        # Pre-norm + MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        
        return hidden_states


class MyModel(DecoderModel):
    """Main model without language modeling head."""
    
    def __init__(self, model_config: ModelConfig[MyConfig]):
        super().__init__(model_config)
        config = model_config.pretrained_config
        
        # Token embeddings
        self.embed_tokens = Embedding(
            num_embeddings=config.vocab_size,
            embedding_dim=config.hidden_size,
            dtype=model_config.dtype,
            tp_group=model_config.mapping.tp_group,
            tp_size=model_config.mapping.tp_size,
            sharding_dim=model_config.embedding_sharding_dim,
            tp_rank=model_config.mapping.tp_rank,
        )
        
        # Transformer layers
        self.layers = nn.ModuleList([
            MyDecoderLayer(model_config, layer_idx)
            for layer_idx in range(config.num_hidden_layers)
        ])
        
        # Final layer norm
        self.norm = RMSNorm(
            normalized_shape=config.hidden_size,
            eps=config.rms_norm_eps,
            dtype=model_config.dtype
        )

    def forward(
        self,
        attn_metadata: AttentionMetadata,
        input_ids: Optional[torch.IntTensor] = None,
        position_ids: Optional[torch.IntTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ):
        # Embed tokens
        if inputs_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = inputs_embeds
        
        # Forward through layers
        for layer in self.layers:
            hidden_states = layer(
                hidden_states=hidden_states,
                attn_metadata=attn_metadata,
                position_ids=position_ids
            )
        
        # Final layer norm
        hidden_states = self.norm(hidden_states)
        
        return hidden_states


@register_auto_model("MyModelForCausalLM")
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
    """Model with language modeling head for causal LM."""
    
    def __init__(self, model_config: ModelConfig[MyConfig]):
        super().__init__(
            MyModel(model_config),
            config=model_config,
            hidden_size=model_config.pretrained_config.hidden_size,
            vocab_size=model_config.pretrained_config.vocab_size
        )

Key Components

The Attention module handles all attention computation:
from tensorrt_llm._torch.modules.attention import Attention

class MyAttention(Attention):
    def __init__(self, model_config, layer_idx):
        super().__init__(
            hidden_size=config.hidden_size,
            num_attention_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,  # For GQA
            max_position_embeddings=config.max_position_embeddings,
            attention_mask_type=AttentionMaskType.causal,
            position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
            # ... other params
        )
Features:
  • Automatic KV cache management
  • Support for GQA, MQA, MHA
  • RoPE, ALiBi position embeddings
  • Flash Attention integration
  • Tensor parallelism
Use TensorRT-LLM’s optimized modules for best performance:
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
Benefits:
  • Automatic tensor parallelism
  • Quantization support
  • Optimized CUDA kernels
  • Reduced memory usage
Packed Mode: Input tensors are in packed format where the first dimension is the total number of tokens in the batch:
# input_ids shape: [total_tokens]
# NOT [batch_size, seq_len]

# Example batch:
# Sequence 1: [1, 2, 3] (3 tokens)
# Sequence 2: [4, 5] (2 tokens)
# Packed: [1, 2, 3, 4, 5] (5 tokens total)
AttentionMetadata: Contains batching and KV cache metadata:
  • Sequence lengths
  • Cumulative sequence lengths
  • KV cache locations
  • Attention mask information
Pass attn_metadata to all attention modules.

Step 3: Weight Loading

Implement weight loading to convert checkpoint weights to your model format.

Default Weight Loading

The base DecoderModelForCausalLM provides automatic weight loading that works for most models:
@register_auto_model("MyModelForCausalLM")
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
    # Default load_weights works automatically
    pass

Custom Weight Loading

If you need custom weight mapping, override load_weights:
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
    
    def load_weights(self, weights: dict):
        """
        Load and convert weights from checkpoint.
        
        Args:
            weights: Dict of parameter names to tensors
        """
        # Example: Fused QKV projection
        for layer_idx in range(self.config.num_hidden_layers):
            # Collect Q, K, V weights from checkpoint
            q_weight = weights[f"model.layers.{layer_idx}.self_attn.q_proj.weight"]
            k_weight = weights[f"model.layers.{layer_idx}.self_attn.k_proj.weight"]
            v_weight = weights[f"model.layers.{layer_idx}.self_attn.v_proj.weight"]
            
            # Concatenate for fused QKV
            qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
            
            # Handle tensor parallelism
            if self.config.mapping.tp_size > 1:
                qkv_weight = self._split_tensor_parallel(
                    qkv_weight,
                    dim=0,
                    rank=self.config.mapping.tp_rank
                )
            
            # Assign to model
            self.model.layers[layer_idx].self_attn.qkv_proj.weight.data = qkv_weight
        
        # Load other weights using default method
        super().load_weights(weights)

Weight Loading Examples

# Fuse gate and up projections for SwiGLU
gate_weight = weights[f"layers.{i}.mlp.gate_proj.weight"]
up_weight = weights[f"layers.{i}.mlp.up_proj.weight"]

fused_weight = torch.cat([gate_weight, up_weight], dim=0)
self.model.layers[i].mlp.gate_up_proj.weight.data = fused_weight

Step 4: Model Registration

Register your model so it can be auto-discovered by TensorRT-LLM.

Using Your Custom Model

Once registered, use your model like any other TensorRT-LLM model:
from tensorrt_llm import LLM
import modeling_mymodel  # Trigger registration

llm = LLM(model="/path/to/checkpoint")
outputs = llm.generate(["Tell me a story"])

Complete Example

See the full out-of-tree model example:
cd tensorrt_llm/examples/llm-api/out_of_tree_example
python main.py
This example implements a complete OPT model as an out-of-tree model.

Testing Your Model

1

Unit Tests

Test individual components:
def test_mymodel_forward():
    from modeling_mymodel import MyModel
    from configuration_mymodel import MyConfig
    
    config = MyConfig()
    model = MyModel(config)
    
    # Test forward pass
    input_ids = torch.randint(0, config.vocab_size, (1, 10))
    outputs = model(input_ids)
    
    assert outputs.shape == (1, 10, config.hidden_size)
2

Generation Test

Test end-to-end generation:
def test_generation():
    from tensorrt_llm import LLM
    import modeling_mymodel
    
    llm = LLM(model="/path/to/checkpoint")
    outputs = llm.generate(["Hello"], max_new_tokens=50)
    
    assert len(outputs[0]) > 0
3

Accuracy Test

Compare outputs with reference implementation:
def test_accuracy():
    from transformers import AutoModel
    import modeling_mymodel
    
    # Reference model
    ref_model = AutoModel.from_pretrained("/path/to/checkpoint")
    ref_output = ref_model(input_ids)
    
    # TensorRT-LLM model
    trt_model = MyModel.from_pretrained("/path/to/checkpoint")
    trt_output = trt_model(input_ids)
    
    torch.testing.assert_close(ref_output, trt_output, rtol=1e-3)

Best Practices

  • Inherit from base classes: Use DecoderModel, DecoderLayer, Attention for compatibility
  • Use optimized modules: Prefer TensorRT-LLM modules over PyTorch for better performance
  • Handle packed input: Ensure your model handles packed token sequences correctly
  • Pass attn_metadata: Always pass AttentionMetadata to attention modules
  • Use module-level load_weights: Call Linear.load_weights(), Embedding.load_weights() for TP/quantization
  • Handle weight fusion: Fuse QKV, gate/up projections for better performance
  • Test with different TP sizes: Verify weight loading works with tensor parallelism
  • Support quantization: Ensure weight loader handles quantized checkpoints
  • Fuse operations: Combine QKV, gate/up projections into single Linear layers
  • Use RMSNorm: Faster than LayerNorm when applicable
  • Enable tensor parallelism: Use Linear(..., tp_group=...) for automatic TP
  • Optimize memory: Use torch.utils.checkpoint for gradient checkpointing during training
  • Unit test components: Test attention, layer, model separately
  • Compare with reference: Validate outputs against HuggingFace/original implementation
  • Test all TP configurations: Test with TP=1, 2, 4, 8
  • Benchmark performance: Measure throughput and latency vs baselines

Common Issues

Problem: Tensor shape errors in forward passSolution: Remember that inputs are in packed mode:
# Input shape is [total_tokens], not [batch_size, seq_len]
input_ids.shape  # [5] for batch with sequences of length 3 and 2
Problem: Weights don’t load correctlySolution: Check weight names and shapes:
# Debug weight loading
print("Checkpoint keys:", weights.keys())
print("Model keys:", model.state_dict().keys())

# Check shapes
for name, param in model.named_parameters():
    if name in weights:
        print(f"{name}: model {param.shape} vs checkpoint {weights[name].shape}")
Problem: Tensor parallelism produces wrong resultsSolution: Ensure all Linear layers have TP config:
Linear(
    in_features=hidden_size,
    out_features=output_size,
    dtype=dtype,
    tp_group=model_config.mapping.tp_group,  # Required
    tp_size=model_config.mapping.tp_size,    # Required
    gather_output=True  # For column-parallel layers
)

Example Models

Study these reference implementations:

LLaMA

Standard decoder-only model with RoPE and GQA

GPT

Classic GPT architecture with learned positions

Gemma

Model with custom normalization and sliding window

OPT (Out-of-tree)

Complete out-of-tree model example

Next Steps

Model Configuration

Configure your custom model

Quantization

Add quantization support

Deployment

Deploy your model to production

Build docs developers (and LLMs) love