AlignmentPipeline
class AlignmentPipeline:
def __init__(
self,
config: PipelineConfig,
checkpoint_dir: Optional[Path] = None,
)
Orchestrates the full alignment pipeline with four stages: pretraining, supervised fine-tuning (SFT), Direct Preference Optimization (DPO), and verifier training.
Pipeline stages
The alignment pipeline implements a complete RLHF-inspired training workflow:
- Pretrain: Language model pretraining on unlabeled text (e.g., WikiText-2)
- SFT: Supervised fine-tuning on instruction-following data
- DPO: Direct Preference Optimization on human preference data
- Verifier: Train answer correctness classifier for reranking
Parameters
Pipeline configuration specifying:
- Model hyperparameters (size, layers, etc.)
- Training hyperparameters for each stage
- Dataset names and paths
- Tokenizer name
- Run name for checkpointing
checkpoint_dir
Optional[Path]
default:"None"
Directory to save checkpoints. Defaults to experiments/runs/{run_name}.
Attributes
Directory containing all stage checkpoints.
Current pipeline state with checkpoint paths and metrics.
Path to saved state JSON file (pipeline_state.json).
Methods
run
def run(
self,
skip_pretrain: bool = False,
pretrain_checkpoint: Optional[Path] = None,
skip_sft: bool = False,
skip_dpo: bool = False,
skip_verifier: bool = False,
) -> PipelineState
Execute the full pipeline with optional stage skipping.
Parameters:
If True, skip pretraining stage. Must provide pretrain_checkpoint or have existing checkpoint in state.
pretrain_checkpoint
Optional[Path]
default:"None"
Path to existing pretrain checkpoint to use instead of training.
If True, skip SFT stage and use pretrain checkpoint for DPO.
If True, skip DPO stage and use SFT checkpoint as final model.
If True, skip verifier training stage.
Returns:
Final pipeline state containing all checkpoint paths and metrics.
Raises:
ValueError - If skip_pretrain=True but no pretrain checkpoint is available
Side effects:
- Saves checkpoints to
checkpoint_dir after each stage
- Updates and saves
pipeline_state.json after each stage
- Logs progress and metrics to console and log files
load_model
def load_model(self, stage: str) -> ModernDecoderLM
Load model from a specific pipeline stage.
Parameters:
Stage name: "pretrain", "sft", or "dpo".
Returns:
Loaded model on appropriate device (CUDA if available, otherwise CPU).
Raises:
ValueError - If stage is unknown or checkpoint doesn’t exist
load_verifier
def load_verifier(self) -> VerifierModel
Load trained verifier model.
Returns:
Loaded verifier model on appropriate device.
Raises:
ValueError - If verifier checkpoint doesn’t exist
Usage
Basic pipeline execution
from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig
# Configure pipeline
config = PipelineConfig(
run_name="modern-llm-alignment",
tokenizer_name="gpt2",
d_model=512,
n_layers=6,
n_heads=8,
max_seq_len=512,
# Pretraining
pretrain_steps=10_000,
pretrain_batch_size=16,
pretrain_lr=1e-3,
pretrain_datasets=["wikitext"],
# SFT
sft_steps=5_000,
sft_batch_size=8,
sft_lr=5e-5,
sft_dataset="tatsu-lab/alpaca",
# DPO
dpo_steps=3_000,
dpo_batch_size=4,
dpo_lr=1e-5,
dpo_beta=0.1,
dpo_dataset="Anthropic/hh-rlhf",
# Verifier
verifier_steps=2_000,
verifier_batch_size=16,
verifier_lr=1e-4,
)
# Run full pipeline
pipeline = AlignmentPipeline(config)
state = pipeline.run()
print(f"Pretrain checkpoint: {state.pretrain_checkpoint}")
print(f"SFT checkpoint: {state.sft_checkpoint}")
print(f"DPO checkpoint: {state.dpo_checkpoint}")
print(f"Verifier checkpoint: {state.verifier_checkpoint}")
Resume from existing checkpoint
from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig
config = PipelineConfig(
run_name="modern-llm-alignment",
tokenizer_name="gpt2",
max_seq_len=512,
)
# Pipeline automatically loads existing state from checkpoint_dir
pipeline = AlignmentPipeline(config)
# Skip pretrain and SFT, only run DPO and verifier
state = pipeline.run(
skip_pretrain=True, # Use existing pretrain checkpoint
skip_sft=True, # Use existing SFT checkpoint
)
Use external pretrained model
from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig
config = PipelineConfig(
run_name="alignment-from-gpt2",
tokenizer_name="gpt2",
max_seq_len=512,
)
pipeline = AlignmentPipeline(config)
state = pipeline.run(
skip_pretrain=True,
pretrain_checkpoint=Path("pretrained_models/gpt2-base.pt"),
)
Load and use trained models
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig
import torch
# Load pipeline state
config = PipelineConfig(run_name="modern-llm-alignment", tokenizer_name="gpt2")
pipeline = AlignmentPipeline(config)
# Load final DPO model
model = pipeline.load_model("dpo")
model.eval()
# Generate text
prompt = "Explain how photosynthesis works:"
tokenizer = pipeline.tokenizer
input_ids = tokenizer.encode(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(input_ids, max_length=100)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
# Load verifier for reranking
verifier = pipeline.load_verifier()
verifier.eval()
# Score multiple candidate answers
candidates = [
"Photosynthesis converts light into energy...",
"Plants use sunlight to make food...",
"It's a process in leaves...",
]
scores = []
for candidate in candidates:
input_ids = tokenizer.encode(
prompt + " " + candidate,
return_tensors="pt",
)
with torch.no_grad():
score = verifier(input_ids)["logits"].squeeze()
scores.append(score.item())
best_idx = max(range(len(scores)), key=lambda i: scores[i])
print(f"Best answer: {candidates[best_idx]}")
PipelineState
@dataclass
class PipelineState:
pretrain_checkpoint: Optional[Path] = None
sft_checkpoint: Optional[Path] = None
dpo_checkpoint: Optional[Path] = None
verifier_checkpoint: Optional[Path] = None
pretrain_metrics: Optional[dict] = None
sft_metrics: Optional[dict] = None
dpo_metrics: Optional[dict] = None
verifier_metrics: Optional[dict] = None
Tracks checkpoint paths and metrics across pipeline stages.
Attributes
Path to pretrained base model checkpoint.
Path to supervised fine-tuned model checkpoint.
Path to DPO-aligned model checkpoint.
Path to trained verifier model checkpoint.
Training metrics from pretraining stage (loss, perplexity, etc.).
Training metrics from SFT stage.
Training metrics from DPO stage.
Training metrics from verifier stage.
Methods
to_dict
def to_dict(self) -> dict
Serialize state to dictionary.
save
def save(self, path: Path) -> None
Save state to JSON file. Creates parent directories if needed.
load
@classmethod
def load(cls, path: Path) -> PipelineState
Load state from JSON file.
Usage
from pathlib import Path
from modern_llm.alignment.alignment_pipeline import PipelineState
# Create state
state = PipelineState(
pretrain_checkpoint=Path("checkpoints/pretrain.pt"),
sft_checkpoint=Path("checkpoints/sft.pt"),
pretrain_metrics={"loss": 3.2, "perplexity": 24.5},
sft_metrics={"loss": 2.1, "accuracy": 0.78},
)
# Save state
state.save(Path("experiments/pipeline_state.json"))
# Load state
loaded_state = PipelineState.load(Path("experiments/pipeline_state.json"))
print(f"Pretrain checkpoint: {loaded_state.pretrain_checkpoint}")
run_alignment_pipeline
def run_alignment_pipeline(
config: PipelineConfig,
checkpoint_dir: Optional[Path] = None,
pretrain_checkpoint: Optional[Path] = None,
skip_pretrain: bool = False,
skip_sft: bool = False,
skip_dpo: bool = False,
skip_verifier: bool = False,
) -> PipelineState
Convenience function to execute the full alignment pipeline.
This is equivalent to creating an AlignmentPipeline instance and calling run(), but provides a simpler functional interface.
Parameters
Pipeline configuration with model and training hyperparameters.
checkpoint_dir
Optional[Path]
default:"None"
Directory to save checkpoints. Defaults to experiments/runs/{run_name}.
pretrain_checkpoint
Optional[Path]
default:"None"
Path to existing pretrain checkpoint to use instead of training.
If True, skip pretraining stage.
If True, skip verifier training stage.
Returns
Final pipeline state with all checkpoint paths and metrics.
Usage
from modern_llm.alignment.alignment_pipeline import run_alignment_pipeline
from modern_llm.config import PipelineConfig
config = PipelineConfig(
run_name="quick-alignment",
tokenizer_name="gpt2",
max_seq_len=512,
)
# Run full pipeline
state = run_alignment_pipeline(config)
print(f"DPO model: {state.dpo_checkpoint}")
References
- SFT: Ouyang et al. (2022). Training language models to follow instructions with human feedback. (InstructGPT)
- DPO: Rafailov et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model.
- Verifier reranking: Lightman et al. (2023). Let’s Verify Step by Step.