Skip to main content

Overview

The transforms module provides functions to create torchvision transform pipelines for training and validation. It supports flexible augmentation presets and normalization strategies optimized for malware image classification.

Functions

create_train_transforms

Creates a transform pipeline for training data with optional augmentation.
def create_train_transforms(config: dict) -> transforms.Compose
config
dict
required
Dataset configuration dictionary with keys:
  • preprocessing: Preprocessing settings
    • target_size: Tuple (height, width) for resizing (default: (224, 224))
    • normalization: Normalization strategy (“[0,1]”, ”[-1,1]”, “ImageNet Mean/Std”)
    • color_mode: Color mode (“RGB”, “Grayscale”)
  • augmentation: Augmentation settings
    • preset: Augmentation preset (“None”, “Light”, “Moderate”, “Heavy”, “Custom”)
    • custom: Custom augmentation parameters (if preset=“Custom”)
Returns: Composed transform pipeline Transform Pipeline Order:
  1. Resize to target size
  2. Convert to RGB (if needed)
  3. Apply augmentations
  4. Convert to tensor
  5. Apply normalization

Augmentation Presets

None

No augmentation, only basic preprocessing.

Light

  • Random horizontal flip (p=0.5)
  • Random 90° rotation (0°, 90°, 180°, 270°)

Moderate

  • Random horizontal flip (p=0.5)
  • Random vertical flip (p=0.5)
  • Random 90° rotation (0°, 90°, 180°, 270°)
  • Color jitter (brightness±10%, contrast±10%)

Heavy

  • Random horizontal flip (p=0.5)
  • Random vertical flip (p=0.5)
  • Random 90° rotation (0°, 90°, 180°, 270°)
  • Color jitter (brightness±20%, contrast±20%)
  • Gaussian blur (kernel=3, sigma=0.1-0.5)

Custom

Configure individual augmentations:
  • horizontal_flip: Enable horizontal flips
  • vertical_flip: Enable vertical flips
  • rotation: Enable rotations
  • rotation_angles: List of rotation angles (default: [90, 180, 270])
  • brightness_range: Brightness variation percentage (0-100)
  • contrast_range: Contrast variation percentage (0-100)
  • gaussian_noise: Enable Gaussian blur

Normalization Strategies

StrategyMeanStdUse Case
[0,1]--Default, no normalization after ToTensor()
[-1,1][0.5, 0.5, 0.5][0.5, 0.5, 0.5]Neural networks with tanh activation
ImageNet Mean/Std[0.485, 0.456, 0.406][0.229, 0.224, 0.225]Transfer learning from ImageNet models

Example

from training.transforms import create_train_transforms

# Light augmentation
config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Light"
    }
}

train_transform = create_train_transforms(config)

# Apply to image
from PIL import Image
image = Image.open("sample.png")
transformed = train_transform(image)
print(transformed.shape)  # torch.Size([3, 224, 224])
# Moderate augmentation for better generalization
config = {
    "preprocessing": {
        "target_size": (256, 256),
        "normalization": "[-1,1]",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Moderate"
    }
}

train_transform = create_train_transforms(config)
# Custom augmentation with specific settings
config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Custom",
        "custom": {
            "horizontal_flip": True,
            "vertical_flip": False,
            "rotation": True,
            "rotation_angles": [90, 180],
            "brightness_range": 15,
            "contrast_range": 10,
            "gaussian_noise": False
        }
    }
}

train_transform = create_train_transforms(config)
# Grayscale images with heavy augmentation
config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "[0,1]",
        "color_mode": "Grayscale"
    },
    "augmentation": {
        "preset": "Heavy"
    }
}

train_transform = create_train_transforms(config)
# Note: Grayscale images are converted to 3-channel RGB for compatibility

create_val_transforms

Creates a transform pipeline for validation/test data without augmentation.
def create_val_transforms(config: dict) -> transforms.Compose
config
dict
required
Dataset configuration dictionary with keys:
  • preprocessing: Preprocessing settings
    • target_size: Tuple (height, width) for resizing (default: (224, 224))
    • normalization: Normalization strategy (“[0,1]”, ”[-1,1]”, “ImageNet Mean/Std”)
    • color_mode: Color mode (“RGB”, “Grayscale”)
Returns: Composed transform pipeline Transform Pipeline:
  1. Resize to target size
  2. Convert to RGB (if needed)
  3. Convert to tensor
  4. Apply normalization
Note: No augmentation is applied for validation/test sets to ensure consistent evaluation.

Example

from training.transforms import create_val_transforms

config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    }
}

val_transform = create_val_transforms(config)

# Apply to image
from PIL import Image
image = Image.open("sample.png")
transformed = val_transform(image)
print(transformed.shape)  # torch.Size([3, 224, 224])

Usage with Dataset

from pathlib import Path
from training.dataset import MalwareDataset, scan_dataset, create_splits
from training.transforms import create_train_transforms, create_val_transforms

# Configuration
config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "ImageNet Mean/Std",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Moderate"
    }
}

# Create transforms
train_transform = create_train_transforms(config)
val_transform = create_val_transforms(config)

# Scan dataset and create splits
image_paths, labels, class_names = scan_dataset(Path("dataset"))
splits = create_splits(image_paths, labels)

# Create datasets with transforms
train_dataset = MalwareDataset(
    splits["train"]["paths"],
    splits["train"]["labels"],
    transform=train_transform  # Augmentation applied
)

val_dataset = MalwareDataset(
    splits["val"]["paths"],
    splits["val"]["labels"],
    transform=val_transform  # No augmentation
)

test_dataset = MalwareDataset(
    splits["test"]["paths"],
    splits["test"]["labels"],
    transform=val_transform  # No augmentation
)

Visualizing Augmentations

import matplotlib.pyplot as plt
from PIL import Image
from training.transforms import create_train_transforms

config = {
    "preprocessing": {
        "target_size": (224, 224),
        "normalization": "[0,1]",
        "color_mode": "RGB"
    },
    "augmentation": {
        "preset": "Moderate"
    }
}

transform = create_train_transforms(config)
image = Image.open("sample.png")

# Apply transform multiple times to see variations
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    augmented = transform(image)
    # Convert from tensor to displayable format
    augmented = augmented.permute(1, 2, 0).numpy()
    ax.imshow(augmented)
    ax.axis('off')
    ax.set_title(f'Augmentation {i+1}')

plt.tight_layout()
plt.savefig('augmentation_examples.png')
plt.show()

Best Practices

Choosing Augmentation Strength

# Small dataset (< 1000 samples per class) - Use Heavy
config["augmentation"]["preset"] = "Heavy"

# Medium dataset (1000-5000 samples per class) - Use Moderate
config["augmentation"]["preset"] = "Moderate"

# Large dataset (> 5000 samples per class) - Use Light or None
config["augmentation"]["preset"] = "Light"

Choosing Normalization

# Transfer learning from ImageNet models (ResNet, EfficientNet, ViT)
config["preprocessing"]["normalization"] = "ImageNet Mean/Std"

# Training from scratch
config["preprocessing"]["normalization"] = "[0,1]"

# Models with tanh activation
config["preprocessing"]["normalization"] = "[-1,1]"

Consistent Preprocessing

# IMPORTANT: Use same config for train and val transforms
config = {
    "preprocessing": {...},
    "augmentation": {...}
}

train_transform = create_train_transforms(config)  # With augmentation
val_transform = create_val_transforms(config)      # Without augmentation

# Both use same preprocessing (resize, color mode, normalization)
# Only augmentation differs

Build docs developers (and LLMs) love