Overview
Thetrain() function is the primary entry point for training ChemLactica models. It handles both pretraining and supervised fine-tuning (SFT), managing distributed training setup, model loading, dataset preparation, and the complete training loop.
Function Signature
Parameters
The training command being executed (captured for reproducibility)
Whether to run evaluation via SLURM job submission
Type of training to perform. Options:
"pretrain" or "sft" (supervised fine-tuning)Path to pretrained model or checkpoint to start training from
Name of the model configuration to use
List of directories containing training data files
List of data types corresponding to each training data directory. Must match length of
training_data_dirsDirectory containing validation data files
Learning rate for the optimizer. If not specified, uses
train_config.max_learning_rateNumber of warmup steps for the learning rate scheduler
Maximum number of steps for the scheduler. If not specified, defaults to
max_stepsNumber of training steps between evaluations
Number of training steps between checkpoint saves
Per-device batch size for training
Size of the shuffle buffer for data loading (used for assay datasets)
Name of the experiment for tracking and logging
Root directory where model checkpoints will be saved
Number of worker processes for data loading
Whether to use Flash Attention for efficient attention computation
Number of gradient accumulation steps before optimizer update
Whether to use gradient checkpointing to reduce memory usage
If True, only run evaluation without training
Maximum number of training steps to perform
Number of training epochs
Whether to enable experiment tracking with Aim
Directory for tracking logs and metadata
Whether to enable reproducibility checks during pretraining
Per-device batch size for validation. If not specified, uses
train_batch_sizeWhether to enable PyTorch profiling
Directory to save profiling traces (required if
profile=True)Usage Example
Training Types
Pretrain
For pretraining (train_type="pretrain"), the function:
- Loads JSONL datasets from specified directories
- Supports multiple dataset types that can be interleaved
- Enables epoch callbacks and early stopping
- Supports reproducibility checks
- Uses streaming datasets for efficient memory usage
Supervised Fine-Tuning (SFT)
For SFT (train_type="sft"), the function:
- Loads datasets using HuggingFace datasets library
- Enables numerical evaluation callbacks
- Uses standard dataset loading (non-streaming)
Callbacks
The training function automatically configures several callbacks based on the provided parameters:- WPSCounterCallback: Tracks words-per-second training speed
- CustomProgressCallback: Displays training progress with FLOP metrics
- CustomAimCallback: Logs metrics to Aim (if
track=True) - EarlyStoppingCallback: Stops training at
max_steps(pretrain only) - EpochCallback: Tracks epoch completion (pretrain only)
- ReproducabilityCallback: Validates reproducibility (if
check_reproducability=True) - JsonlDatasetResumeCallback: Enables checkpoint resumption for JSONL datasets (pretrain only)
- ProfCallback: Manages PyTorch profiling (if
profile=True) - SFTNumericalEval: Evaluates numerical predictions (SFT only)
- GradientAccumulationScheduler: Dynamically adjusts gradient accumulation (if configured)
Checkpoint Management
Checkpoints are saved to:Source Reference
Implemented inchemlactica/train.py:59