Skip to main content

AlignmentPipeline

class AlignmentPipeline:
    def __init__(
        self,
        config: PipelineConfig,
        checkpoint_dir: Optional[Path] = None,
    )
Orchestrates the full alignment pipeline with four stages: pretraining, supervised fine-tuning (SFT), Direct Preference Optimization (DPO), and verifier training.

Pipeline stages

The alignment pipeline implements a complete RLHF-inspired training workflow:
  1. Pretrain: Language model pretraining on unlabeled text (e.g., WikiText-2)
  2. SFT: Supervised fine-tuning on instruction-following data
  3. DPO: Direct Preference Optimization on human preference data
  4. Verifier: Train answer correctness classifier for reranking

Parameters

config
PipelineConfig
required
Pipeline configuration specifying:
  • Model hyperparameters (size, layers, etc.)
  • Training hyperparameters for each stage
  • Dataset names and paths
  • Tokenizer name
  • Run name for checkpointing
checkpoint_dir
Optional[Path]
default:"None"
Directory to save checkpoints. Defaults to experiments/runs/{run_name}.

Attributes

config
PipelineConfig
Pipeline configuration.
checkpoint_dir
Path
Directory containing all stage checkpoints.
state
PipelineState
Current pipeline state with checkpoint paths and metrics.
state_path
Path
Path to saved state JSON file (pipeline_state.json).

Methods

run

def run(
    self,
    skip_pretrain: bool = False,
    pretrain_checkpoint: Optional[Path] = None,
    skip_sft: bool = False,
    skip_dpo: bool = False,
    skip_verifier: bool = False,
) -> PipelineState
Execute the full pipeline with optional stage skipping. Parameters:
skip_pretrain
bool
default:"False"
If True, skip pretraining stage. Must provide pretrain_checkpoint or have existing checkpoint in state.
pretrain_checkpoint
Optional[Path]
default:"None"
Path to existing pretrain checkpoint to use instead of training.
skip_sft
bool
default:"False"
If True, skip SFT stage and use pretrain checkpoint for DPO.
skip_dpo
bool
default:"False"
If True, skip DPO stage and use SFT checkpoint as final model.
skip_verifier
bool
default:"False"
If True, skip verifier training stage.
Returns:
state
PipelineState
Final pipeline state containing all checkpoint paths and metrics.
Raises:
  • ValueError - If skip_pretrain=True but no pretrain checkpoint is available
Side effects:
  • Saves checkpoints to checkpoint_dir after each stage
  • Updates and saves pipeline_state.json after each stage
  • Logs progress and metrics to console and log files

load_model

def load_model(self, stage: str) -> ModernDecoderLM
Load model from a specific pipeline stage. Parameters:
stage
str
required
Stage name: "pretrain", "sft", or "dpo".
Returns:
model
ModernDecoderLM
Loaded model on appropriate device (CUDA if available, otherwise CPU).
Raises:
  • ValueError - If stage is unknown or checkpoint doesn’t exist

load_verifier

def load_verifier(self) -> VerifierModel
Load trained verifier model. Returns:
verifier
VerifierModel
Loaded verifier model on appropriate device.
Raises:
  • ValueError - If verifier checkpoint doesn’t exist

Usage

Basic pipeline execution

from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig

# Configure pipeline
config = PipelineConfig(
    run_name="modern-llm-alignment",
    tokenizer_name="gpt2",
    d_model=512,
    n_layers=6,
    n_heads=8,
    max_seq_len=512,
    # Pretraining
    pretrain_steps=10_000,
    pretrain_batch_size=16,
    pretrain_lr=1e-3,
    pretrain_datasets=["wikitext"],
    # SFT
    sft_steps=5_000,
    sft_batch_size=8,
    sft_lr=5e-5,
    sft_dataset="tatsu-lab/alpaca",
    # DPO
    dpo_steps=3_000,
    dpo_batch_size=4,
    dpo_lr=1e-5,
    dpo_beta=0.1,
    dpo_dataset="Anthropic/hh-rlhf",
    # Verifier
    verifier_steps=2_000,
    verifier_batch_size=16,
    verifier_lr=1e-4,
)

# Run full pipeline
pipeline = AlignmentPipeline(config)
state = pipeline.run()

print(f"Pretrain checkpoint: {state.pretrain_checkpoint}")
print(f"SFT checkpoint: {state.sft_checkpoint}")
print(f"DPO checkpoint: {state.dpo_checkpoint}")
print(f"Verifier checkpoint: {state.verifier_checkpoint}")

Resume from existing checkpoint

from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig

config = PipelineConfig(
    run_name="modern-llm-alignment",
    tokenizer_name="gpt2",
    max_seq_len=512,
)

# Pipeline automatically loads existing state from checkpoint_dir
pipeline = AlignmentPipeline(config)

# Skip pretrain and SFT, only run DPO and verifier
state = pipeline.run(
    skip_pretrain=True,  # Use existing pretrain checkpoint
    skip_sft=True,       # Use existing SFT checkpoint
)

Use external pretrained model

from pathlib import Path
from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig

config = PipelineConfig(
    run_name="alignment-from-gpt2",
    tokenizer_name="gpt2",
    max_seq_len=512,
)

pipeline = AlignmentPipeline(config)
state = pipeline.run(
    skip_pretrain=True,
    pretrain_checkpoint=Path("pretrained_models/gpt2-base.pt"),
)

Load and use trained models

from modern_llm.alignment.alignment_pipeline import AlignmentPipeline
from modern_llm.config import PipelineConfig
import torch

# Load pipeline state
config = PipelineConfig(run_name="modern-llm-alignment", tokenizer_name="gpt2")
pipeline = AlignmentPipeline(config)

# Load final DPO model
model = pipeline.load_model("dpo")
model.eval()

# Generate text
prompt = "Explain how photosynthesis works:"
tokenizer = pipeline.tokenizer
input_ids = tokenizer.encode(prompt, return_tensors="pt")

with torch.no_grad():
    outputs = model.generate(input_ids, max_length=100)
    
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)

# Load verifier for reranking
verifier = pipeline.load_verifier()
verifier.eval()

# Score multiple candidate answers
candidates = [
    "Photosynthesis converts light into energy...",
    "Plants use sunlight to make food...",
    "It's a process in leaves...",
]

scores = []
for candidate in candidates:
    input_ids = tokenizer.encode(
        prompt + " " + candidate,
        return_tensors="pt",
    )
    with torch.no_grad():
        score = verifier(input_ids)["logits"].squeeze()
    scores.append(score.item())

best_idx = max(range(len(scores)), key=lambda i: scores[i])
print(f"Best answer: {candidates[best_idx]}")

PipelineState

@dataclass
class PipelineState:
    pretrain_checkpoint: Optional[Path] = None
    sft_checkpoint: Optional[Path] = None
    dpo_checkpoint: Optional[Path] = None
    verifier_checkpoint: Optional[Path] = None
    pretrain_metrics: Optional[dict] = None
    sft_metrics: Optional[dict] = None
    dpo_metrics: Optional[dict] = None
    verifier_metrics: Optional[dict] = None
Tracks checkpoint paths and metrics across pipeline stages.

Attributes

pretrain_checkpoint
Optional[Path]
Path to pretrained base model checkpoint.
sft_checkpoint
Optional[Path]
Path to supervised fine-tuned model checkpoint.
dpo_checkpoint
Optional[Path]
Path to DPO-aligned model checkpoint.
verifier_checkpoint
Optional[Path]
Path to trained verifier model checkpoint.
pretrain_metrics
Optional[dict]
Training metrics from pretraining stage (loss, perplexity, etc.).
sft_metrics
Optional[dict]
Training metrics from SFT stage.
dpo_metrics
Optional[dict]
Training metrics from DPO stage.
verifier_metrics
Optional[dict]
Training metrics from verifier stage.

Methods

to_dict

def to_dict(self) -> dict
Serialize state to dictionary.

save

def save(self, path: Path) -> None
Save state to JSON file. Creates parent directories if needed.

load

@classmethod
def load(cls, path: Path) -> PipelineState
Load state from JSON file.

Usage

from pathlib import Path
from modern_llm.alignment.alignment_pipeline import PipelineState

# Create state
state = PipelineState(
    pretrain_checkpoint=Path("checkpoints/pretrain.pt"),
    sft_checkpoint=Path("checkpoints/sft.pt"),
    pretrain_metrics={"loss": 3.2, "perplexity": 24.5},
    sft_metrics={"loss": 2.1, "accuracy": 0.78},
)

# Save state
state.save(Path("experiments/pipeline_state.json"))

# Load state
loaded_state = PipelineState.load(Path("experiments/pipeline_state.json"))
print(f"Pretrain checkpoint: {loaded_state.pretrain_checkpoint}")

run_alignment_pipeline

def run_alignment_pipeline(
    config: PipelineConfig,
    checkpoint_dir: Optional[Path] = None,
    pretrain_checkpoint: Optional[Path] = None,
    skip_pretrain: bool = False,
    skip_sft: bool = False,
    skip_dpo: bool = False,
    skip_verifier: bool = False,
) -> PipelineState
Convenience function to execute the full alignment pipeline. This is equivalent to creating an AlignmentPipeline instance and calling run(), but provides a simpler functional interface.

Parameters

config
PipelineConfig
required
Pipeline configuration with model and training hyperparameters.
checkpoint_dir
Optional[Path]
default:"None"
Directory to save checkpoints. Defaults to experiments/runs/{run_name}.
pretrain_checkpoint
Optional[Path]
default:"None"
Path to existing pretrain checkpoint to use instead of training.
skip_pretrain
bool
default:"False"
If True, skip pretraining stage.
skip_sft
bool
default:"False"
If True, skip SFT stage.
skip_dpo
bool
default:"False"
If True, skip DPO stage.
skip_verifier
bool
default:"False"
If True, skip verifier training stage.

Returns

state
PipelineState
Final pipeline state with all checkpoint paths and metrics.

Usage

from modern_llm.alignment.alignment_pipeline import run_alignment_pipeline
from modern_llm.config import PipelineConfig

config = PipelineConfig(
    run_name="quick-alignment",
    tokenizer_name="gpt2",
    max_seq_len=512,
)

# Run full pipeline
state = run_alignment_pipeline(config)

print(f"DPO model: {state.dpo_checkpoint}")

References

  • SFT: Ouyang et al. (2022). Training language models to follow instructions with human feedback. (InstructGPT)
  • DPO: Rafailov et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model.
  • Verifier reranking: Lightman et al. (2023). Let’s Verify Step by Step.

Build docs developers (and LLMs) love