Skip to main content

Overview

The dataset module provides PyTorch Dataset classes and utility functions for loading, splitting, and creating DataLoaders for malware image datasets. It supports stratified splitting, class imbalance handling, and flexible data augmentation.

Classes

MalwareDataset

PyTorch Dataset for loading malware images from disk.
class MalwareDataset(Dataset):
    def __init__(self, image_paths: list[Path], labels: list[int], transform=None)
image_paths
list[Path]
required
List of Path objects pointing to image files
labels
list[int]
required
List of integer labels corresponding to each image
transform
callable
default:"None"
Optional torchvision transforms to apply to images

Methods

__len__() Returns the total number of samples in the dataset. __getitem__(idx: int) Loads and returns a sample from the dataset at the given index.
  • Opens the image file and converts to RGB
  • Applies transforms if provided
  • Returns tuple of (image_tensor, label)

Example

from pathlib import Path
from training.dataset import MalwareDataset
from training.transforms import create_train_transforms

# Create dataset
image_paths = [Path("dataset/trojan/sample1.png"), Path("dataset/worm/sample2.png")]
labels = [0, 1]
transform = create_train_transforms(config)

dataset = MalwareDataset(image_paths, labels, transform=transform)

# Access samples
image, label = dataset[0]
print(f"Dataset size: {len(dataset)}")

Functions

scan_dataset

Scans a dataset directory and collects all image paths, labels, and class names.
def scan_dataset(
    dataset_path: Path,
    selected_families: list[str] | None = None
) -> tuple[list[Path], list[int], list[str]]
dataset_path
Path
required
Path to the dataset directory containing family subdirectories
selected_families
list[str]
default:"None"
Optional list of family names to include. If None, includes all families.
Returns: Tuple of (image_paths, labels, class_names) Directory Structure:
dataset/
  ├── trojan/
  │   ├── sample1.png
  │   └── sample2.png
  ├── worm/
  │   └── sample3.png
  └── ransomware/
      └── sample4.png

Example

from pathlib import Path
from training.dataset import scan_dataset

dataset_path = Path("dataset")
image_paths, labels, class_names = scan_dataset(dataset_path)

print(f"Found {len(image_paths)} images")
print(f"Classes: {class_names}")

# Filter specific families
image_paths, labels, class_names = scan_dataset(
    dataset_path,
    selected_families=["trojan", "worm"]
)

create_splits

Creates stratified train/validation/test splits from the dataset.
def create_splits(
    image_paths: list[Path],
    labels: list[int],
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15,
    stratified: bool = True,
    random_seed: int = 72,
) -> dict
image_paths
list[Path]
required
List of all image paths
labels
list[int]
required
List of all labels
train_ratio
float
default:"0.7"
Proportion of data for training (0.0 to 1.0)
val_ratio
float
default:"0.15"
Proportion of data for validation (0.0 to 1.0)
test_ratio
float
default:"0.15"
Proportion of data for testing (0.0 to 1.0)
stratified
bool
default:"True"
Whether to maintain class distribution in splits
random_seed
int
default:"72"
Random seed for reproducibility
Returns: Dictionary with keys ‘train’, ‘val’, ‘test’, each containing ‘paths’ and ‘labels’

Example

from training.dataset import scan_dataset, create_splits

image_paths, labels, class_names = scan_dataset(Path("dataset"))

splits = create_splits(
    image_paths,
    labels,
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
    stratified=True,
    random_seed=42
)

print(f"Train: {len(splits['train']['paths'])} samples")
print(f"Val: {len(splits['val']['paths'])} samples")
print(f"Test: {len(splits['test']['paths'])} samples")

compute_class_weights

Computes inverse frequency class weights for handling class imbalance.
def compute_class_weights(labels: list[int], num_classes: int) -> torch.Tensor
labels
list[int]
required
List of training labels
num_classes
int
required
Total number of classes
Returns: Tensor of shape (num_classes,) with normalized class weights Formula: weight[i] = total_samples / (num_classes * count[i])

Example

import torch
from training.dataset import compute_class_weights

labels = [0, 0, 0, 1, 1, 2]  # Imbalanced: 3:2:1 ratio
num_classes = 3

weights = compute_class_weights(labels, num_classes)
print(weights)  # Higher weights for minority classes
# tensor([0.6667, 1.0000, 2.0000])

create_weighted_sampler

Creates a WeightedRandomSampler for balanced batch sampling.
def create_weighted_sampler(
    labels: list[int],
    num_classes: int
) -> WeightedRandomSampler
labels
list[int]
required
List of training labels
num_classes
int
required
Total number of classes
Returns: WeightedRandomSampler that oversamples minority classes

Example

from torch.utils.data import DataLoader
from training.dataset import MalwareDataset, create_weighted_sampler

sampler = create_weighted_sampler(train_labels, num_classes=3)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    sampler=sampler,  # Use sampler instead of shuffle
    num_workers=4
)

create_dataloaders

High-level function that creates train/val/test DataLoaders from configuration.
def create_dataloaders(
    dataset_config: dict,
    training_config: dict,
    num_workers: int = 4,
) -> tuple[dict[str, DataLoader], list[str], torch.Tensor | None]
dataset_config
dict
required
Dataset configuration with keys:
  • dataset_path: Path to dataset directory
  • selected_families: Optional list of family names
  • split: Dict with train/val/test ratios
  • preprocessing: Preprocessing settings
  • augmentation: Augmentation settings
training_config
dict
required
Training configuration with keys:
  • batch_size: Batch size for DataLoaders
  • class_weights: Class weighting method (“None”, “Auto Class Weights”, “Focal Loss”)
num_workers
int
default:"4"
Number of worker processes for data loading
Returns: Tuple of (dataloaders_dict, class_names, class_weights)
  • dataloaders_dict: Dict with ‘train’, ‘val’, ‘test’ DataLoaders
  • class_names: List of class names
  • class_weights: Tensor of class weights or None

Example

from training.dataset import create_dataloaders

dataset_config = {
    "dataset_path": "dataset",
    "selected_families": None,
    "split": {
        "train": 70,
        "val": 15,
        "test": 15,
        "stratified": True,
        "random_seed": 72
    },
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Moderate"
    }
}

training_config = {
    "batch_size": 32,
    "class_weights": "Auto Class Weights"
}

dataloaders, class_names, class_weights = create_dataloaders(
    dataset_config,
    training_config,
    num_workers=4
)

train_loader = dataloaders["train"]
val_loader = dataloaders["val"]
test_loader = dataloaders["test"]

print(f"Classes: {class_names}")
print(f"Train batches: {len(train_loader)}")

Build docs developers (and LLMs) love