Skip to main content

Overview

Data transforms are used to preprocess and augment data before feeding it to models. While PyTorch core focuses on tensor operations, data transformations are commonly handled by:
  • torchvision.transforms - For image transformations (vision tasks)
  • torchaudio.transforms - For audio transformations
  • Custom transforms - User-defined transformations
This page covers the general transformation concepts. For vision-specific transforms, see the torchvision documentation.

Transform Composition

Compose multiple transforms together.
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Apply to data
transformed = transform(image)

Custom Transforms

Create custom transformation functions or classes.

Function-Based Transform

def custom_transform(x):
    """Simple function-based transform."""
    # x is a PIL Image or Tensor
    return x * 2.0 + 1.0

# Use in Compose
transform = transforms.Compose([
    transforms.ToTensor(),
    custom_transform
])

Class-Based Transform

class AddGaussianNoise:
    """Add Gaussian noise to tensor."""
    
    def __init__(self, mean=0., std=1.):
        self.mean = mean
        self.std = std
    
    def __call__(self, tensor):
        noise = torch.randn(tensor.size()) * self.std + self.mean
        return tensor + noise
    
    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

# Usage
transform = transforms.Compose([
    transforms.ToTensor(),
    AddGaussianNoise(mean=0, std=0.1)
])

Common Transform Patterns

Training vs Validation Transforms

# Training: with augmentation
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

# Validation: no augmentation
val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

Transform with Dataset

from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image (PIL or numpy)
        image = load_image(self.image_paths[idx])
        label = self.labels[idx]
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Create datasets with different transforms
train_dataset = ImageDataset(train_paths, train_labels, transform=train_transform)
val_dataset = ImageDataset(val_paths, val_labels, transform=val_transform)

Tensor Transforms

Transforms that work directly on PyTorch tensors.

Normalization

import torch

class Normalize:
    """Normalize a tensor with mean and standard deviation."""
    
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)
    
    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor of shape (C, H, W) to normalize
        
        Returns:
            Tensor: Normalized tensor
        """
        return (tensor - self.mean) / self.std
    
    def inverse(self, tensor):
        """Denormalize tensor."""
        return tensor * self.std + self.mean

# Usage
normalizer = Normalize(mean=[0.485, 0.456, 0.406],
                      std=[0.229, 0.224, 0.225])

image_tensor = torch.rand(3, 224, 224)
normalized = normalizer(image_tensor)
denormalized = normalizer.inverse(normalized)

Random Augmentation

import random

class RandomApply:
    """Apply transform with given probability."""
    
    def __init__(self, transform, p=0.5):
        self.transform = transform
        self.p = p
    
    def __call__(self, x):
        if random.random() < self.p:
            return self.transform(x)
        return x

# Usage
transform = RandomApply(
    transforms.ColorJitter(brightness=0.2),
    p=0.7
)

Functional Transforms

Lower-level functional API for transforms.
import torchvision.transforms.functional as TF
import random

class CustomTransform:
    """Custom transform using functional API."""
    
    def __call__(self, img):
        # Random rotation
        angle = random.randint(-30, 30)
        img = TF.rotate(img, angle)
        
        # Random crop
        i, j, h, w = transforms.RandomCrop.get_params(
            img, output_size=(224, 224)
        )
        img = TF.crop(img, i, j, h, w)
        
        # To tensor and normalize
        img = TF.to_tensor(img)
        img = TF.normalize(img, mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
        
        return img

Multi-Modal Transforms

Transforms for multiple inputs (e.g., image + mask).
class DualTransform:
    """Apply same transform to image and mask."""
    
    def __init__(self, transform):
        self.transform = transform
    
    def __call__(self, image, mask):
        # Get random parameters
        state = torch.get_rng_state()
        
        # Apply to image
        image = self.transform(image)
        
        # Apply same transform to mask
        torch.set_rng_state(state)
        mask = self.transform(mask)
        
        return image, mask

# Usage
transform = DualTransform(
    transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10)
    ])
)

image, mask = transform(image, mask)

Batched Transforms

Transforms that operate on batches.
import torch.nn as nn

class BatchNormalize(nn.Module):
    """Normalize a batch of images."""
    
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean).view(1, -1, 1, 1)
        self.std = torch.tensor(std).view(1, -1, 1, 1)
    
    def forward(self, batch):
        """
        Args:
            batch (Tensor): Batch of shape (N, C, H, W)
        
        Returns:
            Tensor: Normalized batch
        """
        return (batch - self.mean.to(batch.device)) / self.std.to(batch.device)

# Usage
batch_normalizer = BatchNormalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

batch = torch.rand(32, 3, 224, 224)
normalized_batch = batch_normalizer(batch)

Advanced Patterns

Transform with State

class AdaptiveTransform:
    """Transform that adapts based on data statistics."""
    
    def __init__(self):
        self.mean = None
        self.std = None
    
    def fit(self, dataset):
        """Compute statistics from dataset."""
        all_data = torch.cat([dataset[i][0] for i in range(len(dataset))])
        self.mean = all_data.mean(dim=(0, 2, 3))
        self.std = all_data.std(dim=(0, 2, 3))
    
    def __call__(self, tensor):
        if self.mean is None:
            raise RuntimeError("Must call fit() before using transform")
        return (tensor - self.mean.view(-1, 1, 1)) / self.std.view(-1, 1, 1)

# Usage
transform = AdaptiveTransform()
transform.fit(train_dataset)
normalized = transform(image)

Conditional Transforms

class ConditionalTransform:
    """Different transforms based on label or condition."""
    
    def __init__(self, transform_dict):
        self.transform_dict = transform_dict
    
    def __call__(self, image, label):
        transform = self.transform_dict.get(label, lambda x: x)
        return transform(image)

# Usage
transform = ConditionalTransform({
    0: transforms.ColorJitter(brightness=0.2),
    1: transforms.RandomRotation(10),
    2: transforms.RandomHorizontalFlip()
})

Common Transform Functions

Image Preprocessing

def preprocess_image(image, size=224, mean=[0.485, 0.456, 0.406], 
                     std=[0.229, 0.224, 0.225]):
    """Standard image preprocessing pipeline."""
    transform = transforms.Compose([
        transforms.Resize(size + 32),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    return transform(image)

Data Augmentation

def get_augmentation(mode='strong'):
    """Get augmentation transform based on strength."""
    if mode == 'weak':
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])
    elif mode == 'strong':
        return transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.3, contrast=0.3),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
            transforms.ToTensor()
        ])
    else:
        return transforms.ToTensor()

Best Practices

  • Use transforms.ToTensor() to convert to PyTorch tensors
  • Apply expensive transforms (e.g., resize) before cheap ones
  • Consider caching transformed data for small datasets
  • Use num_workers > 0 in DataLoader to parallelize transforms
  • Set random seeds before creating transforms
  • Use torch.manual_seed() for deterministic augmentation
  • Document all transform parameters
  • Save transform configuration with model checkpoints
  • Visualize transformed samples to verify correctness
  • Test edge cases (e.g., very small/large images)
  • Check data ranges after normalization
  • Verify transforms maintain label consistency
  • Start with light augmentation, increase gradually
  • Don’t augment validation/test data
  • Use domain-specific augmentations
  • Monitor if augmentation helps or hurts performance

Visualization Helper

import matplotlib.pyplot as plt
import torchvision.utils as vutils

def visualize_transforms(dataset, transform, n_samples=8):
    """Visualize original and transformed images."""
    fig, axes = plt.subplots(2, n_samples, figsize=(16, 4))
    
    for i in range(n_samples):
        # Original
        img, _ = dataset[i]
        axes[0, i].imshow(img)
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=12)
        
        # Transformed
        transformed = transform(img)
        # Denormalize if needed
        transformed = transformed.permute(1, 2, 0).numpy()
        axes[1, i].imshow(transformed)
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Transformed', fontsize=12)
    
    plt.tight_layout()
    plt.show()

# Usage
visualize_transforms(dataset, train_transform, n_samples=8)

See Also

Build docs developers (and LLMs) love