Skip to main content

Overview

The train() function is the primary entry point for training ChemLactica models. It handles both pretraining and supervised fine-tuning (SFT), managing distributed training setup, model loading, dataset preparation, and the complete training loop.

Function Signature

def train(
    command,
    slurm_eval,
    train_type,
    from_pretrained,
    model_config_name,
    training_data_dirs,
    dir_data_types,
    valid_data_dir,
    learning_rate,
    warmup_steps,
    scheduler_max_steps,
    eval_steps,
    save_steps,
    train_batch_size,
    shuffle_buffer_size,
    experiment_name,
    checkpoints_root_dir,
    dataloader_num_workers,
    flash_attn,
    gradient_accumulation_steps,
    gradient_checkpointing,
    evaluate_only,
    max_steps,
    num_train_epochs=1,
    track=False,
    track_dir=None,
    check_reproducability=False,
    valid_batch_size=None,
    profile=False,
    profile_dir=None,
)

Parameters

command
str
required
The training command being executed (captured for reproducibility)
slurm_eval
bool
required
Whether to run evaluation via SLURM job submission
train_type
str
required
Type of training to perform. Options: "pretrain" or "sft" (supervised fine-tuning)
from_pretrained
str
required
Path to pretrained model or checkpoint to start training from
model_config_name
str
required
Name of the model configuration to use
training_data_dirs
list[str]
required
List of directories containing training data files
dir_data_types
list[str]
required
List of data types corresponding to each training data directory. Must match length of training_data_dirs
valid_data_dir
str
required
Directory containing validation data files
learning_rate
float
required
Learning rate for the optimizer. If not specified, uses train_config.max_learning_rate
warmup_steps
int
required
Number of warmup steps for the learning rate scheduler
scheduler_max_steps
int
required
Maximum number of steps for the scheduler. If not specified, defaults to max_steps
eval_steps
int
required
Number of training steps between evaluations
save_steps
int
required
Number of training steps between checkpoint saves
train_batch_size
int
required
Per-device batch size for training
shuffle_buffer_size
int
required
Size of the shuffle buffer for data loading (used for assay datasets)
experiment_name
str
required
Name of the experiment for tracking and logging
checkpoints_root_dir
str
required
Root directory where model checkpoints will be saved
dataloader_num_workers
int
required
Number of worker processes for data loading
flash_attn
bool
required
Whether to use Flash Attention for efficient attention computation
gradient_accumulation_steps
int
required
Number of gradient accumulation steps before optimizer update
gradient_checkpointing
bool
required
Whether to use gradient checkpointing to reduce memory usage
evaluate_only
bool
required
If True, only run evaluation without training
max_steps
int
required
Maximum number of training steps to perform
num_train_epochs
int
default:"1"
Number of training epochs
track
bool
default:"False"
Whether to enable experiment tracking with Aim
track_dir
str
default:"None"
Directory for tracking logs and metadata
check_reproducability
bool
default:"False"
Whether to enable reproducibility checks during pretraining
valid_batch_size
int
default:"None"
Per-device batch size for validation. If not specified, uses train_batch_size
profile
bool
default:"False"
Whether to enable PyTorch profiling
profile_dir
str
default:"None"
Directory to save profiling traces (required if profile=True)

Usage Example

from chemlactica.train import train

train(
    command="python train.py ...",
    slurm_eval=False,
    train_type="pretrain",
    from_pretrained="OSS-Models/ChemLactica-125M",
    model_config_name="125M",
    training_data_dirs=["/data/pretrain/molecules"],
    dir_data_types=["molecules"],
    valid_data_dir="/data/valid",
    learning_rate=1e-4,
    warmup_steps=1000,
    scheduler_max_steps=100000,
    eval_steps=5000,
    save_steps=10000,
    train_batch_size=8,
    shuffle_buffer_size=10000,
    experiment_name="chemlactica-pretrain",
    checkpoints_root_dir="./checkpoints",
    dataloader_num_workers=4,
    flash_attn=True,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    evaluate_only=False,
    max_steps=100000,
    track=True,
    track_dir="./aim_logs",
)

Training Types

Pretrain

For pretraining (train_type="pretrain"), the function:
  • Loads JSONL datasets from specified directories
  • Supports multiple dataset types that can be interleaved
  • Enables epoch callbacks and early stopping
  • Supports reproducibility checks
  • Uses streaming datasets for efficient memory usage

Supervised Fine-Tuning (SFT)

For SFT (train_type="sft"), the function:
  • Loads datasets using HuggingFace datasets library
  • Enables numerical evaluation callbacks
  • Uses standard dataset loading (non-streaming)

Callbacks

The training function automatically configures several callbacks based on the provided parameters:
  • WPSCounterCallback: Tracks words-per-second training speed
  • CustomProgressCallback: Displays training progress with FLOP metrics
  • CustomAimCallback: Logs metrics to Aim (if track=True)
  • EarlyStoppingCallback: Stops training at max_steps (pretrain only)
  • EpochCallback: Tracks epoch completion (pretrain only)
  • ReproducabilityCallback: Validates reproducibility (if check_reproducability=True)
  • JsonlDatasetResumeCallback: Enables checkpoint resumption for JSONL datasets (pretrain only)
  • ProfCallback: Manages PyTorch profiling (if profile=True)
  • SFTNumericalEval: Evaluates numerical predictions (SFT only)
  • GradientAccumulationScheduler: Dynamically adjusts gradient accumulation (if configured)

Checkpoint Management

Checkpoints are saved to:
{checkpoints_root_dir}/{organization}/{model_name}/{experiment_hash}/
The final model is saved to:
{checkpoints_root_dir}/{organization}/{model_name}/{experiment_hash}/last/

Source Reference

Implemented in chemlactica/train.py:59

Build docs developers (and LLMs) love