run_training
from modern_llm.training.train_lm import run_training
Run causal language model pretraining on one or more datasets.
Parameters
Model architecture configuration defining layers, dimensions, and attention settings.
Training hyperparameters including batch size, learning rate, and optimization settings.
dataset_names
Optional[list]
default:"None"
List of dataset names to train on. If None, defaults to ["wikitext-2-raw-v1"]. Supports:
"wikitext-2-raw-v1"
"wikitext-103-raw-v1"
"tinystories"
"openwebtext"
"c4"
HuggingFace tokenizer name or path. The model’s vocab size will be updated to match.
Returns
Path to the final checkpoint file (e.g., experiments/runs/{run_name}_final.pt).
Usage
from pathlib import Path
from modern_llm.config import ModernLLMConfig, TrainingConfig
from modern_llm.training.train_lm import run_training
# Define model architecture
model_config = ModernLLMConfig(
d_model=512,
n_layers=8,
n_heads=8,
ffn_hidden_size=2048,
max_seq_len=1024,
dropout=0.1,
use_rope=True,
use_swiglu=True,
)
# Define training hyperparameters
train_config = TrainingConfig(
run_name="gpt-pretrain",
dataset_name="multi",
tokenizer_name="gpt2",
output_dir=Path("experiments/pretrain"),
batch_size=128,
micro_batch_size=8,
gradient_accumulation_steps=16,
learning_rate=3e-4,
max_steps=50000,
warmup_steps=2000,
weight_decay=0.1,
eval_every=5000,
save_every=10000,
log_every=100,
mixed_precision="bf16",
compile_model=True,
)
# Run pretraining
checkpoint = run_training(
model_config=model_config,
train_config=train_config,
dataset_names=["wikitext-2-raw-v1", "tinystories"],
tokenizer_name="gpt2",
)
print(f"Training complete! Checkpoint: {checkpoint}")
Multi-Dataset Training
When multiple datasets are provided, they are concatenated during training:
# Train on multiple corpora
checkpoint = run_training(
model_config=model_config,
train_config=train_config,
dataset_names=[
"wikitext-2-raw-v1",
"tinystories",
"openwebtext",
],
)
Validation always uses WikiText-2 validation split as a standard benchmark.
generate_text
from modern_llm.training.train_lm import generate_text
Generate text from a trained ModernDecoderLM model using sampling with temperature and top-k filtering.
Parameters
The trained language model. Will be set to eval mode during generation.
tokenizer
PreTrainedTokenizer
required
Tokenizer used to encode the prompt and decode generated tokens.
Input text to condition generation on. Must be non-empty and tokenize to less than model.config.max_seq_len.
Maximum number of tokens to generate. Must be positive.
Sampling temperature. Higher values (e.g., 1.5) increase randomness, lower values (e.g., 0.7) make output more deterministic. Must be > 0.
top_k
Optional[int]
default:"50"
If specified, only sample from the top-k most likely tokens at each step. Set to None to disable top-k filtering.
Returns
The complete generated text including the original prompt and continuation.
Usage
from modern_llm.models import ModernDecoderLM
from modern_llm.training.train_lm import generate_text
from modern_llm.utils.checkpointing import load_checkpoint
from transformers import AutoTokenizer
import torch
# Load trained model
ckpt = load_checkpoint("experiments/runs/my-model_final.pt")
model = ModernDecoderLM.from_config(ckpt["config"])
model.load_state_dict(ckpt["model_state"])
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Generate text
output = generate_text(
model=model,
tokenizer=tokenizer,
prompt="Once upon a time",
max_new_tokens=100,
temperature=0.8,
top_k=50,
)
print(output)
Sampling Parameters
Temperature
Controls randomness in sampling:
temperature=0.7: More focused, coherent output
temperature=1.0: Balanced sampling (default)
temperature=1.5: More creative, diverse output
# Deterministic output
output = generate_text(
model=model,
tokenizer=tokenizer,
prompt="The capital of France is",
max_new_tokens=10,
temperature=0.5,
top_k=10,
)
Top-k Filtering
Restricts sampling to the k most likely tokens:
# Conservative sampling (top-10)
output = generate_text(
model=model,
tokenizer=tokenizer,
prompt="In this tutorial, we will",
max_new_tokens=50,
temperature=1.0,
top_k=10,
)
# No filtering (sample from full distribution)
output = generate_text(
model=model,
tokenizer=tokenizer,
prompt="The meaning of life is",
max_new_tokens=50,
temperature=1.0,
top_k=None,
)
Constraints
- Prompt must tokenize to fewer than
model.config.max_seq_len tokens
- Generation stops when
max_new_tokens is reached or context length limit is hit
- Temperature must be positive (typically 0.1 to 2.0)
- Model must be in eval mode (automatically set by the function)