The task_datasets module provides utilities for loading and tokenizing supervised learning datasets for classification tasks (e.g., GLUE SST-2) and sequence-to-sequence tasks (e.g., SAMSum summarization, GSM8K).
load_supervised_text_dataset
Load and tokenize a supervised dataset for classification or seq2seq tasks.
from modern_llm.data.task_datasets import (
load_supervised_text_dataset,
TaskDatasetConfig,
)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Load SST-2 sentiment classification
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
label_field="label",
task_type="classification",
max_source_length=128,
)
dataset = load_supervised_text_dataset(config, tokenizer)
config
TaskDatasetConfig
required
Configuration object specifying dataset parameters and task type
tokenizer
PreTrainedTokenizerBase
required
Tokenizer whose vocabulary matches the model for fine-tuning
target_tokenizer
Optional[PreTrainedTokenizerBase]
default:"None"
Separate tokenizer for targets in seq2seq tasks. If None, uses the same tokenizer for both source and target.
Tokenized dataset with input_ids, attention_mask, and labels fields formatted for PyTorch
Task types
Classification: Returns labels as integer class indices
# Binary or multi-class classification
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
task_type="classification",
label_field="label",
)
Seq2seq: Returns tokenized target sequences
# Summarization, translation, etc.
config = TaskDatasetConfig(
dataset_name="samsum",
text_fields=["dialogue"],
task_type="seq2seq",
label_field="summary",
max_source_length=512,
max_target_length=128,
)
Returns
Returns a tokenized datasets.Dataset with:
input_ids: Tokenized input sequences
attention_mask: Attention masks (1 for real tokens, 0 for padding)
labels: For classification, integer class labels; for seq2seq, tokenized target sequences
Complexity
O(num_examples · max_length) for tokenization.
TaskDatasetConfig
Configuration dataclass for supervised classification and seq2seq datasets.
from modern_llm.data.task_datasets import TaskDatasetConfig
# Classification task
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
label_field="label",
task_type="classification",
max_source_length=128,
split="train",
)
# Seq2seq task
config = TaskDatasetConfig(
dataset_name="samsum",
text_fields=["dialogue"],
label_field="summary",
task_type="seq2seq",
max_source_length=512,
max_target_length=128,
)
Hugging Face dataset name (e.g., “glue”, “samsum”, “gsm8k”)
List of column names containing input text. Multiple fields are concatenated with newlines.
dataset_config_name
Optional[str]
default:"None"
Dataset configuration name (e.g., “sst2” for GLUE)
Dataset split to load (“train”, “validation”, or “test”)
label_field
Optional[str]
default:"label"
Name of the column containing labels or target sequences
task_type
str
default:"classification"
Task type: “classification” or “seq2seq”
Maximum sequence length for input text. Must be positive.
max_target_length
Optional[int]
default:"None"
Maximum sequence length for targets in seq2seq tasks. Required for seq2seq, must be positive.
Padding strategy: “max_length” or “longest”
num_proc
Optional[int]
default:"None"
Number of processes for parallel tokenization. None uses single process.
Validation
The config validates on initialization:
dataset_name must be non-empty
text_fields must contain at least one column name
max_source_length must be positive
task_type must be “classification” or “seq2seq”
- Seq2seq tasks require
max_target_length > 0
Examples
SST-2 sentiment classification
from modern_llm.data.task_datasets import (
load_supervised_text_dataset,
TaskDatasetConfig,
)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
label_field="label",
task_type="classification",
max_source_length=128,
split="train",
)
train_dataset = load_supervised_text_dataset(config, tokenizer)
print(f"Loaded {len(train_dataset)} training examples")
# Inspect an example
example = train_dataset[0]
print(f"Input IDs shape: {example['input_ids'].shape}")
print(f"Label: {example['labels']}")
SAMSum dialogue summarization
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
config = TaskDatasetConfig(
dataset_name="samsum",
text_fields=["dialogue"],
label_field="summary",
task_type="seq2seq",
max_source_length=512,
max_target_length=128,
split="train",
)
dataset = load_supervised_text_dataset(config, tokenizer)
print(f"Loaded {len(dataset)} dialogue-summary pairs")
GSM8K math word problems
tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = TaskDatasetConfig(
dataset_name="gsm8k",
dataset_config_name="main",
text_fields=["question"],
label_field="answer",
task_type="seq2seq",
max_source_length=256,
max_target_length=256,
)
dataset = load_supervised_text_dataset(config, tokenizer)
Multi-field concatenation
# Concatenate multiple text fields with newlines
config = TaskDatasetConfig(
dataset_name="your-dataset",
text_fields=["title", "context", "question"], # All concatenated
label_field="answer",
task_type="seq2seq",
max_source_length=512,
max_target_length=128,
)
dataset = load_supervised_text_dataset(config, tokenizer)
Parallel tokenization for large datasets
# Use multiple processes for faster tokenization
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
task_type="classification",
max_source_length=128,
num_proc=8, # Use 8 processes
)
dataset = load_supervised_text_dataset(config, tokenizer)
Validation split for evaluation
# Load validation set
val_config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
task_type="classification",
max_source_length=128,
split="validation",
)
val_dataset = load_supervised_text_dataset(val_config, tokenizer)
Custom padding strategy
# Use dynamic padding instead of max_length
config = TaskDatasetConfig(
dataset_name="samsum",
text_fields=["dialogue"],
label_field="summary",
task_type="seq2seq",
max_source_length=512,
max_target_length=128,
padding="longest", # Pad to longest in batch
)
dataset = load_supervised_text_dataset(config, tokenizer)
Separate tokenizers for encoder-decoder models
from transformers import AutoTokenizer
# Use different tokenizers for source and target
encoder_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
decoder_tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = TaskDatasetConfig(
dataset_name="samsum",
text_fields=["dialogue"],
label_field="summary",
task_type="seq2seq",
max_source_length=512,
max_target_length=128,
)
dataset = load_supervised_text_dataset(
config,
tokenizer=encoder_tokenizer,
target_tokenizer=decoder_tokenizer,
)
Use with PyTorch DataLoader
import torch
from torch.utils.data import DataLoader
config = TaskDatasetConfig(
dataset_name="glue",
dataset_config_name="sst2",
text_fields=["sentence"],
task_type="classification",
max_source_length=128,
)
dataset = load_supervised_text_dataset(config, tokenizer)
# Dataset is already formatted for PyTorch
train_loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
)
for batch in train_loader:
input_ids = batch["input_ids"] # [32, 128]
attention_mask = batch["attention_mask"] # [32, 128]
labels = batch["labels"] # [32] for classification
break