Overview
The dataset module provides PyTorch Dataset classes and utility functions for loading, splitting, and creating DataLoaders for malware image datasets. It supports stratified splitting, class imbalance handling, and flexible data augmentation.
Classes
MalwareDataset
PyTorch Dataset for loading malware images from disk.
class MalwareDataset(Dataset):
def __init__(self, image_paths: list[Path], labels: list[int], transform=None)
List of Path objects pointing to image files
List of integer labels corresponding to each image
Optional torchvision transforms to apply to images
Methods
__len__()
Returns the total number of samples in the dataset.
__getitem__(idx: int)
Loads and returns a sample from the dataset at the given index.
- Opens the image file and converts to RGB
- Applies transforms if provided
- Returns tuple of (image_tensor, label)
Example
from pathlib import Path
from training.dataset import MalwareDataset
from training.transforms import create_train_transforms
# Create dataset
image_paths = [Path("dataset/trojan/sample1.png"), Path("dataset/worm/sample2.png")]
labels = [0, 1]
transform = create_train_transforms(config)
dataset = MalwareDataset(image_paths, labels, transform=transform)
# Access samples
image, label = dataset[0]
print(f"Dataset size: {len(dataset)}")
Functions
scan_dataset
Scans a dataset directory and collects all image paths, labels, and class names.
def scan_dataset(
dataset_path: Path,
selected_families: list[str] | None = None
) -> tuple[list[Path], list[int], list[str]]
Path to the dataset directory containing family subdirectories
Optional list of family names to include. If None, includes all families.
Returns: Tuple of (image_paths, labels, class_names)
Directory Structure:
dataset/
├── trojan/
│ ├── sample1.png
│ └── sample2.png
├── worm/
│ └── sample3.png
└── ransomware/
└── sample4.png
Example
from pathlib import Path
from training.dataset import scan_dataset
dataset_path = Path("dataset")
image_paths, labels, class_names = scan_dataset(dataset_path)
print(f"Found {len(image_paths)} images")
print(f"Classes: {class_names}")
# Filter specific families
image_paths, labels, class_names = scan_dataset(
dataset_path,
selected_families=["trojan", "worm"]
)
create_splits
Creates stratified train/validation/test splits from the dataset.
def create_splits(
image_paths: list[Path],
labels: list[int],
train_ratio: float = 0.7,
val_ratio: float = 0.15,
test_ratio: float = 0.15,
stratified: bool = True,
random_seed: int = 72,
) -> dict
Proportion of data for training (0.0 to 1.0)
Proportion of data for validation (0.0 to 1.0)
Proportion of data for testing (0.0 to 1.0)
Whether to maintain class distribution in splits
Random seed for reproducibility
Returns: Dictionary with keys ‘train’, ‘val’, ‘test’, each containing ‘paths’ and ‘labels’
Example
from training.dataset import scan_dataset, create_splits
image_paths, labels, class_names = scan_dataset(Path("dataset"))
splits = create_splits(
image_paths,
labels,
train_ratio=0.7,
val_ratio=0.15,
test_ratio=0.15,
stratified=True,
random_seed=42
)
print(f"Train: {len(splits['train']['paths'])} samples")
print(f"Val: {len(splits['val']['paths'])} samples")
print(f"Test: {len(splits['test']['paths'])} samples")
compute_class_weights
Computes inverse frequency class weights for handling class imbalance.
def compute_class_weights(labels: list[int], num_classes: int) -> torch.Tensor
Returns: Tensor of shape (num_classes,) with normalized class weights
Formula: weight[i] = total_samples / (num_classes * count[i])
Example
import torch
from training.dataset import compute_class_weights
labels = [0, 0, 0, 1, 1, 2] # Imbalanced: 3:2:1 ratio
num_classes = 3
weights = compute_class_weights(labels, num_classes)
print(weights) # Higher weights for minority classes
# tensor([0.6667, 1.0000, 2.0000])
create_weighted_sampler
Creates a WeightedRandomSampler for balanced batch sampling.
def create_weighted_sampler(
labels: list[int],
num_classes: int
) -> WeightedRandomSampler
Returns: WeightedRandomSampler that oversamples minority classes
Example
from torch.utils.data import DataLoader
from training.dataset import MalwareDataset, create_weighted_sampler
sampler = create_weighted_sampler(train_labels, num_classes=3)
train_loader = DataLoader(
train_dataset,
batch_size=32,
sampler=sampler, # Use sampler instead of shuffle
num_workers=4
)
create_dataloaders
High-level function that creates train/val/test DataLoaders from configuration.
def create_dataloaders(
dataset_config: dict,
training_config: dict,
num_workers: int = 4,
) -> tuple[dict[str, DataLoader], list[str], torch.Tensor | None]
Dataset configuration with keys:
dataset_path: Path to dataset directory
selected_families: Optional list of family names
split: Dict with train/val/test ratios
preprocessing: Preprocessing settings
augmentation: Augmentation settings
Training configuration with keys:
batch_size: Batch size for DataLoaders
class_weights: Class weighting method (“None”, “Auto Class Weights”, “Focal Loss”)
Number of worker processes for data loading
Returns: Tuple of (dataloaders_dict, class_names, class_weights)
dataloaders_dict: Dict with ‘train’, ‘val’, ‘test’ DataLoaders
class_names: List of class names
class_weights: Tensor of class weights or None
Example
from training.dataset import create_dataloaders
dataset_config = {
"dataset_path": "dataset",
"selected_families": None,
"split": {
"train": 70,
"val": 15,
"test": 15,
"stratified": True,
"random_seed": 72
},
"preprocessing": {
"target_size": (224, 224),
"normalization": "ImageNet Mean/Std",
"color_mode": "RGB"
},
"augmentation": {
"preset": "Moderate"
}
}
training_config = {
"batch_size": 32,
"class_weights": "Auto Class Weights"
}
dataloaders, class_names, class_weights = create_dataloaders(
dataset_config,
training_config,
num_workers=4
)
train_loader = dataloaders["train"]
val_loader = dataloaders["val"]
test_loader = dataloaders["test"]
print(f"Classes: {class_names}")
print(f"Train batches: {len(train_loader)}")