Skip to main content

Model Configuration

This page documents the PretrainedConfig class, which defines model architecture parameters and configuration for TensorRT-LLM models.

Overview

The PretrainedConfig class is the base configuration class for all TensorRT-LLM models. It contains architecture-specific parameters that define the model structure, such as layer counts, hidden dimensions, attention configuration, and more. Model-specific config classes (like LlamaConfig, GPTConfig, etc.) inherit from PretrainedConfig and add model-specific parameters.

Location

Source file: tensorrt_llm/models/modeling_utils.py:346-548

Constructor Parameters

Core Architecture Parameters

architecture
str
required
The model architecture type.Examples:
  • "LlamaForCausalLM"
  • "GPTForCausalLM"
  • "MixtralForCausalLM"
Note: This is automatically set when using AutoConfig.from_hugging_face().
dtype
str
required
The data type for model weights and activations.Supported values:
  • "float16": FP16
  • "bfloat16": BF16
  • "float32": FP32
hidden_size
int
required
Dimensionality of the model’s hidden layers.Example: 4096 for Llama-2-7B
num_hidden_layers
int
required
Number of transformer layers in the model.Example: 32 for Llama-2-7B
num_attention_heads
int
required
Number of attention heads in each layer.Example: 32 for Llama-2-7B
vocab_size
Optional[int]
Size of the vocabulary.Example: 32000 for Llama-2

Activation and Normalization

hidden_act
str
default:"gelu"
The activation function used in the feed-forward layers.Common values:
  • "gelu": Gaussian Error Linear Unit
  • "silu": Sigmoid Linear Unit (used in Llama)
  • "relu": Rectified Linear Unit
norm_epsilon
float
The epsilon value for layer normalization.
logits_dtype
str
default:"float32"
The data type for output logits.

Position Embeddings

position_embedding_type
Union[PositionEmbeddingType, str]
default:"learned_absolute"
The type of position embedding to use.Options:
  • "learned_absolute": Learned absolute position embeddings
  • "rope_gpt_neox": Rotary Position Embeddings (GPT-NeoX style)
  • "rope_gptj": Rotary Position Embeddings (GPT-J style)
  • "alibi": Attention with Linear Biases
  • "alibi_with_scale": ALiBi with learned scale
  • "relative": Relative position embeddings
max_position_embeddings
Optional[int]
The maximum sequence length that the model can handle.Example: 4096 for Llama-2
rotary_embedding_dim
Optional[int]
The dimensionality of rotary position embeddings.Default calculation: If not specified, computed as head_size * rotary_pct where rotary_pct defaults to 1.0.

Attention Configuration

num_key_value_heads
Optional[int]
Number of key-value heads for Grouped Query Attention (GQA).Default: If not specified, equals num_attention_heads (Multi-Head Attention).Examples:
  • num_key_value_heads == num_attention_heads: Multi-Head Attention (MHA)
  • num_key_value_heads < num_attention_heads: Grouped Query Attention (GQA)
  • num_key_value_heads == 1: Multi-Query Attention (MQA)
head_size
Optional[int]
The dimension of each attention head.Default calculation: If not specified, computed as hidden_size // num_attention_heads.
qk_layernorm
bool
default:false
Whether to apply layer normalization to queries and keys in attention.

Feed-Forward Network

intermediate_size
Optional[int]
The dimensionality of the feed-forward network’s intermediate layer.Default calculation: If not specified, computed as hidden_size * 4.Example: 11008 for Llama-2-7B

Parallel Configuration

mapping
Optional[Union[Mapping, dict]]
The parallel mapping configuration.Mapping fields:
  • world_size: Total number of GPUs
  • rank: Current GPU rank
  • tp_size: Tensor parallel size
  • pp_size: Pipeline parallel size
  • cp_size: Context parallel size
  • gpus_per_node: Number of GPUs per node
Default: If not specified, creates a single-GPU mapping.

Quantization Configuration

quantization
Optional[Union[QuantConfig, dict]]
Quantization configuration for the model.QuantConfig fields:
  • quant_algo: Quantization algorithm
  • kv_cache_quant_algo: KV cache quantization algorithm
  • group_size: Group size for group-wise quantization
  • smoothquant_val: Smoothing parameter
  • exclude_modules: Modules to exclude from quantization
See Quantization Configuration for full details.

Embedding Configuration

use_parallel_embedding
bool
default:false
Whether to use parallel embedding tables (sharded across GPUs).
embedding_sharding_dim
int
default:0
The dimension along which to shard the embedding table.Options:
  • 0: Shard along vocabulary dimension
  • 1: Shard along hidden dimension

Runtime Defaults

runtime_defaults
Optional[Union[RuntimeDefaults, dict]]
Default runtime configuration values.RuntimeDefaults fields:
  • KV cache defaults
  • Scheduling defaults
  • Performance knob defaults
These are typically loaded from saved engine configurations.

Properties

Quantization Mode

quant_mode
QuantMode
The quantization mode derived from the quantization configuration.Accessed via:
config.quant_mode
quant_algo
Optional[QuantAlgo]
The quantization algorithm.Accessed via:
config.quant_algo

KV Cache Data Type

kv_dtype
str
The data type for KV cache.Returns:
  • "int8": If using INT8 KV cache quantization
  • "fp8": If using FP8 KV cache quantization
  • "fp4": If using FP4 KV cache quantization
  • config.dtype: Otherwise (same as model dtype)
Accessed via:
config.kv_dtype

Methods

Loading and Saving

from_dict
classmethod
Create a PretrainedConfig from a dictionary.Parameters:
  • config (dict): Configuration dictionary
Returns: PretrainedConfig instanceExample:
config_dict = {
    "architecture": "LlamaForCausalLM",
    "dtype": "bfloat16",
    "hidden_size": 4096,
    "num_hidden_layers": 32,
    "num_attention_heads": 32,
}
config = PretrainedConfig.from_dict(config_dict)
from_json_file
classmethod
Load a PretrainedConfig from a JSON file.Parameters:
  • config_file (str): Path to the config.json file
Returns: PretrainedConfig instanceExample:
config = PretrainedConfig.from_json_file("/path/to/config.json")
from_checkpoint
classmethod
Load a PretrainedConfig from a checkpoint directory.Parameters:
  • ckpt_dir (str): Path to checkpoint directory
Returns: PretrainedConfig instanceExample:
config = PretrainedConfig.from_checkpoint("/path/to/checkpoint")
to_dict
method
Convert the config to a dictionary.Returns: dict representation of the configExample:
config_dict = config.to_dict()
to_json_file
method
Save the config to a JSON file.Parameters:
  • config_file (str): Path to save the config
Example:
config.to_json_file("/path/to/config.json")

Rank Management

set_rank
method
Set the rank for this config instance.Parameters:
  • rank (int): The GPU rank
Example:
config.set_rank(1)  # Set to rank 1
for_each_rank
method
Iterate over all ranks, yielding a config copy for each rank.Returns: Generator yielding config instances for each rankExample:
for rank_config in config.for_each_rank():
    # rank_config is configured for a specific rank
    save_checkpoint(rank_config)

Model-Specific Configurations

Model-specific config classes extend PretrainedConfig with additional parameters:

Llama Configuration

class LlamaConfig(PretrainedConfig):
    # Inherits all PretrainedConfig parameters
    # Additional Llama-specific fields may be added
    pass

Mixtral Configuration

class MixtralConfig(PretrainedConfig):
    num_experts: int  # Number of MoE experts
    num_experts_per_tok: int  # Top-K experts per token
    # Additional Mixtral-specific fields

GPT Configuration

class GPTConfig(PretrainedConfig):
    # GPT-specific parameters
    bias: bool  # Whether to use bias in linear layers
    # Additional GPT-specific fields

Quantization Configuration

The quantization field accepts a QuantConfig object:
from tensorrt_llm.models.modeling_utils import QuantConfig, QuantAlgo

quant_config = QuantConfig(
    quant_algo=QuantAlgo.FP8,
    kv_cache_quant_algo=QuantAlgo.FP8,
    group_size=128,
    exclude_modules=["lm_head"]
)

config = PretrainedConfig(
    architecture="LlamaForCausalLM",
    dtype="bfloat16",
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    quantization=quant_config
)

QuantConfig Fields

quant_algo
Optional[QuantAlgo]
Quantization algorithm for weights.Options:
  • QuantAlgo.W8A16: INT8 weight-only
  • QuantAlgo.W4A16: INT4 weight-only
  • QuantAlgo.FP8: FP8 quantization
  • QuantAlgo.NVFP4: NVFP4 quantization
  • QuantAlgo.W4A16_AWQ: AWQ INT4 quantization
  • QuantAlgo.W8A8_SQ_PER_CHANNEL: SmoothQuant per-channel
kv_cache_quant_algo
Optional[QuantAlgo]
Quantization algorithm for KV cache.Options:
  • QuantAlgo.INT8: INT8 KV cache
  • QuantAlgo.FP8: FP8 KV cache
  • QuantAlgo.NVFP4: NVFP4 KV cache
group_size
int
default:128
Group size for group-wise quantization.
smoothquant_val
float
Smoothing parameter alpha used in SmoothQuant.
exclude_modules
Optional[List[str]]
Module name patterns that are skipped in quantization.Example:
exclude_modules=["lm_head", "*embedding*"]

Example Configurations

Llama-2-7B Configuration

config = PretrainedConfig(
    architecture="LlamaForCausalLM",
    dtype="bfloat16",
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    num_key_value_heads=32,
    vocab_size=32000,
    intermediate_size=11008,
    hidden_act="silu",
    max_position_embeddings=4096,
    position_embedding_type="rope_gpt_neox",
    norm_epsilon=1e-5,
)

Mixtral-8x7B Configuration

config = PretrainedConfig(
    architecture="MixtralForCausalLM",
    dtype="bfloat16",
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    num_key_value_heads=8,  # GQA
    vocab_size=32000,
    intermediate_size=14336,
    hidden_act="silu",
    max_position_embeddings=32768,
    position_embedding_type="rope_gpt_neox",
    num_experts=8,
    num_experts_per_tok=2,
)

Tensor Parallel Configuration

from tensorrt_llm.mapping import Mapping

mapping = Mapping(
    world_size=4,
    rank=0,
    tp_size=4,
    pp_size=1,
)

config = PretrainedConfig(
    architecture="LlamaForCausalLM",
    dtype="bfloat16",
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    mapping=mapping,
)

With Quantization

from tensorrt_llm.models.modeling_utils import QuantConfig, QuantAlgo

config = PretrainedConfig(
    architecture="LlamaForCausalLM",
    dtype="bfloat16",
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    quantization=QuantConfig(
        quant_algo=QuantAlgo.FP8,
        kv_cache_quant_algo=QuantAlgo.FP8,
    ),
)

Loading from HuggingFace

Most commonly, you’ll load configurations from HuggingFace models:
from tensorrt_llm.models import AutoConfig

config = AutoConfig.from_hugging_face(
    "meta-llama/Llama-2-7b-hf",
    dtype="bfloat16",
    trust_remote_code=False,
)
This automatically:
  • Downloads the model configuration
  • Converts HuggingFace config to TensorRT-LLM format
  • Sets appropriate defaults for the model architecture

See Also

Build docs developers (and LLMs) love