Skip to main content

Overview

ChemLactica provides custom trainer implementations that extend HuggingFace Transformers’ Trainer and TRL’s IterativeSFTTrainer with additional functionality for debugging, distributed training, and SLURM evaluation.

CustomArguments

CustomArguments extends TrainingArguments from HuggingFace Transformers with ChemLactica-specific parameters.

Class Definition

from dataclasses import dataclass, field
from transformers import TrainingArguments

@dataclass
class CustomArguments(TrainingArguments):
    slurm_eval: bool = field(
        default=False, metadata={"help": "Whether to run eval via slurm job."}
    )
    command: str = field(default=None)
    experiment_name: str = field(default=None)
    tokenizer_path: str = field(
        default="/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66"
    )

Parameters

slurm_eval
bool
default:"False"
Whether to run evaluation via SLURM job submission. Enables distributed evaluation on compute clusters
command
str
default:"None"
The training command being executed, stored for reproducibility and debugging
experiment_name
str
default:"None"
Name of the experiment for tracking, logging, and checkpoint organization
tokenizer_path
str
default:"ChemLacticaTokenizer66"
Path to the tokenizer directory or model

Inherited Parameters

CustomArguments inherits all parameters from HuggingFace’s TrainingArguments, including:
  • output_dir: Directory to save checkpoints
  • per_device_train_batch_size: Training batch size per device
  • per_device_eval_batch_size: Evaluation batch size per device
  • learning_rate: Initial learning rate
  • num_train_epochs: Number of training epochs
  • max_steps: Maximum number of training steps
  • warmup_steps: Number of warmup steps
  • logging_steps: Log metrics every N steps
  • save_steps: Save checkpoint every N steps
  • eval_steps: Run evaluation every N steps
  • gradient_accumulation_steps: Number of gradient accumulation steps
  • bf16: Enable bfloat16 mixed precision
  • fp16: Enable float16 mixed precision
  • And many more…
See HuggingFace TrainingArguments documentation for the complete list.

Usage Example

from chemlactica.custom_trainer import CustomArguments

training_args = CustomArguments(
    command="python train.py --model 125M",
    slurm_eval=False,
    experiment_name="chemlactica-pretrain",
    tokenizer_path="chemlactica/tokenizer/ChemLacticaTokenizer66",
    output_dir="./checkpoints",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    learning_rate=1e-4,
    warmup_steps=1000,
    max_steps=100000,
    logging_steps=1,
    save_steps=10000,
    eval_steps=5000,
    gradient_accumulation_steps=4,
    bf16=True,
)

CustomTrainer

CustomTrainer extends HuggingFace’s Trainer with debugging capabilities and custom distributed training setup.

Class Definition

from transformers import Trainer

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        # Number of samples to print when training begins, for debugging
        self.num_samples_to_print = 10
        self.tokenizer_path = kwargs["args"].tokenizer_path
        super().__init__(*args, **kwargs)
    
    def training_step(self, model, inputs):
        # Prints first N samples at the start of training
        ...
    
    def create_accelerator_and_postprocess(self):
        # Custom accelerator creation with CustomAccelerator
        ...

Key Features

Sample Debugging

The training_step method prints the first 10 training samples when training begins:
def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tensor:
    if self.num_samples_to_print:
        tokenizer = get_tokenizer(self.tokenizer_path)
        for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
            print(f"Sample {i + 1}:", tokenizer.decode(inputs["input_ids"][i]))
        self.num_samples_to_print = None
    return super().training_step(model, inputs)
This helps verify that data preprocessing is working correctly.

Custom Accelerator

The create_accelerator_and_postprocess method uses CustomAccelerator instead of the standard Accelerate accelerator:
def create_accelerator_and_postprocess(self):
    grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
    grad_acc_kwargs["sync_with_dataloader"] = False
    gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
    
    self.accelerator = CustomAccelerator(
        deepspeed_plugin=self.args.deepspeed_plugin,
        gradient_accumulation_plugin=gradient_accumulation_plugin,
        **self.args.accelerator_config.to_dict(),
    )
    ...

Usage Example

from chemlactica.custom_trainer import CustomTrainer, CustomArguments
from chemlactica.utils.model_utils import load_model

model = load_model(
    "OSS-Models/ChemLactica-125M",
    use_flash_attn=True,
)

training_args = CustomArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=8,
    tokenizer_path="chemlactica/tokenizer/ChemLacticaTokenizer66",
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

CustomIterativeSFTTrainer

CustomIterativeSFTTrainer extends TRL’s IterativeSFTTrainer with sample debugging for iterative SFT training.

Class Definition

from trl import IterativeSFTTrainer

class CustomIterativeSFTTrainer(IterativeSFTTrainer):
    def __init__(self, *args, **kwargs):
        # Number of samples to print when training begins
        self.num_samples_to_print = 5
        super().__init__(*args, **kwargs)
    
    def training_step(self, model, inputs):
        if self.num_samples_to_print:
            for i in range(min(inputs["input_ids"].size(0), self.num_samples_to_print)):
                print(f"Sample {i + 1}:", self.tokenizer.decode(inputs["input_ids"][i]))
            self.num_samples_to_print = None
        return super().training_step(model, inputs)

Usage Example

from chemlactica.custom_trainer import CustomIterativeSFTTrainer, CustomArguments

trainer = CustomIterativeSFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()

Source Reference

Implemented in chemlactica/custom_trainer.py

Build docs developers (and LLMs) love