Skip to main content

Overview

ChemLactica provides utilities for loading and preparing datasets for both pretraining and supervised fine-tuning. The dataset system supports streaming JSONL files for efficient memory usage during pretraining and standard HuggingFace datasets for SFT.

get_dataset()

Main function for loading and preparing datasets based on training type.

Function Signature

def get_dataset(
    train_type,
    training_data_dirs,
    valid_data_dir,
    dir_data_types,
    train_config,
    model_config,
    shared_jsonl_files,
    evaluate_only,
    slurm_eval,
    shuffle_buffer_size,
)

Parameters

train_type
str
required
Type of training. Options: "pretrain" or "sft"
training_data_dirs
list[str]
required
List of directories containing training data files. For pretrain: directories with JSONL files. For SFT: path to HuggingFace dataset
valid_data_dir
str
required
Directory containing validation data files (JSONL format for pretrain)
dir_data_types
list[str]
required
List of data types corresponding to each training directory. Must match length of training_data_dirs. Supported types defined in DIR_DATA_TYPES
train_config
dict
required
Training configuration object containing training hyperparameters
model_config
dict
required
Model configuration object containing model architecture parameters and tokenizer path
shared_jsonl_files
multiprocessing.Manager.dict
required
Shared dictionary for tracking JSONL file reading positions across processes (pretrain only). Pass None for SFT
evaluate_only
bool
required
If True, only load validation dataset
slurm_eval
bool
required
If True and not evaluate_only, skip loading validation dataset (evaluation runs separately via SLURM)
shuffle_buffer_size
int
required
Size of shuffle buffer for assay datasets during pretraining

Returns

dataset
dict
Dictionary containing "train" and "validation" datasets. For pretrain, returns iterable datasets. For SFT, returns standard HuggingFace dataset dict

Pretrain Mode

When train_type="pretrain", the function:
  1. Validates that data types are supported
  2. Loads JSONL files from each training directory
  3. Creates iterable datasets using samples_generator
  4. Processes each dataset with tokenization and formatting
  5. Shuffles assay-type datasets with specified buffer size
  6. Interleaves multiple datasets if provided
  7. Loads validation dataset (unless evaluate_only=True or slurm_eval=True)
from chemlactica.get_dataset import get_dataset
import multiprocessing

with multiprocessing.Manager() as manager:
    shared_jsonl_files = manager.dict()
    
    dataset = get_dataset(
        train_type="pretrain",
        training_data_dirs=[
            "/data/molecules",
            "/data/assays"
        ],
        valid_data_dir="/data/valid",
        dir_data_types=["molecules", "assay_split"],
        train_config=train_config,
        model_config=model_config,
        shared_jsonl_files=shared_jsonl_files,
        evaluate_only=False,
        slurm_eval=False,
        shuffle_buffer_size=10000,
    )
    
    train_dataset = dataset["train"]
    eval_dataset = dataset["validation"]

SFT Mode

When train_type="sft", the function loads a standard HuggingFace dataset:
from chemlactica.get_dataset import get_dataset

dataset = get_dataset(
    train_type="sft",
    training_data_dirs=["username/dataset-name"],
    valid_data_dir=None,
    dir_data_types=None,
    train_config=train_config,
    model_config=model_config,
    shared_jsonl_files=None,
    evaluate_only=False,
    slurm_eval=False,
    shuffle_buffer_size=None,
)

Data Types

Supported data types are defined in DIR_DATA_TYPES. Common types include:
  • "molecules": Molecular structure data
  • "assay_split": Assay data that should be shuffled
  • Other domain-specific types
Datasets with “assay” in the type name are automatically shuffled using the specified shuffle_buffer_size.

samples_generator()

Generator function for streaming JSONL files in a distributed manner.

Function Signature

def samples_generator(
    files: List[str],
    shared_jsonl_files,
    chunk_size=25000,
    return_line_info=False
)

Parameters

files
list[str]
required
List of JSONL file paths to read
shared_jsonl_files
multiprocessing.Manager.dict
required
Shared dictionary for tracking file positions across processes, enabling checkpoint resumption
chunk_size
int
default:"25000"
Size of chunks for reading files (currently not used in line-by-line reading)
return_line_info
bool
default:"False"
Whether to return line position information with samples

Yields

sample
dict
Dictionary with "text" key containing the line content from the JSONL file

Features

Distributed Reading

The generator distributes samples across multiple processes:
def should_yield_on_current_rank(i, num_processes, process_index):
    return i % num_processes == process_index
Each process only yields samples at indices matching its rank, ensuring no duplicate processing.

Checkpoint Resumption

The generator tracks file positions in shared_jsonl_files:
file_states = {f: {"position": 0, "line_number": 0} for f in files}
for file in file_states.keys():
    if shared_jsonl_files.get(file):
        jsonl_state = shared_jsonl_files[file]
        file_states[file] = jsonl_state
        print(f"loaded {file}: {jsonl_state['position']}")
When resuming from a checkpoint, the generator seeks to the saved position in each file.

Sample Format

Each line is formatted as:
def format_sample(line):
    sample = line.strip()
    ret = {"text": sample}
    return ret

Usage Example

from chemlactica.jsonl_dataset import samples_generator
from datasets.iterable_dataset import IterableDataset
import multiprocessing

with multiprocessing.Manager() as manager:
    shared_jsonl_files = manager.dict()
    
    training_files = [
        "/data/train/molecules_001.jsonl",
        "/data/train/molecules_002.jsonl",
    ]
    
    dataset = IterableDataset.from_generator(
        samples_generator,
        gen_kwargs={
            "files": training_files,
            "shared_jsonl_files": shared_jsonl_files,
        },
    )
    
    for sample in dataset:
        print(sample["text"])

Dataset Processing

All datasets loaded by get_dataset() are processed using process_dataset from chemlactica.utils.dataset_utils, which handles:
  • Tokenization with the specified tokenizer
  • Sequence formatting and truncation
  • Batching and padding
  • Special handling for assay vs. non-assay data
Processing parameters:
dataset = process_dataset(
    dataset=dataset,
    train_config=train_config,
    model_config=model_config,
    process_batch_sizes=(50, 50),
    is_eval=False,
    assay=is_assay_split,
)

Source Reference

  • get_dataset(): chemlactica/get_dataset.py:10
  • samples_generator(): chemlactica/jsonl_dataset.py:37

Build docs developers (and LLMs) love