Skip to main content
Pretraining is the first stage of the Modern LLM pipeline, where a decoder-only transformer is trained from random initialization on large text corpora using causal language modeling (next-token prediction).

Overview

The pretraining stage teaches the model basic language understanding, grammar, facts, and reasoning patterns by training it to predict the next token in sequences from diverse text sources.

Supported datasets

Modern LLM supports several high-quality text corpora for pretraining:
DatasetSizeDescription
wikitext-2-raw-v12M tokensSmall, high-quality Wikipedia articles (Merity et al., 2016)
wikitext-103-raw-v1103M tokensLarger WikiText corpus with 100+ articles
roneneldan/TinyStories~25M tokensSimple stories for small models (Gao et al., 2023)
openwebtext~8B tokensReddit-curated web content, GPT-2 training set
wikipedia~4B tokensFull Wikipedia dump (20231101.en)
Datasets are automatically downloaded from Hugging Face on first use and cached locally.

Usage

The easiest way to run pretraining is through the unified pipeline script:
python scripts/run_pipeline.py --config local-smoke --stage pretrain

Direct script usage

You can also use the standalone pretraining script:
python scripts/pretrain.py --config local
This provides the same functionality but without the pipeline orchestration features.

Configuration

Config presets

Pretraining hyperparameters are defined in the pipeline config presets:
# 5-minute smoke test
pretrain_max_steps: 100
pretrain_lr: 3e-4
pretrain_batch_size: 64
pretrain_micro_batch_size: 2
pretrain_warmup_steps: 500
pretrain_datasets: ["wikitext-2-raw-v1"]
The local preset uses only WikiText-2 for faster training. The gpu preset uses multiple large datasets for better quality.

Multi-dataset training

You can train on multiple datasets simultaneously. The datasets are concatenated and the model sees examples from all sources:
python scripts/run_pipeline.py --config local --stage pretrain \
    --pretrain-datasets "wikitext-2-raw-v1,roneneldan/TinyStories"
Dataset downsampling is supported by appending :N where N is the maximum number of examples:
python scripts/run_pipeline.py --config local --stage pretrain \
    --pretrain-datasets "openwebtext:100000,wikitext-103-raw-v1"

Hyperparameter tuning

Key hyperparameters that affect pretraining quality: Learning rate (pretrain_lr)
  • Default: 3e-4 works well for most model sizes
  • Larger models (>1B params) may need lower LR (1e-4 to 2e-4)
  • Smaller models can handle higher LR (5e-4 to 1e-3)
Batch size (pretrain_batch_size)
  • Default: 64 for local, 128 for GPU
  • Larger batches = more stable gradients but slower iteration
  • Use gradient accumulation if GPU memory is limited
Warmup steps (pretrain_warmup_steps)
  • Default: 500 steps
  • Linear warmup from 0 to pretrain_lr over first N steps
  • Prevents early training instability
Weight decay (default: 0.1)
  • Standard L2 regularization
  • Prevents overfitting on small datasets

Training details

Optimization

The pretraining implementation uses:
  • Optimizer: AdamW with β₁=0.9, β₂=0.999
  • Learning rate schedule: Linear warmup + constant LR
  • Gradient accumulation: Automatic based on batch_size / micro_batch_size
  • Mixed precision: BF16 on supported GPUs, FP32 fallback
  • Gradient clipping: Max norm 1.0

Loss function

Standard causal language modeling loss:
# Cross-entropy on shifted predictions
loss = F.cross_entropy(
    logits[:, :-1, :].reshape(-1, vocab_size),
    input_ids[:, 1:].reshape(-1),
    reduction='mean'
)
Padding tokens are masked so they don’t contribute to the loss.

Evaluation

During training, the model is evaluated on a held-out validation set every eval_every steps (default: 500). The evaluation metric is perplexity, computed as:
perplexity = exp(validation_loss)
Lower perplexity indicates better language modeling performance.

Checkpoints

The pretraining stage saves checkpoints at regular intervals:
  • Regular checkpoints: Every save_every steps (default: 2000)
    • Saved as <run_name>-pretrain_step{N}.pt
    • Includes model state, optimizer state, config, step counter
  • Final checkpoint: At the end of training
    • Saved as <run_name>-pretrain_final.pt
    • Used as input for the SFT stage
Checkpoint contents (accessible via torch.load):
checkpoint = {
    'model_state': OrderedDict(...),      # Model weights
    'optimizer_state': {...},             # Optimizer state
    'config': {...},                      # Model architecture config
    'step': 20000,                        # Training step
    'run_name': 'local-full-pretrain',
}

Monitoring

Training progress is logged to:
  1. Console output: Real-time progress bar with loss/perplexity
  2. Log file: experiments/runs/<run_name>/training.log
  3. Checkpoints: Model states saved at regular intervals
Example console output:
Pretraining: local-full-pretrain
Model: d=768, L=12, H=12
Steps: 20000
Batch: 64 (micro=2)
LR: 0.0003

Pretrain Training: 100%|████████| 20000/20000 [5:23:12<00:00, loss=3.2456, ppl=25.67]
step=5000 loss=3.8234 ppl=45.67 lr=3.000e-04
step=10000 loss=3.4123 ppl=30.32 lr=3.000e-04
step=15000 loss=3.1890 ppl=24.23 lr=3.000e-04
step=20000 loss=3.0456 ppl=21.03 lr=3.000e-04

Pretraining complete: experiments/runs/local-full/pretrain_final.pt

Implementation details

The pretraining implementation is located at:
  • src/modern_llm/training/train_lm.py:run_training() - Main training loop
  • scripts/pretrain.py - CLI wrapper
  • scripts/run_pipeline.py:run_pretrain() - Pipeline integration
Key functions: run_training(model_config, train_config, dataset_names, tokenizer_name) Main training entrypoint that:
  1. Loads tokenizer and datasets
  2. Initializes model from scratch
  3. Sets up optimizer and LR scheduler
  4. Runs training loop with evaluation
  5. Saves final checkpoint
load_multi_dataset(dataset_names, tokenizer, split, max_length) Loads and concatenates multiple datasets:
  1. Fetches each dataset from Hugging Face
  2. Tokenizes with padding/truncation
  3. Concatenates into single dataset
  4. Returns PyTorch Dataset object
See src/modern_llm/training/train_lm.py:100-236 for full implementation.

Performance tips

  • Lower micro_batch_size (increases gradient accumulation)
  • Enable gradient checkpointing (trades compute for memory)
  • Use smaller model (reduce d_model, n_layers)
  • Shorter sequences (max_seq_len)
  • Increase micro_batch_size if GPU has headroom
  • Use multiple GPUs with DDP (not yet supported)
  • Enable Flash Attention (requires use_attention_sinks=False)
  • Use bf16 mixed precision on Ampere+ GPUs
  • Train longer (increase pretrain_max_steps)
  • Use larger/more diverse datasets
  • Increase model size (if compute allows)
  • Lower learning rate for stability
  • Check for NaN losses (reduce LR or enable gradient clipping)
  • Verify dataset loading (check first few examples)
  • Monitor perplexity trend (should decrease over time)
  • Inspect generated samples (use generate_text() after training)

Next steps

After pretraining completes:
  1. Verify the checkpoint exists at experiments/runs/<run_name>/pretrain_final.pt
  2. Run SFT using this checkpoint:
    python scripts/run_pipeline.py --config local --stage sft \
        --checkpoint experiments/runs/local-full/pretrain_final.pt
    
  3. Or continue with full pipeline:
    python scripts/run_pipeline.py --config local --stage all \
        --checkpoint experiments/runs/local-full/pretrain_final.pt
    

Supervised fine-tuning

Learn how to instruction-tune your pretrained model

Build docs developers (and LLMs) love