Skip to main content

Overview

The Training API provides functionality for fine-tuning vision-language models (Qwen2-VL and Molmo) with LoRA adapters. It includes checkpoint management, configuration handling, and integration with Weights & Biases for experiment tracking.

Main Functions

run_train

config
TrainConfig
required
Training configuration object containing model, data, and hyperparameter settings
Executes the complete training pipeline including:
  • Environment setup and distributed training initialization
  • Model and processor loading
  • Dataset preparation
  • LoRA adapter configuration (optional)
  • Training with automatic checkpointing
  • Best model saving and merging
Source: olmocr/train/train.py:87
from olmocr.train.train import run_train
from olmocr.train.core.config import TrainConfig

# Load configuration from YAML or create programmatically
config = TrainConfig.from_yaml("config.yaml")

# Run training
run_train(config)
The function automatically detects distributed training setup and adjusts logging levels accordingly. Only rank 0 processes perform model saving and checkpoint uploads.

get_rank

Returns the current process rank in distributed training.
rank
int
Current process rank (0 if not in distributed mode)
Source: olmocr/train/train.py:81
from olmocr.train.train import get_rank

rank = get_rank()
if rank == 0:
    print("Main process")

update_wandb_config

config
TrainConfig
required
Training configuration
trainer
Trainer
required
HuggingFace Trainer instance
model
torch.nn.Module
required
Model instance (may include PEFT adapters)
Updates Weights & Biases run configuration with PEFT settings, script configuration, and Beaker environment variables. Source: olmocr/train/train.py:57
from olmocr.train.train import update_wandb_config

update_wandb_config(config, trainer, model)

Classes

CheckpointUploadCallback

Callback for uploading checkpoints to remote storage (S3 or local) during training. Inherits: transformers.TrainerCallback Source: olmocr/train/train.py:41

Constructor

save_path
str
required
Base path where checkpoints will be saved (supports S3 paths like s3://bucket/path)
logger
Logger
Optional logger instance (defaults to class logger)

Methods

on_save
Called whenever the trainer saves a checkpoint. Copies the latest checkpoint to the specified save path.
args
TrainingArguments
required
Trainer arguments
state
TrainerState
required
Current trainer state
control
TrainerControl
required
Trainer control object
from olmocr.train.train import CheckpointUploadCallback
from transformers import Trainer, TrainingArguments

# Create callback
checkpoint_callback = CheckpointUploadCallback(
    save_path="s3://my-bucket/experiments/run-001",
    logger=my_logger
)

# Add to trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    callbacks=[checkpoint_callback]
)
Checkpoint uploads only occur on the local process zero (main process) to avoid duplicate uploads in distributed training.

TrainConfig Structure

The TrainConfig dataclass contains all training configuration:

Key Configuration Sections

model
ModelConfig
Model configuration including:
  • name_or_path: HuggingFace model identifier
  • use_flash_attn: Enable Flash Attention 2
lora
LoraConfig | None
LoRA adapter configuration:
  • rank: LoRA rank (r)
  • alpha: LoRA alpha scaling
  • dropout: LoRA dropout rate
  • bias: Bias training strategy
  • target_modules: List of module names to apply LoRA
  • task_type: Task type (e.g., “CAUSAL_LM”)
train_data
DataConfig
Training dataset configuration:
  • sources: List of data sources with glob paths
  • seed: Random seed for reproducibility
  • cache_location: PDF cache directory
valid_data
DataConfig
Validation dataset configuration:
  • sources: Validation data sources
  • metric_for_best_model: Metric to use for best model selection
hparams
HyperparamsConfig
Training hyperparameters:
  • learning_rate: Learning rate
  • batch_size: Per-device batch size
  • gradient_accumulation_steps: Gradient accumulation
  • max_steps: Maximum training steps
  • warmup_steps: Warmup steps
  • warmup_ratio: Warmup ratio
  • weight_decay: Weight decay
  • clip_grad_norm: Gradient clipping
  • gradient_checkpointing: Enable gradient checkpointing
  • optim: Optimizer name
  • log_every_steps: Logging frequency
  • eval_every_steps: Evaluation frequency
save
SaveConfig
Checkpoint saving configuration:
  • path: Base save path (local or S3)
  • save_every_steps: Checkpoint frequency
generate
GenerateConfig
Generation configuration:
  • max_length: Maximum sequence length
wandb
WandbConfig
Weights & Biases configuration:
  • project: W&B project name
  • entity: W&B entity/team
  • api_key: API key (optional)
aws
AwsConfig
AWS configuration for S3 access:
  • access_key_id: AWS access key
  • secret_access_key: AWS secret key
max_workers
int
Number of dataloader workers

Training Pipeline Flow

1

Initialize Environment

Set up distributed training, configure logging, and initialize W&B
2

Load Model & Processor

Load the base vision-language model (Qwen2-VL or Molmo) and processor
3

Prepare Datasets

Load and preprocess training and validation datasets with proper transforms
4

Apply LoRA (Optional)

Wrap model with PEFT LoRA adapters if configured
5

Configure Training

Set up TrainingArguments with hyperparameters and callbacks
6

Train Model

Execute training loop with automatic evaluation and checkpointing
7

Save Best Model

Merge LoRA adapters (if used) and save final model

Model Support

The training API supports two vision-language model architectures:

Qwen2-VL

Multi-modal models from Alibaba with vision and language capabilities
  • Auto-detected from model name containing “qwen”
  • Uses Flash Attention 2 when enabled
  • Requires pixel_values and image_grid_thw in batches

Molmo

Vision-language models with custom architecture
  • Auto-detected from model name containing “molmo”
  • Adjustable max position embeddings
  • Requires images, image_input_idx, and image_masks in batches

Advanced Usage

Custom Training Script

from olmocr.train.train import run_train
from olmocr.train.core.cli import make_cli
from olmocr.train.core.config import TrainConfig

def main():
    # Parse config from CLI arguments and YAML
    config = make_cli(TrainConfig)
    
    # Customize config programmatically
    config.hparams.learning_rate = 2e-5
    config.lora.rank = 32
    
    # Run training
    run_train(config)

if __name__ == "__main__":
    main()

Distributed Training

# Using torchrun
torchrun --nproc_per_node=4 train_script.py --config config.yaml

# Using accelerate
accelerate launch --num_processes=4 train_script.py --config config.yaml

Best Practices

  • Enable gradient_checkpointing for large models
  • Use Flash Attention 2 (use_flash_attn=True)
  • Adjust batch_size and gradient_accumulation_steps
  • Consider using LoRA for memory-efficient fine-tuning
  • Set reasonable save_every_steps to balance disk space and recovery
  • Use S3 paths for centralized checkpoint storage
  • Monitor load_best_model_at_end to ensure best model is saved
  • Keep config.yaml with checkpoints for reproducibility
  • Start with rank 16-32 for most tasks
  • Set alpha = 2 * rank as a good default
  • Target attention layers: ["q_proj", "v_proj", "k_proj", "o_proj"]
  • Merge adapters before inference for faster generation

See Also

Build docs developers (and LLMs) love