Overview
The training pipeline evaluates three different classifier architectures using stratified cross-validation, then selects and retrains the best performer on the full dataset. The system handles class imbalance through weighted loss functions and uses early stopping to prevent overfitting.Model Architectures
1. k-Nearest Neighbors (Baseline)
Hyperparameters:n_neighbors:min(5, len(X_train) - 1)metric:"cosine"- Cosine similarity (well-suited for normalized embeddings)
Why k-NN?k-NN serves as a strong baseline for this task because:
- No training required (inference-only)
- Works well with high-quality pre-trained embeddings (CLIP)
- Cosine distance naturally handles normalized feature vectors
- Simple and interpretable
train.py:145-152:
2. Logistic Regression
Hyperparameters:max_iter:1000C:1.0(inverse regularization strength)class_weight:"balanced"- Automatically adjusts weights inversely proportional to class frequencies
Logistic regression provides a linear decision boundary in the 1024-d embedding space. With
class_weight="balanced", it automatically handles class imbalance by weighting the loss for each sample by the inverse of its class frequency.train.py:154-160:
3. Multi-Layer Perceptron (MLP)
Architecture:hidden_dim:256(first hidden layer)hidden_dim // 2:128(second hidden layer)dropout:0.3(first layer),0.2(second layer)learning_rate:1e-3weight_decay:1e-4(L2 regularization)batch_size:32epochs:100(with early stopping)
train.py:31-45:
Architecture Rationale:
- 2 hidden layers: Sufficient capacity for non-linear decision boundaries without overfitting
- 256 → 128 bottleneck: Progressively reduces dimensionality while learning hierarchical features
- Dropout regularization: Prevents co-adaptation of neurons, improves generalization
- ReLU activation: Fast, effective, and prevents vanishing gradients
Class Imbalance Handling
The Problem
User-organized collections often have severe class imbalance. For example:Solution: Class-Weighted Loss
Fromtrain.py:53-58:
- Count samples per class:
[82, 5, 15, 23] - Compute inverse frequencies:
[1/82, 1/5, 1/15, 1/23] - Normalize weights to sum to
num_classes - Apply weights to cross-entropy loss
Effect: The loss for a minority class sample is amplified proportionally to its rarity, forcing the model to pay attention to underrepresented categories.For example, a misclassified “funny” sample (5 examples) contributes ~16x more loss than a misclassified “soccer” sample (82 examples).
Training Procedure (MLP)
Optimizer and Regularization
Fromtrain.py:51:
- Adam: Adaptive learning rate optimizer (fast convergence, robust to hyperparameters)
- Learning Rate:
1e-3(standard default) - Weight Decay:
1e-4(L2 regularization to prevent overfitting)
Early Stopping
Fromtrain.py:66-93:
Cross-Validation Strategy
Stratified K-Fold
Fromtrain.py:131-136:
Stratified splitting ensures each fold maintains the same class distribution as the full dataset. This is crucial for imbalanced data to get reliable accuracy estimates.For example, if “funny” is 4% of the dataset, it will be ~4% in each train/validation split.
- If the smallest class has only 5 examples, we can’t do 5-fold CV (would leave 1 example per fold)
- The system automatically reduces
n_splitsto ensure at least 1 sample per class in each split - Minimum of 2 folds, maximum of 5 folds
Cross-Validation Loop
Fromtrain.py:141-168:
Model Selection
Fromtrain.py:176-182:
Final Model Retraining
After selecting the best architecture via CV, the system retrains on the full dataset to maximize available training data for deployment.
For MLP (Most Common Winner)
Fromtrain.py:205-218:
- Uses 90% for training, 10% holdout for early stopping validation
- Trains for up to 200 epochs (more than CV’s 100, since this is final model)
- Saves PyTorch state dict to
model.pt
For sklearn Models (k-NN or LogReg)
Fromtrain.py:187-203:
- sklearn models train on 100% of data (no need for validation during fit)
- Saved as pickle files (
model.pkl) instead of PyTorch
Model Configuration
Fromtrain.py:220-225:
model_config.json:
predict.py to reconstruct the model architecture and interpret predictions.
Evaluation Metrics
Fromtrain.py:99-114:
- Per-class precision, recall, F1-score (from sklearn’s
classification_report) - Confusion matrix (which classes are confused with each other)
- Overall accuracy
Training Summary Workflow
Performance Expectations
Typical Results (will vary by dataset):
- k-NN: 70-80% accuracy (good baseline, fast inference)
- Logistic Regression: 75-85% accuracy (linear boundary, class weighting helps)
- MLP: 80-90% accuracy (best performer, learns non-linear patterns)
- Dataset size (more data → better accuracy)
- Class balance (even distribution → easier learning)
- Category separability (visually distinct categories → higher accuracy)
Next Steps
- Learn about Multimodal Features to understand what the model is learning from
- See System Architecture for the complete pipeline
- Check the Prediction API to use your trained model