Skip to main content
Two complementary techniques combine at inference time to maximize accuracy: model ensembling averages the predictions of 4 independently trained checkpoints, while Test-Time Augmentation (TTA) applies 4 temporal transformations to each video before averaging predictions from all views.

Why Ensemble?

Each model in the ensemble is trained from a different random seed (42 + i), producing a different local minimum in weight space. When predictions from multiple diverse models are averaged, individual errors tend to cancel out—especially on ambiguous samples near decision boundaries.

Reduced Variance

A single model’s confidence on a hard sample may fluctuate. Averaging 4 models produces a smoother, more reliable probability estimate.

Error Cancellation

If model A misclassifies a Gaming clip as Animation but models B, C, and D are correct, the average still predicts Gaming.

No Extra Training Cost

Once trained, ensemble inference requires only one forward pass per model. With pre-extracted features, this is negligible overhead.

Different Seeds

Seeds 42, 43, 44, 45 initialize weights differently and produce different dropout masks, leading to genuinely diverse feature representations.

The 4 Checkpoints

From configuration_analysis.json, the four production checkpoints share identical architecture but differ in their training trajectories:
CheckpointBest EpochBest Val AccWeighted F1
best_ensemble_model_1.pt5272.73%64.76%
best_ensemble_model_2.pt4392.13%92.15%
best_ensemble_model_3.pt4291.86%91.87%
best_ensemble_model_4.pt4091.86%91.88%
Model 1’s lower individual accuracy is not necessarily a weakness in an ensemble. It may have learned a different decision boundary than models 2–4, contributing complementary information. The ensemble’s combined accuracy exceeds any individual model.

Model Ensemble: Averaging Softmax Probabilities

Each model produces a softmax probability vector of shape [4]. These are stacked and averaged:
def predict_standard(self, features):
    """Standard prediction (no TTA) - same as test_standard()"""
    all_model_predictions = []
    
    with torch.no_grad():
        features_batch = features.unsqueeze(0).to(self.device)  # [1, T, D]
        lengths = torch.tensor([features.shape[0]], device=self.device)
        
        for model in self.models:
            outputs = model(features_batch, lengths)  # [1, num_classes]
            probs = F.softmax(outputs, dim=1)
            all_model_predictions.append(probs.squeeze(0).cpu())
    
    # Ensemble: average predictions
    ensemble_probs = torch.stack(all_model_predictions).mean(dim=0)
    
    return ensemble_probs
The models are loaded from the four checkpoint files at startup:
selected_names = [
    "best_ensemble_model_1.pt",
    "best_ensemble_model_2.pt",
    "best_ensemble_model_3.pt",
    "best_ensemble_model_4.pt",
]

Test-Time Augmentation (TTA)

TTA generates multiple views of the same video at inference time by applying temporal transformations to the pre-extracted feature sequence. The model never sees pixels during TTA—only the feature vectors are manipulated, making it computationally cheap.

The 4 TTA Modes

The unmodified feature sequence is passed through the ensemble as-is. This is the baseline prediction.
# TTA Mode 1: Original
probs = self.predict_standard(features)
tta_predictions.append(probs)
The frame sequence is reversed end-to-end, simulating a video played backwards. Content type (Animation, Gaming, etc.) is invariant to temporal direction, so this is a valid augmentation.
# TTA Mode 2: Reverse
features_reversed = torch.flip(features, dims=[0])
probs = self.predict_standard(features_reversed)
tta_predictions.append(probs)
Every other frame is dropped by sampling T/2 indices uniformly across the sequence. This simulates a 2× speed-up.
# TTA Mode 3: Speed up (skip frames)
if features.shape[0] > 10:
    indices = torch.linspace(0, features.shape[0]-1,
                             features.shape[0]//2).long()
    features_speedup = features[indices]
    probs = self.predict_standard(features_speedup)
    tta_predictions.append(probs)
Minimum length guard (> 10 frames) prevents degenerate sequences that are too short for the LSTM.
The sequence is expanded to 1.5 × T frames by repeating existing frame features at uniformly spaced positions. This simulates a 0.67× slowdown.
# TTA Mode 4: Speed down (interpolate frames)
if features.shape[0] > 10:
    indices = torch.linspace(0, features.shape[0]-1,
                             int(features.shape[0]*1.5)).long()
    indices = indices.clamp(max=features.shape[0]-1)
    features_speeddown = features[indices]
    probs = self.predict_standard(features_speeddown)
    tta_predictions.append(probs)
.clamp() ensures no index exceeds the original sequence length due to floating-point rounding.

The predict_with_tta() Method

The full TTA function from test_already_extracted.py:
def predict_with_tta(self, features):
    """Prediction with TTA - same as test_with_tta()"""
    tta_predictions = []
    
    # TTA Mode 1: Original
    probs = self.predict_standard(features)
    tta_predictions.append(probs)
    
    # TTA Mode 2: Reverse
    features_reversed = torch.flip(features, dims=[0])
    probs = self.predict_standard(features_reversed)
    tta_predictions.append(probs)
    
    # TTA Mode 3: Speed up (skip frames)
    if features.shape[0] > 10:
        indices = torch.linspace(0, features.shape[0]-1,
                                 features.shape[0]//2).long()
        features_speedup = features[indices]
        probs = self.predict_standard(features_speedup)
        tta_predictions.append(probs)
    
    # TTA Mode 4: Speed down (interpolate frames)
    if features.shape[0] > 10:
        indices = torch.linspace(0, features.shape[0]-1,
                                 int(features.shape[0]*1.5)).long()
        indices = indices.clamp(max=features.shape[0]-1)
        features_speeddown = features[indices]
        probs = self.predict_standard(features_speeddown)
        tta_predictions.append(probs)
    
    # Average TTA predictions
    ensemble_probs = torch.stack(tta_predictions).mean(dim=0)
    
    return ensemble_probs
Each call to predict_standard already averages all 4 model checkpoints, so the final prediction averages up to 4 TTA modes × 4 models = 16 forward passes.

TTA During Training Evaluation

The same four modes are used in the test_time_augmentation method of EnhancedTemporalModelTrainer during the training pipeline:
tta_modes = [None, 'reverse', 'speed_up', 'speed_down']
all_predictions = []

for tta_mode in tta_modes:
    tta_dataset = EnhancedPreExtractedFeaturesDataset(
        original_dataset.feature_file,
        augment=False,
        tta_mode=tta_mode
    )
    
    # ... inference loop ...
    mode_predictions = torch.cat(mode_predictions)
    all_predictions.append(mode_predictions)

# Average predictions across all TTA modes
avg_predictions = torch.stack(all_predictions).mean(dim=0)
The EnhancedPreExtractedFeaturesDataset.__getitem__ applies the transformation inline:
if self.tta_mode == 'reverse':
    features = torch.flip(features, dims=[0])
elif self.tta_mode == 'speed_up':
    if num_frames > 10:
        indices = torch.linspace(0, num_frames-1, num_frames//2).long()
        features = features[indices]
elif self.tta_mode == 'speed_down':
    if num_frames > 5:
        indices = torch.linspace(0, num_frames-1,
                                 int(num_frames*1.5)).long().clamp(max=num_frames-1)
        features = features[indices]

Performance Impact

ConfigurationAccuracy
Single model, no TTA~93%
Single model + TTA~94–95%
4-model ensemble + TTA~95%+
The system documentation reports:
Test Accuracy: ~93% (95% with TTA)
Balanced class-wise F1 scores (>95% with ensemble + TTA)
TTA is disabled by default during interactive inference (use_tta=False). Pass use_tta=True to classify_video() when maximum accuracy is required and the additional latency (3–4× inference time) is acceptable.

Summary

Single video at inference

         ├── Original features          ─┐
         ├── Reversed features            │  each branch runs
         ├── Speed-up features (T/2)      │  through all 4 models
         └── Speed-down features (1.5T)  ─┘  → 4 softmax vectors
                                               averaged per branch


    4 branch predictions


    Average → final 4-class probability vector


    argmax → predicted class

Build docs developers (and LLMs) love