Skip to main content

run_training

from modern_llm.training.train_lm import run_training
Run causal language model pretraining on one or more datasets.

Parameters

model_config
ModernLLMConfig
required
Model architecture configuration defining layers, dimensions, and attention settings.
train_config
TrainingConfig
required
Training hyperparameters including batch size, learning rate, and optimization settings.
dataset_names
Optional[list]
default:"None"
List of dataset names to train on. If None, defaults to ["wikitext-2-raw-v1"]. Supports:
  • "wikitext-2-raw-v1"
  • "wikitext-103-raw-v1"
  • "tinystories"
  • "openwebtext"
  • "c4"
tokenizer_name
str
default:"gpt2"
HuggingFace tokenizer name or path. The model’s vocab size will be updated to match.

Returns

checkpoint_path
Path
Path to the final checkpoint file (e.g., experiments/runs/{run_name}_final.pt).

Usage

from pathlib import Path
from modern_llm.config import ModernLLMConfig, TrainingConfig
from modern_llm.training.train_lm import run_training

# Define model architecture
model_config = ModernLLMConfig(
    d_model=512,
    n_layers=8,
    n_heads=8,
    ffn_hidden_size=2048,
    max_seq_len=1024,
    dropout=0.1,
    use_rope=True,
    use_swiglu=True,
)

# Define training hyperparameters
train_config = TrainingConfig(
    run_name="gpt-pretrain",
    dataset_name="multi",
    tokenizer_name="gpt2",
    output_dir=Path("experiments/pretrain"),
    batch_size=128,
    micro_batch_size=8,
    gradient_accumulation_steps=16,
    learning_rate=3e-4,
    max_steps=50000,
    warmup_steps=2000,
    weight_decay=0.1,
    eval_every=5000,
    save_every=10000,
    log_every=100,
    mixed_precision="bf16",
    compile_model=True,
)

# Run pretraining
checkpoint = run_training(
    model_config=model_config,
    train_config=train_config,
    dataset_names=["wikitext-2-raw-v1", "tinystories"],
    tokenizer_name="gpt2",
)

print(f"Training complete! Checkpoint: {checkpoint}")

Multi-Dataset Training

When multiple datasets are provided, they are concatenated during training:
# Train on multiple corpora
checkpoint = run_training(
    model_config=model_config,
    train_config=train_config,
    dataset_names=[
        "wikitext-2-raw-v1",
        "tinystories",
        "openwebtext",
    ],
)
Validation always uses WikiText-2 validation split as a standard benchmark.

generate_text

from modern_llm.training.train_lm import generate_text
Generate text from a trained ModernDecoderLM model using sampling with temperature and top-k filtering.

Parameters

model
ModernDecoderLM
required
The trained language model. Will be set to eval mode during generation.
tokenizer
PreTrainedTokenizer
required
Tokenizer used to encode the prompt and decode generated tokens.
prompt
str
required
Input text to condition generation on. Must be non-empty and tokenize to less than model.config.max_seq_len.
max_new_tokens
int
required
Maximum number of tokens to generate. Must be positive.
temperature
float
default:"1.0"
Sampling temperature. Higher values (e.g., 1.5) increase randomness, lower values (e.g., 0.7) make output more deterministic. Must be > 0.
top_k
Optional[int]
default:"50"
If specified, only sample from the top-k most likely tokens at each step. Set to None to disable top-k filtering.

Returns

text
str
The complete generated text including the original prompt and continuation.

Usage

from modern_llm.models import ModernDecoderLM
from modern_llm.training.train_lm import generate_text
from modern_llm.utils.checkpointing import load_checkpoint
from transformers import AutoTokenizer
import torch

# Load trained model
ckpt = load_checkpoint("experiments/runs/my-model_final.pt")
model = ModernDecoderLM.from_config(ckpt["config"])
model.load_state_dict(ckpt["model_state"])
model.eval()

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Generate text
output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt="Once upon a time",
    max_new_tokens=100,
    temperature=0.8,
    top_k=50,
)

print(output)

Sampling Parameters

Temperature Controls randomness in sampling:
  • temperature=0.7: More focused, coherent output
  • temperature=1.0: Balanced sampling (default)
  • temperature=1.5: More creative, diverse output
# Deterministic output
output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt="The capital of France is",
    max_new_tokens=10,
    temperature=0.5,
    top_k=10,
)
Top-k Filtering Restricts sampling to the k most likely tokens:
# Conservative sampling (top-10)
output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt="In this tutorial, we will",
    max_new_tokens=50,
    temperature=1.0,
    top_k=10,
)

# No filtering (sample from full distribution)
output = generate_text(
    model=model,
    tokenizer=tokenizer,
    prompt="The meaning of life is",
    max_new_tokens=50,
    temperature=1.0,
    top_k=None,
)

Constraints

  • Prompt must tokenize to fewer than model.config.max_seq_len tokens
  • Generation stops when max_new_tokens is reached or context length limit is hit
  • Temperature must be positive (typically 0.1 to 2.0)
  • Model must be in eval mode (automatically set by the function)

Build docs developers (and LLMs) love