train.py script trains three different classifiers on your extracted features and automatically selects the best performer based on cross-validation accuracy.
Training Pipeline Overview
The script implements a comprehensive training and evaluation workflow:- Load labeled embeddings from
artifacts/labeled_embeddings.pt - Perform stratified k-fold cross-validation (default: 5 folds)
- Train three model types in parallel:
- k-NN: Non-parametric baseline using cosine similarity
- Logistic Regression: Linear classifier with class balancing
- MLP: 2-layer neural network with dropout
- Select the best model based on mean CV accuracy
- Retrain on full dataset and save to
artifacts/
Running Training
Extract Features First
Training requires
artifacts/labeled_embeddings.pt from the feature extraction step:Run Training
- Detects GPU/CPU
- Loads embeddings
- Performs cross-validation
- Saves the best model
Model Architectures
k-Nearest Neighbors (k-NN)
Non-parametric baseline that doesn’t learn weights:- No training required (just stores data)
- Uses cosine similarity in embedding space
- Good baseline for CLIP features (already semantic)
- Slow inference (must compare to all training samples)
Logistic Regression
Linear classifier with class balancing:- Fast training and inference
class_weight="balanced"handles imbalanced data- Learns linear decision boundaries
- Often sufficient for well-separated CLIP features
Multi-Layer Perceptron (MLP)
Two-layer neural network with dropout:- Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
- Loss: CrossEntropyLoss with class weights
- Early stopping: patience=15 epochs
- Batch size: 32
- Learns non-linear decision boundaries
- Class-weighted loss for imbalanced datasets
- Early stopping prevents overfitting
- Best accuracy but slower inference than logistic regression
Understanding Class Weighting
With imbalanced data (e.g., soccer=45 videos, funny=8 videos), the model can achieve high accuracy by always predicting the majority class. Class weighting fixes this:Cross-Validation Strategy
Stratified k-fold ensures each fold has proportional representation:- Each fold has ~31 test samples
- Each class appears in every fold proportionally
- Prevents data leakage between folds
If you have a very small class (e.g., only 3 videos), the script automatically reduces to 3-fold CV to ensure valid splits.
Output Artifacts
Training saves two files toartifacts/:
model.pt or model.pkl
Depending on which model wins: If MLP wins:model_config.json
Metadata for loading and inference:Interpreting Metrics
Precision, Recall, F1-Score
- Precision: Of videos predicted as “soccer”, 93% were actually soccer
- Recall: Of all actual soccer videos, 96% were correctly identified
- F1-Score: Harmonic mean of precision and recall (balanced metric)
- Support: Number of videos in that category
Confusion Matrix
Rows = true labels, Columns = predicted labels- Diagonal = correct predictions
- Off-diagonal = errors
soccer → cooking: 1 soccer video misclassified as cookingfunny → cooking: 1 funny video misclassified as cooking
When to Retrain
Retrain your model after:- Adding new labeled videos: More data improves accuracy
- Creating new categories: Adds new classes to learn
- Renaming/merging categories: Changes label structure
- Correcting mislabeled videos: Fixes training data quality
Hyperparameter Tuning
Default hyperparameters work well for most cases, but you can tune:MLP Architecture
- Better for complex, multi-modal datasets
- Risk of overfitting with <100 samples
- Faster training, less overfitting
- May underfit if categories are visually similar
Learning Rate
Early Stopping Patience
Troubleshooting
ValueError: n_splits=5 cannot be greater than the number of members in each class
ValueError: n_splits=5 cannot be greater than the number of members in each class
You have a category with <5 videos. Solutions:
- Add more videos to small categories (recommended)
- Merge small categories into broader groups
- Remove categories with <5 samples
Low accuracy (<60%) across all models
Low accuracy (<60%) across all models
Possible causes:
- Ambiguous categories: Videos don’t have consistent visual/audio patterns
- Mislabeled data: Check that videos are in correct folders
- Insufficient data: Add more labeled examples (aim for 15+ per class)
- Feature quality: Try larger CLIP/Whisper models in extraction
High training accuracy but poor predictions
High training accuracy but poor predictions
Indicates overfitting:
- Reduce MLP size:
hidden_dim=128 - Increase dropout:
nn.Dropout(0.5) - Add more data to increase generalization
One class has 0% recall
One class has 0% recall
The model never predicts this class. Check:
- Class size: Extremely small classes (<5 videos) may be ignored
- Class similarity: If very similar to another class, it gets absorbed
- Labeling errors: Verify videos are actually in the right folder
Advanced: Manual Model Selection
Force a specific model type by modifyingtrain.py:178:
- Force k-NN: Fastest inference, good for small datasets
- Force LogReg: Fast and interpretable, sufficient for most cases
- Force MLP: Maximum accuracy, acceptable slower inference
Next Steps
Inference
Run predictions on unlabeled videos