Skip to main content
The train.py script trains three different classifiers on your extracted features and automatically selects the best performer based on cross-validation accuracy.

Training Pipeline Overview

The script implements a comprehensive training and evaluation workflow:
  1. Load labeled embeddings from artifacts/labeled_embeddings.pt
  2. Perform stratified k-fold cross-validation (default: 5 folds)
  3. Train three model types in parallel:
    • k-NN: Non-parametric baseline using cosine similarity
    • Logistic Regression: Linear classifier with class balancing
    • MLP: 2-layer neural network with dropout
  4. Select the best model based on mean CV accuracy
  5. Retrain on full dataset and save to artifacts/

Running Training

1

Extract Features First

Training requires artifacts/labeled_embeddings.pt from the feature extraction step:
python extract_features.py
2

Run Training

python train.py
The script automatically:
  • Detects GPU/CPU
  • Loads embeddings
  • Performs cross-validation
  • Saves the best model
3

Review Results

Terminal output shows per-fold accuracy and final metrics:
Loaded 156 samples, 8 classes: ['cooking', 'funny', 'motivational', 'pets', 'quran', 'soccer', 'tiktok', 'travel']
Feature dimension: 1024
Class distribution: {'cooking': 12, 'funny': 8, 'motivational': 15, 'pets': 18, 'quran': 24, 'soccer': 45, 'tiktok': 20, 'travel': 14}

Using 5-fold stratified cross-validation
  Fold 1: kNN=74.2%  LogReg=80.6%  MLP=87.1%
  Fold 2: kNN=71.0%  LogReg=77.4%  MLP=90.3%
  Fold 3: kNN=69.4%  LogReg=83.9%  MLP=87.1%
  Fold 4: kNN=77.4%  LogReg=80.6%  MLP=93.5%
  Fold 5: kNN=75.0%  LogReg=78.1%  MLP=90.6%

============================================================
Cross-validation results (mean accuracy):
     kNN: 73.4% (+/- 2.8%)
  logreg: 80.1% (+/- 2.3%)
     mlp: 89.7% (+/- 2.2%)

Best model: mlp (89.7%)

============================================================
  Best Model (mlp) - Full CV Predictions
============================================================
              precision    recall  f1-score   support

     cooking       0.83      0.75      0.79        12
       funny       0.88      0.88      0.88         8
motivational       0.87      0.93      0.90        15
        pets       0.94      0.89      0.91        18
       quran       0.96      1.00      0.98        24
      soccer       0.93      0.96      0.94        45
      tiktok       0.85      0.80      0.82        20
      travel       0.86      0.86      0.86        14

    accuracy                           0.90       156
   macro avg       0.89      0.88      0.89       156
weighted avg       0.90      0.90      0.90       156

Confusion Matrix:
           cookin  funny motiva   pets  quran soccer tiktok travel
cookin          9      0      1      0      0      1      1      0
 funny          0      7      0      0      0      0      1      0
motiva          0      0     14      1      0      0      0      0
  pets          0      0      1     16      0      1      0      0
 quran          0      0      0      0     24      0      0      0
soccer          0      0      0      1      0     43      1      0
tiktok          1      1      0      0      0      2     16      0
travel          1      0      0      1      0      0      0     12

Overall accuracy: 89.7%

Retraining mlp on all 156 samples...
Model saved to artifacts
Done!

Model Architectures

k-Nearest Neighbors (k-NN)

Non-parametric baseline that doesn’t learn weights:
k = min(5, len(X_train) - 1)
knn = KNeighborsClassifier(n_neighbors=k, metric="cosine")
knn.fit(X_train, y_train)
Characteristics:
  • No training required (just stores data)
  • Uses cosine similarity in embedding space
  • Good baseline for CLIP features (already semantic)
  • Slow inference (must compare to all training samples)

Logistic Regression

Linear classifier with class balancing:
lr = LogisticRegression(max_iter=1000, C=1.0, class_weight="balanced")
lr.fit(X_train, y_train)
Characteristics:
  • Fast training and inference
  • class_weight="balanced" handles imbalanced data
  • Learns linear decision boundaries
  • Often sufficient for well-separated CLIP features

Multi-Layer Perceptron (MLP)

Two-layer neural network with dropout:
class MLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),       # 1024 → 256
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2), # 256 → 128
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes), # 128 → N
        )
Training configuration:
  • Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
  • Loss: CrossEntropyLoss with class weights
  • Early stopping: patience=15 epochs
  • Batch size: 32
Characteristics:
  • Learns non-linear decision boundaries
  • Class-weighted loss for imbalanced datasets
  • Early stopping prevents overfitting
  • Best accuracy but slower inference than logistic regression

Understanding Class Weighting

With imbalanced data (e.g., soccer=45 videos, funny=8 videos), the model can achieve high accuracy by always predicting the majority class. Class weighting fixes this:
class_counts = np.bincount(y_train, minlength=num_classes)
weights = 1.0 / class_counts
weights = weights / weights.sum() * num_classes  # Normalize
Effect: Misclassifying a rare class (funny) has higher penalty than a common class (soccer), forcing the model to learn all categories.

Cross-Validation Strategy

Stratified k-fold ensures each fold has proportional representation:
n_splits = min(5, min(np.bincount(y)))  # Can't split smaller than smallest class
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
Example with 156 samples, 8 classes, 5 folds:
  • Each fold has ~31 test samples
  • Each class appears in every fold proportionally
  • Prevents data leakage between folds
If you have a very small class (e.g., only 3 videos), the script automatically reduces to 3-fold CV to ensure valid splits.

Output Artifacts

Training saves two files to artifacts/:

model.pt or model.pkl

Depending on which model wins: If MLP wins:
# artifacts/model.pt (PyTorch state dict)
torch.load("artifacts/model.pt")  # OrderedDict of weights
If k-NN or Logistic Regression wins:
# artifacts/model.pkl (sklearn model)
import pickle
with open("artifacts/model.pkl", "rb") as f:
    model = pickle.load(f)

model_config.json

Metadata for loading and inference:
{
  "model_type": "mlp",
  "input_dim": 1024,
  "num_classes": 8,
  "hidden_dim": 256,
  "label_names": [
    "cooking",
    "funny",
    "motivational",
    "pets",
    "quran",
    "soccer",
    "tiktok",
    "travel"
  ],
  "feature_dim": 1024,
  "best_cv_accuracy": 0.8974358974358975
}
Usage in predict.py:
with open("artifacts/model_config.json") as f:
    config = json.load(f)

if config["model_type"] == "mlp":
    model = MLP(config["input_dim"], config["num_classes"])
    model.load_state_dict(torch.load("artifacts/model.pt"))

Interpreting Metrics

Precision, Recall, F1-Score

              precision    recall  f1-score   support

      soccer       0.93      0.96      0.94        45
       funny       0.88      0.88      0.88         8
  • Precision: Of videos predicted as “soccer”, 93% were actually soccer
  • Recall: Of all actual soccer videos, 96% were correctly identified
  • F1-Score: Harmonic mean of precision and recall (balanced metric)
  • Support: Number of videos in that category

Confusion Matrix

Rows = true labels, Columns = predicted labels
         soccer  funny  cooking
soccer       43      1        1
funny         0      7        1
cooking       2      0       10
Reading the matrix:
  • Diagonal = correct predictions
  • Off-diagonal = errors
  • soccer → cooking: 1 soccer video misclassified as cooking
  • funny → cooking: 1 funny video misclassified as cooking
Finding confusion patterns: Look for high off-diagonal values to identify categories the model conflates.

When to Retrain

Retrain your model after:
  1. Adding new labeled videos: More data improves accuracy
  2. Creating new categories: Adds new classes to learn
  3. Renaming/merging categories: Changes label structure
  4. Correcting mislabeled videos: Fixes training data quality
Always re-run feature extraction before retraining if you’ve added/modified videos.
# Full retraining workflow
python extract_features.py  # Re-extract (if data changed)
python train.py             # Train new model
python predict.py           # Generate new predictions

Hyperparameter Tuning

Default hyperparameters work well for most cases, but you can tune:

MLP Architecture

# Edit train.py:32
class MLP(nn.Module):
    def __init__(self, input_dim, num_classes, hidden_dim=512):  # Increase capacity
        # ...
Larger networks (hidden_dim=512):
  • Better for complex, multi-modal datasets
  • Risk of overfitting with <100 samples
Smaller networks (hidden_dim=128):
  • Faster training, less overfitting
  • May underfit if categories are visually similar

Learning Rate

# Edit train.py:52
optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)
Lower LR (1e-4): More stable, slower convergence Higher LR (5e-3): Faster but may overshoot optimal weights

Early Stopping Patience

# Edit train.py:68
patience = 25  # Wait 25 epochs before stopping
Increase if your model improves slowly; decrease for faster training.

Troubleshooting

You have a category with <5 videos. Solutions:
  1. Add more videos to small categories (recommended)
  2. Merge small categories into broader groups
  3. Remove categories with <5 samples
The script automatically reduces folds, but very small classes may still fail.
Possible causes:
  1. Ambiguous categories: Videos don’t have consistent visual/audio patterns
  2. Mislabeled data: Check that videos are in correct folders
  3. Insufficient data: Add more labeled examples (aim for 15+ per class)
  4. Feature quality: Try larger CLIP/Whisper models in extraction
Indicates overfitting:
  1. Reduce MLP size: hidden_dim=128
  2. Increase dropout: nn.Dropout(0.5)
  3. Add more data to increase generalization
The model never predicts this class. Check:
  1. Class size: Extremely small classes (<5 videos) may be ignored
  2. Class similarity: If very similar to another class, it gets absorbed
  3. Labeling errors: Verify videos are actually in the right folder
Increase class weight manually or add more diverse examples.

Advanced: Manual Model Selection

Force a specific model type by modifying train.py:178:
# Override automatic selection
best_name = "logreg"  # Force logistic regression
# best_name = max(mean_accs, key=mean_accs.get)  # Comment out original
When to override:
  • Force k-NN: Fastest inference, good for small datasets
  • Force LogReg: Fast and interpretable, sufficient for most cases
  • Force MLP: Maximum accuracy, acceptable slower inference

Next Steps

Inference

Run predictions on unlabeled videos

Build docs developers (and LLMs) love