Skip to main content
The task_datasets module provides utilities for loading and tokenizing supervised learning datasets for classification tasks (e.g., GLUE SST-2) and sequence-to-sequence tasks (e.g., SAMSum summarization, GSM8K).

load_supervised_text_dataset

Load and tokenize a supervised dataset for classification or seq2seq tasks.
from modern_llm.data.task_datasets import (
    load_supervised_text_dataset,
    TaskDatasetConfig,
)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Load SST-2 sentiment classification
config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    label_field="label",
    task_type="classification",
    max_source_length=128,
)

dataset = load_supervised_text_dataset(config, tokenizer)
config
TaskDatasetConfig
required
Configuration object specifying dataset parameters and task type
tokenizer
PreTrainedTokenizerBase
required
Tokenizer whose vocabulary matches the model for fine-tuning
target_tokenizer
Optional[PreTrainedTokenizerBase]
default:"None"
Separate tokenizer for targets in seq2seq tasks. If None, uses the same tokenizer for both source and target.
dataset
Dataset
Tokenized dataset with input_ids, attention_mask, and labels fields formatted for PyTorch

Task types

Classification: Returns labels as integer class indices
# Binary or multi-class classification
config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    task_type="classification",
    label_field="label",
)
Seq2seq: Returns tokenized target sequences
# Summarization, translation, etc.
config = TaskDatasetConfig(
    dataset_name="samsum",
    text_fields=["dialogue"],
    task_type="seq2seq",
    label_field="summary",
    max_source_length=512,
    max_target_length=128,
)

Returns

Returns a tokenized datasets.Dataset with:
  • input_ids: Tokenized input sequences
  • attention_mask: Attention masks (1 for real tokens, 0 for padding)
  • labels: For classification, integer class labels; for seq2seq, tokenized target sequences

Complexity

O(num_examples · max_length) for tokenization.

TaskDatasetConfig

Configuration dataclass for supervised classification and seq2seq datasets.
from modern_llm.data.task_datasets import TaskDatasetConfig

# Classification task
config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    label_field="label",
    task_type="classification",
    max_source_length=128,
    split="train",
)

# Seq2seq task
config = TaskDatasetConfig(
    dataset_name="samsum",
    text_fields=["dialogue"],
    label_field="summary",
    task_type="seq2seq",
    max_source_length=512,
    max_target_length=128,
)
dataset_name
str
required
Hugging Face dataset name (e.g., “glue”, “samsum”, “gsm8k”)
text_fields
Sequence[str]
required
List of column names containing input text. Multiple fields are concatenated with newlines.
dataset_config_name
Optional[str]
default:"None"
Dataset configuration name (e.g., “sst2” for GLUE)
split
str
default:"train"
Dataset split to load (“train”, “validation”, or “test”)
label_field
Optional[str]
default:"label"
Name of the column containing labels or target sequences
task_type
str
default:"classification"
Task type: “classification” or “seq2seq”
max_source_length
int
default:"512"
Maximum sequence length for input text. Must be positive.
max_target_length
Optional[int]
default:"None"
Maximum sequence length for targets in seq2seq tasks. Required for seq2seq, must be positive.
padding
str
default:"max_length"
Padding strategy: “max_length” or “longest”
num_proc
Optional[int]
default:"None"
Number of processes for parallel tokenization. None uses single process.

Validation

The config validates on initialization:
  • dataset_name must be non-empty
  • text_fields must contain at least one column name
  • max_source_length must be positive
  • task_type must be “classification” or “seq2seq”
  • Seq2seq tasks require max_target_length > 0

Examples

SST-2 sentiment classification

from modern_llm.data.task_datasets import (
    load_supervised_text_dataset,
    TaskDatasetConfig,
)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    label_field="label",
    task_type="classification",
    max_source_length=128,
    split="train",
)

train_dataset = load_supervised_text_dataset(config, tokenizer)
print(f"Loaded {len(train_dataset)} training examples")

# Inspect an example
example = train_dataset[0]
print(f"Input IDs shape: {example['input_ids'].shape}")
print(f"Label: {example['labels']}")

SAMSum dialogue summarization

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

config = TaskDatasetConfig(
    dataset_name="samsum",
    text_fields=["dialogue"],
    label_field="summary",
    task_type="seq2seq",
    max_source_length=512,
    max_target_length=128,
    split="train",
)

dataset = load_supervised_text_dataset(config, tokenizer)
print(f"Loaded {len(dataset)} dialogue-summary pairs")

GSM8K math word problems

tokenizer = AutoTokenizer.from_pretrained("gpt2")

config = TaskDatasetConfig(
    dataset_name="gsm8k",
    dataset_config_name="main",
    text_fields=["question"],
    label_field="answer",
    task_type="seq2seq",
    max_source_length=256,
    max_target_length=256,
)

dataset = load_supervised_text_dataset(config, tokenizer)

Multi-field concatenation

# Concatenate multiple text fields with newlines
config = TaskDatasetConfig(
    dataset_name="your-dataset",
    text_fields=["title", "context", "question"],  # All concatenated
    label_field="answer",
    task_type="seq2seq",
    max_source_length=512,
    max_target_length=128,
)

dataset = load_supervised_text_dataset(config, tokenizer)

Parallel tokenization for large datasets

# Use multiple processes for faster tokenization
config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    task_type="classification",
    max_source_length=128,
    num_proc=8,  # Use 8 processes
)

dataset = load_supervised_text_dataset(config, tokenizer)

Validation split for evaluation

# Load validation set
val_config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    task_type="classification",
    max_source_length=128,
    split="validation",
)

val_dataset = load_supervised_text_dataset(val_config, tokenizer)

Custom padding strategy

# Use dynamic padding instead of max_length
config = TaskDatasetConfig(
    dataset_name="samsum",
    text_fields=["dialogue"],
    label_field="summary",
    task_type="seq2seq",
    max_source_length=512,
    max_target_length=128,
    padding="longest",  # Pad to longest in batch
)

dataset = load_supervised_text_dataset(config, tokenizer)

Separate tokenizers for encoder-decoder models

from transformers import AutoTokenizer

# Use different tokenizers for source and target
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")

config = TaskDatasetConfig(
    dataset_name="samsum",
    text_fields=["dialogue"],
    label_field="summary",
    task_type="seq2seq",
    max_source_length=512,
    max_target_length=128,
)

dataset = load_supervised_text_dataset(
    config,
    tokenizer=encoder_tokenizer,
    target_tokenizer=decoder_tokenizer,
)

Use with PyTorch DataLoader

import torch
from torch.utils.data import DataLoader

config = TaskDatasetConfig(
    dataset_name="glue",
    dataset_config_name="sst2",
    text_fields=["sentence"],
    task_type="classification",
    max_source_length=128,
)

dataset = load_supervised_text_dataset(config, tokenizer)

# Dataset is already formatted for PyTorch
train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
)

for batch in train_loader:
    input_ids = batch["input_ids"]        # [32, 128]
    attention_mask = batch["attention_mask"]  # [32, 128]
    labels = batch["labels"]              # [32] for classification
    break

Build docs developers (and LLMs) love