Skip to main content

Overview

The GPT class implements a decoder-only transformer model with several modern improvements over the original GPT architecture.

GPTConfig

Configuration dataclass for the GPT model.
sequence_len
int
default:"2048"
Maximum context length for the model
vocab_size
int
default:"32768"
Size of the vocabulary (typically 2^15)
n_layer
int
default:"12"
Number of transformer layers (depth). This is the primary complexity dial.
n_head
int
default:"6"
Number of query attention heads
n_kv_head
int
default:"6"
Number of key/value heads for Group-Query Attention (GQA). Set equal to n_head for multi-head attention, or less for GQA.
n_embd
int
default:"768"
Model embedding dimension (width). Typically calculated as depth * aspect_ratio.
window_pattern
str
default:"SSSL"
Sliding window attention pattern tiled across layers:
  • L = Long (full context)
  • S = Short (half context)
  • Final layer always uses full context
Examples: "L" (all full), "SL" (alternating), "SSSL" (two short, two long)

GPT Class

from nanochat.gpt import GPT, GPTConfig

# Create a model
config = GPTConfig(
    n_layer=12,
    n_head=6,
    n_embd=768,
    vocab_size=32768,
    sequence_len=2048
)
model = GPT(config)

Architecture Features

The nanochat GPT implementation includes several modern improvements:

Rotary Embeddings

RoPE for relative positional encoding (no learned positional embeddings)

QK Normalization

Normalizes queries and keys for stable attention

Untied Weights

Separate token embedding and language model head weights

ReLU² Activation

Squared ReLU activation in MLP layers

Group-Query Attention

GQA support for efficient inference with KV cache

Flash Attention 3

Automatic FA3 on Hopper+ GPUs, SDPA fallback elsewhere

Sliding Window

Configurable attention window patterns per layer

No Bias Terms

All linear layers use bias=False

Forward Pass

# Training forward pass
logits, loss = model(input_ids, targets=target_ids)

# Inference forward pass
logits, _ = model(input_ids)
idx
torch.Tensor
required
Input token indices of shape (B, T) where B is batch size and T is sequence length
targets
torch.Tensor
Target token indices for computing loss. If provided, returns (logits, loss), otherwise (logits, None)
kv_cache
dict
Key-value cache for efficient autoregressive generation. Used by the Engine class.
logits
torch.Tensor
Output logits of shape (B, T, vocab_size)
loss
torch.Tensor
Cross-entropy loss if targets provided, otherwise None

Optimizer Configuration

The model provides a configure_optimizers() method that groups parameters for the MuonAdamW optimizer:
optimizer = model.configure_optimizers(
    weight_decay=0.2,
    embedding_lr=0.3,
    unembedding_lr=0.004,
    matrix_lr=0.02,
    scalar_lr=0.5,
    device_type="cuda"
)
Parameters are grouped into:
  • Muon group: Weight matrices (Muon optimizer)
  • Adam embedding: Token embeddings (AdamW)
  • Adam unembedding: Language model head (AdamW)
  • Adam scalar: Scalar parameters like gates (AdamW)

Helper Functions

norm(x)

Purely functional RMSNorm with no learnable parameters:
from nanochat.gpt import norm

normalized = norm(x)  # RMSNorm over last dimension

apply_rotary_emb(x, cos, sin)

Applies rotary position embeddings:
from nanochat.gpt import apply_rotary_emb

q_rotated = apply_rotary_emb(q, cos, sin)
k_rotated = apply_rotary_emb(k, cos, sin)

has_ve(layer_idx, n_layer)

Determines if a layer should have Value Embedding:
from nanochat.gpt import has_ve

if has_ve(layer_idx, n_layer):
    # Layer uses value embedding
    pass

See Also

Build docs developers (and LLMs) love