Skip to main content

Loss Function — Focal Loss with Label Smoothing

Training uses a custom FocalLoss that combines two mechanisms for handling class imbalance and overconfidence:
class FocalLoss(nn.Module):
    """Focal Loss with label smoothing"""

    def __init__(self, alpha=None, gamma=2.0, smoothing=0.1):
        super().__init__()
        self.alpha = alpha      # per-class weights tensor
        self.gamma = gamma      # focusing parameter
        self.smoothing = smoothing

    def forward(self, inputs, targets):
        num_classes = inputs.shape[1]

        # Label smoothing: spread (1 - smoothing) confidence onto the true class
        # and smoothing / (num_classes - 1) onto every other class
        confidence = 1.0 - self.smoothing
        smooth_targets = torch.zeros_like(inputs)
        smooth_targets.fill_(self.smoothing / (num_classes - 1))
        smooth_targets.scatter_(1, targets.unsqueeze(1), confidence)

        # Cross-entropy against smooth targets
        log_probs = F.log_softmax(inputs, dim=1)
        ce_loss = -(smooth_targets * log_probs).sum(dim=1)

        # Focal modulation: down-weight easy examples
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        # Per-class alpha weighting
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            at = self.alpha.gather(0, targets)
            focal_loss = at * focal_loss

        return focal_loss.mean()

How it handles class imbalance

  • alpha — Per-class weights derived from inverse class frequency in the training set (clipped to [0.5, 10.0]). Under-represented classes receive a higher multiplier.
  • gamma=2.0 — The (1 - pt)^gamma factor reduces the loss contribution of correctly-classified, high-confidence examples. The model therefore focuses gradient updates on hard or misclassified samples regardless of class.
  • smoothing=0.1 — Prevents the model from becoming overconfident by assigning 10% of the probability mass uniformly across non-target classes.
criterion = FocalLoss(
    alpha=self.class_weights.to(self.device),  # from WeightedRandomSampler weights
    gamma=2.0,
    smoothing=0.1
)

Optimizer — AdamW

optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=5e-4,
    betas=(0.9, 0.999)
)
ParameterValue
Learning rate0.001
Weight decay5e-4
β₁0.9
β₂0.999
Weight decay (5e-4) provides L2 regularization on all parameter tensors, discouraging large weight magnitudes.

Learning Rate Scheduler — CosineAnnealingWarmRestarts

scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=20,       # first restart period (epochs)
    T_mult=2,     # period doubles after each restart
    eta_min=1e-6  # minimum learning rate
)
The scheduler is stepped every batch rather than every epoch:
if scheduler and not isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
    scheduler.step()
This produces a smooth cosine decay within each 20-epoch cycle. After each restart the cycle length doubles (20 → 40 → 80 epochs), allowing the model to escape local minima and re-explore the loss landscape with a temporarily higher learning rate.

Regularization Techniques

Dropout

0.4 on LSTM inter-layer connections and multi-head attention. 0.2 (half rate) on the input projection and attention pooling layers.

Gradient Clipping

Global norm clipping with max_norm=1.0 applied after loss.backward() before each optimizer step.

Weight Decay

L2 penalty of 5e-4 via AdamW. Applied to all learnable parameters.

Label Smoothing

smoothing=0.1 in FocalLoss. Reduces overconfidence and improves calibration on visually ambiguous samples.
# Gradient clipping in the training loop
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()

Training Augmentation

Augmentation is applied at the feature sequence level inside EnhancedPreExtractedFeaturesDataset.__getitem__() — after CNN features are loaded from disk, before they enter the model.
if self.augment and self.tta_mode is None:
    num_frames = features.shape[0]

    if num_frames > 8:
        # 1. Temporal subsampling (50% probability)
        if random.random() < 0.5:
            sample_ratio = random.uniform(0.7, 1.0)  # keep 70–100% of frames
            new_length = max(int(num_frames * sample_ratio), 8)
            indices = sorted(random.sample(range(num_frames), new_length))
            features = features[indices]

        # 2. Temporal shift (30% probability)
        if random.random() < 0.3:
            shift = random.randint(-3, 3)
            if shift != 0:
                features = torch.roll(features, shifts=shift, dims=0)

        # 3. Gaussian noise (20% probability)
        if random.random() < 0.2:
            noise = torch.randn_like(features) * 0.01
            features = features + noise
AugmentationProbabilityParameters
Temporal subsampling50%Retain 70–100% of frames, randomly sampled
Temporal shift30%Circular roll by ±3 frames
Gaussian noise20%σ = 0.01 added to all feature dimensions
Augmentation runs on CPU inside the DataLoader workers and operates on pre-extracted 1280-dim feature vectors — not on raw pixels. This makes it extremely cheap and introduces no GPU overhead.

Resource Monitoring

The ResourceMonitor class runs a background daemon thread that logs GPU, RAM, and CPU utilisation to a JSON file every 30 minutes throughout training:
monitor = ResourceMonitor(
    output_dir="results",
    interval_minutes=30
)
monitor.start_monitoring()

# ... training runs ...

monitor.stop_monitoring()
summary = monitor.get_summary()
Metrics logged per interval:
MetricSource
cpu_percentpsutil.cpu_percent()
ram_used_gb / ram_percentpsutil.virtual_memory()
gpu_allocated_gbtorch.cuda.memory_allocated()
gpu_reserved_gbtorch.cuda.memory_reserved()
gpu_max_allocated_gbtorch.cuda.max_memory_allocated()
The log is written incrementally to results/resource_utilization_log.json so data is preserved even if training is interrupted.
On the A100 MIG partition (9.8 GB VRAM), gpu_reserved_gb will consistently exceed gpu_allocated_gb due to PyTorch’s caching allocator. Call torch.cuda.empty_cache() between ensemble runs to release reserved-but-unused memory back to the driver.

Build docs developers (and LLMs) love