Overview
The Training API provides functionality for fine-tuning vision-language models (Qwen2-VL and Molmo) with LoRA adapters. It includes checkpoint management, configuration handling, and integration with Weights & Biases for experiment tracking.Main Functions
run_train
Training configuration object containing model, data, and hyperparameter settings
- Environment setup and distributed training initialization
- Model and processor loading
- Dataset preparation
- LoRA adapter configuration (optional)
- Training with automatic checkpointing
- Best model saving and merging
olmocr/train/train.py:87
The function automatically detects distributed training setup and adjusts logging levels accordingly. Only rank 0 processes perform model saving and checkpoint uploads.
get_rank
Returns the current process rank in distributed training.Current process rank (0 if not in distributed mode)
olmocr/train/train.py:81
update_wandb_config
Training configuration
HuggingFace Trainer instance
Model instance (may include PEFT adapters)
olmocr/train/train.py:57
Classes
CheckpointUploadCallback
Callback for uploading checkpoints to remote storage (S3 or local) during training. Inherits:transformers.TrainerCallback
Source: olmocr/train/train.py:41
Constructor
Base path where checkpoints will be saved (supports S3 paths like
s3://bucket/path)Optional logger instance (defaults to class logger)
Methods
on_save
Called whenever the trainer saves a checkpoint. Copies the latest checkpoint to the specified save path.Trainer arguments
Current trainer state
Trainer control object
TrainConfig Structure
TheTrainConfig dataclass contains all training configuration:
Key Configuration Sections
Model configuration including:
name_or_path: HuggingFace model identifieruse_flash_attn: Enable Flash Attention 2
LoRA adapter configuration:
rank: LoRA rank (r)alpha: LoRA alpha scalingdropout: LoRA dropout ratebias: Bias training strategytarget_modules: List of module names to apply LoRAtask_type: Task type (e.g., “CAUSAL_LM”)
Training dataset configuration:
sources: List of data sources with glob pathsseed: Random seed for reproducibilitycache_location: PDF cache directory
Validation dataset configuration:
sources: Validation data sourcesmetric_for_best_model: Metric to use for best model selection
Training hyperparameters:
learning_rate: Learning ratebatch_size: Per-device batch sizegradient_accumulation_steps: Gradient accumulationmax_steps: Maximum training stepswarmup_steps: Warmup stepswarmup_ratio: Warmup ratioweight_decay: Weight decayclip_grad_norm: Gradient clippinggradient_checkpointing: Enable gradient checkpointingoptim: Optimizer namelog_every_steps: Logging frequencyeval_every_steps: Evaluation frequency
Checkpoint saving configuration:
path: Base save path (local or S3)save_every_steps: Checkpoint frequency
Generation configuration:
max_length: Maximum sequence length
Weights & Biases configuration:
project: W&B project nameentity: W&B entity/teamapi_key: API key (optional)
AWS configuration for S3 access:
access_key_id: AWS access keysecret_access_key: AWS secret key
Number of dataloader workers
Training Pipeline Flow
Model Support
The training API supports two vision-language model architectures:Qwen2-VL
Multi-modal models from Alibaba with vision and language capabilities
- Auto-detected from model name containing “qwen”
- Uses Flash Attention 2 when enabled
- Requires
pixel_valuesandimage_grid_thwin batches
Molmo
Vision-language models with custom architecture
- Auto-detected from model name containing “molmo”
- Adjustable max position embeddings
- Requires
images,image_input_idx, andimage_masksin batches
Advanced Usage
Custom Training Script
Distributed Training
Best Practices
Memory Optimization
Memory Optimization
- Enable
gradient_checkpointingfor large models - Use Flash Attention 2 (
use_flash_attn=True) - Adjust
batch_sizeandgradient_accumulation_steps - Consider using LoRA for memory-efficient fine-tuning
Checkpoint Management
Checkpoint Management
- Set reasonable
save_every_stepsto balance disk space and recovery - Use S3 paths for centralized checkpoint storage
- Monitor
load_best_model_at_endto ensure best model is saved - Keep
config.yamlwith checkpoints for reproducibility
LoRA Configuration
LoRA Configuration
- Start with rank 16-32 for most tasks
- Set
alpha = 2 * rankas a good default - Target attention layers:
["q_proj", "v_proj", "k_proj", "o_proj"] - Merge adapters before inference for faster generation
See Also
- Data Loading API - Dataset preparation and loading
- Evaluation API - Model evaluation and metrics
- Configuration Guide - Detailed configuration options