Skip to main content

evaluate_pipeline_stages

def evaluate_pipeline_stages(
    state: PipelineState,
    config: PipelineConfig,
    max_eval_batches: Optional[int] = 100,
) -> Path
Evaluate all pipeline stages (Base, SFT, DPO) and save comparison results. This function loads each checkpoint from the pipeline state, computes perplexity on WikiText-2 validation data, and generates both JSON and CSV reports comparing the stages.

Parameters

state
PipelineState
required
Pipeline state containing checkpoint paths for pretrain, SFT, and DPO stages.
config
PipelineConfig
required
Pipeline configuration with tokenizer name, sequence length, and run name.
max_eval_batches
Optional[int]
default:"100"
Maximum number of batches to evaluate per stage. Set to None to evaluate the full dataset.

Returns

results_path
Path
Path to the saved JSON results file (typically experiments/results/{run_name}_eval.json).

Side effects

  • Creates experiments/results/ directory if it doesn’t exist
  • Saves JSON results to experiments/results/{run_name}_eval.json
  • Saves CSV comparison table to experiments/results/{run_name}_comparison.csv
  • Logs perplexity and loss for each stage

Usage

from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.evaluation.pipeline_eval import evaluate_pipeline_stages
from modern_llm.config import PipelineConfig

# After running alignment pipeline
config = PipelineConfig(
    run_name="modern-llm-experiment",
    tokenizer_name="gpt2",
    max_seq_len=512,
)

pipeline = AlignmentPipeline(config)
state = pipeline.run()

# Evaluate all stages
results_path = evaluate_pipeline_stages(
    state=state,
    config=config,
    max_eval_batches=50,  # Faster evaluation
)

print(f"Results saved to: {results_path}")
# Output: Results saved to: experiments/results/modern-llm-experiment_eval.json

compute_perplexity

def compute_perplexity(
    model: ModernDecoderLM,
    dataloader: DataLoader,
    device: torch.device,
    max_batches: Optional[int] = None,
) -> tuple[float, float]
Compute perplexity on a dataset by averaging cross-entropy loss.

Mathematical definition

perplexity=exp(1Ni=1NLi)\text{perplexity} = \exp\left(\frac{1}{N} \sum_{i=1}^{N} \mathcal{L}_i\right) where Li\mathcal{L}_i is the cross-entropy loss for batch ii and NN is the number of batches.

Parameters

model
ModernDecoderLM
required
Language model to evaluate. Must be on the specified device and in eval mode.
dataloader
DataLoader
required
DataLoader providing batches with input_ids, attention_mask, and labels.
device
torch.device
required
Device to run evaluation on (typically torch.device("cuda") or torch.device("cpu")).
max_batches
Optional[int]
default:"None"
Maximum number of batches to evaluate. If None, evaluates the entire dataset.

Returns

perplexity
float
Computed perplexity value. Returns inf if all batches had NaN/inf loss.
avg_loss
float
Average cross-entropy loss across batches. Returns inf if all batches had NaN/inf loss.

Preconditions

  • Model must be on the specified device and in eval mode
  • DataLoader batches must contain valid input tensors

Complexity

  • Time: O(N × L × D²) where N is number of batches, L is sequence length, D is model dimension
  • Space: O(B × L) where B is batch size

Usage

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from modern_llm.models.transformer import ModernDecoderLM
from modern_llm.data.lm_datasets import load_causal_lm_dataset, LanguageModelingDatasetConfig
from modern_llm.evaluation.pipeline_eval import compute_perplexity
from modern_llm.utils.checkpointing import load_checkpoint
from modern_llm.config import ModernLLMConfig

# Load model from checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = load_checkpoint(Path("checkpoints/final.pt"))
model_config = ModernLLMConfig(**ckpt["config"])
model = ModernDecoderLM(model_config)
model.load_state_dict(ckpt["model_state"])
model.to(device)
model.eval()

# Prepare evaluation data
tokenizer = AutoTokenizer.from_pretrained("gpt2")
eval_config = LanguageModelingDatasetConfig(
    dataset_name="wikitext",
    dataset_config_name="wikitext-2-raw-v1",
    split="validation",
    max_length=512,
)
eval_dataset = load_causal_lm_dataset(eval_config, tokenizer)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, shuffle=False)

# Compute perplexity
perplexity, loss = compute_perplexity(
    model=model,
    dataloader=eval_dataloader,
    device=device,
    max_batches=100,
)

print(f"Perplexity: {perplexity:.2f}")
print(f"Average Loss: {loss:.4f}")

StageMetrics

@dataclass
class StageMetrics:
    stage: str
    perplexity: float
    loss: float
    gsm8k_em: float = 0.0
    gsm8k_em_verifier: float = 0.0
    num_params: int = 0
Metrics for a single pipeline stage (Base, SFT, or DPO).

Attributes

stage
str
Stage name (“base”, “sft”, or “dpo”).
perplexity
float
Perplexity on WikiText-2 validation set.
loss
float
Average cross-entropy loss on WikiText-2 validation set.
gsm8k_em
float
default:"0.0"
GSM8K exact match accuracy (0-1 range).
gsm8k_em_verifier
float
default:"0.0"
GSM8K exact match accuracy with verifier reranking (0-1 range).
num_params
int
default:"0"
Total number of trainable parameters in the model.

PipelineEvalResults

@dataclass
class PipelineEvalResults:
    base: Optional[StageMetrics] = None
    sft: Optional[StageMetrics] = None
    dpo: Optional[StageMetrics] = None
    verifier_accuracy: float = 0.0
Full evaluation results across all pipeline stages.

Attributes

base
Optional[StageMetrics]
Metrics for the pretrained base model.
sft
Optional[StageMetrics]
Metrics for the supervised fine-tuned model.
dpo
Optional[StageMetrics]
Metrics for the DPO-aligned model.
verifier_accuracy
float
default:"0.0"
Verifier model accuracy on answer correctness classification.

Methods

to_dict

def to_dict(self) -> dict
Convert results to a dictionary for serialization.

save

def save(self, path: Path) -> None
Save results to a JSON file. Creates parent directories if needed.

to_csv

def to_csv(self, path: Path) -> None
Save stage comparison table as a CSV file.

Usage

from pathlib import Path
from modern_llm.evaluation.pipeline_eval import (
    PipelineEvalResults,
    StageMetrics,
)

# Create results
results = PipelineEvalResults(
    base=StageMetrics(
        stage="base",
        perplexity=28.5,
        loss=3.35,
        num_params=124_000_000,
    ),
    sft=StageMetrics(
        stage="sft",
        perplexity=22.1,
        loss=3.10,
        gsm8k_em=0.45,
        num_params=124_000_000,
    ),
    dpo=StageMetrics(
        stage="dpo",
        perplexity=24.3,
        loss=3.19,
        gsm8k_em=0.52,
        num_params=124_000_000,
    ),
    verifier_accuracy=0.78,
)

# Save as JSON
results.save(Path("experiments/results/eval.json"))

# Save as CSV for comparison
results.to_csv(Path("experiments/results/comparison.csv"))

# Access metrics
print(f"Base PPL: {results.base.perplexity:.1f}")
print(f"SFT PPL: {results.sft.perplexity:.1f}")
print(f"DPO PPL: {results.dpo.perplexity:.1f}")
print(f"PPL improvement: {results.base.perplexity - results.sft.perplexity:.1f}")

Build docs developers (and LLMs) love