Overview
ChemLactica provides custom trainer implementations that extend HuggingFace Transformers’ Trainer and TRL’s IterativeSFTTrainer with additional functionality for debugging, distributed training, and SLURM evaluation.
CustomArguments
CustomArguments extends TrainingArguments from HuggingFace Transformers with ChemLactica-specific parameters.
Class Definition
from dataclasses import dataclass, field
from transformers import TrainingArguments
@dataclass
class CustomArguments(TrainingArguments):
slurm_eval: bool = field(
default=False, metadata={"help": "Whether to run eval via slurm job."}
)
command: str = field(default=None)
experiment_name: str = field(default=None)
tokenizer_path: str = field(
default="/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66"
)
Parameters
Whether to run evaluation via SLURM job submission. Enables distributed evaluation on compute clusters
The training command being executed, stored for reproducibility and debugging
Name of the experiment for tracking, logging, and checkpoint organization
tokenizer_path
str
default:"ChemLacticaTokenizer66"
Path to the tokenizer directory or model
Inherited Parameters
CustomArguments inherits all parameters from HuggingFace’s TrainingArguments, including:
output_dir: Directory to save checkpoints
per_device_train_batch_size: Training batch size per device
per_device_eval_batch_size: Evaluation batch size per device
learning_rate: Initial learning rate
num_train_epochs: Number of training epochs
max_steps: Maximum number of training steps
warmup_steps: Number of warmup steps
logging_steps: Log metrics every N steps
save_steps: Save checkpoint every N steps
eval_steps: Run evaluation every N steps
gradient_accumulation_steps: Number of gradient accumulation steps
bf16: Enable bfloat16 mixed precision
fp16: Enable float16 mixed precision
- And many more…
See HuggingFace TrainingArguments documentation for the complete list.
Usage Example
from chemlactica.custom_trainer import CustomArguments
training_args = CustomArguments(
command="python train.py --model 125M",
slurm_eval=False,
experiment_name="chemlactica-pretrain",
tokenizer_path="chemlactica/tokenizer/ChemLacticaTokenizer66",
output_dir="./checkpoints",
per_device_train_batch_size=8,
per_device_eval_batch_size=16,
learning_rate=1e-4,
warmup_steps=1000,
max_steps=100000,
logging_steps=1,
save_steps=10000,
eval_steps=5000,
gradient_accumulation_steps=4,
bf16=True,
)
CustomTrainer
CustomTrainer extends HuggingFace’s Trainer with debugging capabilities and custom distributed training setup.
Class Definition
from transformers import Trainer
class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
# Number of samples to print when training begins, for debugging
self.num_samples_to_print = 10
self.tokenizer_path = kwargs["args"].tokenizer_path
super().__init__(*args, **kwargs)
def training_step(self, model, inputs):
# Prints first N samples at the start of training
...
def create_accelerator_and_postprocess(self):
# Custom accelerator creation with CustomAccelerator
...
Key Features
Sample Debugging
The training_step method prints the first 10 training samples when training begins:
def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tensor:
if self.num_samples_to_print:
tokenizer = get_tokenizer(self.tokenizer_path)
for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
print(f"Sample {i + 1}:", tokenizer.decode(inputs["input_ids"][i]))
self.num_samples_to_print = None
return super().training_step(model, inputs)
This helps verify that data preprocessing is working correctly.
Custom Accelerator
The create_accelerator_and_postprocess method uses CustomAccelerator instead of the standard Accelerate accelerator:
def create_accelerator_and_postprocess(self):
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
grad_acc_kwargs["sync_with_dataloader"] = False
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
self.accelerator = CustomAccelerator(
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**self.args.accelerator_config.to_dict(),
)
...
Usage Example
from chemlactica.custom_trainer import CustomTrainer, CustomArguments
from chemlactica.utils.model_utils import load_model
model = load_model(
"OSS-Models/ChemLactica-125M",
use_flash_attn=True,
)
training_args = CustomArguments(
output_dir="./checkpoints",
per_device_train_batch_size=8,
tokenizer_path="chemlactica/tokenizer/ChemLacticaTokenizer66",
)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
CustomIterativeSFTTrainer
CustomIterativeSFTTrainer extends TRL’s IterativeSFTTrainer with sample debugging for iterative SFT training.
Class Definition
from trl import IterativeSFTTrainer
class CustomIterativeSFTTrainer(IterativeSFTTrainer):
def __init__(self, *args, **kwargs):
# Number of samples to print when training begins
self.num_samples_to_print = 5
super().__init__(*args, **kwargs)
def training_step(self, model, inputs):
if self.num_samples_to_print:
for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
print(f"Sample {i + 1}:", self.tokenizer.decode(inputs["input_ids"][i]))
self.num_samples_to_print = None
return super().training_step(model, inputs)
Usage Example
from chemlactica.custom_trainer import CustomIterativeSFTTrainer, CustomArguments
trainer = CustomIterativeSFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
Source Reference
Implemented in chemlactica/custom_trainer.py