Overview
Trains a classifier on extracted CLIP+Whisper embeddings to predict TikTok folder categories. Compares three approaches using stratified cross-validation and selects the best model:- k-NN baseline (no training, nearest neighbors)
- Logistic Regression
- Small MLP with 2 hidden layers
source/train.py
Input: artifacts/labeled_embeddings.pt (from extract_features.py)
Output Files:
artifacts/model.pt- Trained MLP state dict (if MLP is best)artifacts/model.pkl- Trained sklearn model (if k-NN or LogReg is best)artifacts/model_config.json- Model metadata and configuration
Configuration Constants
Directory containing embeddings and where trained models are saved.
Classes
MLP
Multi-layer perceptron classifier with 2 hidden layers and dropout. Architecture:- Input layer:
input_dim→hidden_dim(default 256) - Hidden layer 1: ReLU activation + 30% dropout
- Hidden layer 2:
hidden_dim→hidden_dim // 2(128) - Hidden layer 2: ReLU activation + 20% dropout
- Output layer:
hidden_dim // 2→num_classes
Constructor
Dimension of input features (typically 1024 for CLIP+Whisper)
Number of output categories
Size of first hidden layer (second layer is
hidden_dim // 2)forward(x)
Input tensor of shape
[batch_size, input_dim]torch.Tensor - Logits of shape [batch_size, num_classes]
Functions
train_mlp(X_train, y_train, X_val, y_val, num_classes, device, epochs=100, lr=1e-3)
Trains an MLP classifier with early stopping and class-weighted loss.Training features of shape
[N, feature_dim]Training labels of shape
[N] (integer class indices)Validation features of shape
[M, feature_dim]Validation labels of shape
[M]Number of output categories
Device to train on (
"cuda" or "cpu")Maximum number of training epochs
Learning rate for Adam optimizer
tuple - (model, best_val_acc) where model is the trained MLP and best_val_acc is the best validation accuracy achieved
Training Details:
- Optimizer: Adam with weight decay 1e-4
- Loss: Cross-entropy with class weights (inverse frequency, normalized)
- Batch size: 32
- Early stopping: Patience of 15 epochs
- Returns model with best validation accuracy
evaluate(name, y_true, y_pred, label_names)
Prints detailed evaluation metrics including classification report and confusion matrix.Name/title for this evaluation (e.g., “k-NN”, “MLP”)
True labels
Predicted labels
List of category names corresponding to label indices
float - Overall accuracy
Prints:
- Per-class precision, recall, F1-score (via sklearn classification_report)
- Confusion matrix with category names
- Overall accuracy percentage
main()
Main execution function that orchestrates model training and selection. Pipeline:- Load labeled embeddings from
artifacts/labeled_embeddings.pt - Print dataset statistics (class distribution)
- Perform stratified k-fold cross-validation (2-5 folds depending on smallest class)
- For each fold, train and evaluate:
- k-NN classifier (k=5, cosine distance)
- Logistic Regression (max_iter=1000, balanced class weights)
- MLP (2 hidden layers, early stopping)
- Compare mean cross-validation accuracies
- Select best model type
- Retrain best model on all labeled data
- Save model and configuration to artifacts directory
model_config.json):
model_type: "knn" or "logreg"), the model is saved to model.pkl using pickle.
For MLP (model_type: "mlp"), the model state dict is saved to model.pt using torch.save.
Usage
Model Selection
The script automatically selects the best model based on cross-validation performance:- k-NN: Good baseline, no training required, works well with small datasets
- Logistic Regression: Fast, interpretable, handles class imbalance well
- MLP: More expressive, can capture non-linear patterns, best for larger datasets
- Class-weighted loss for MLP (inverse frequency)
class_weight="balanced"for Logistic Regression- Cosine distance for k-NN (works well with normalized embeddings)
Cross-Validation Strategy
Stratified k-fold ensures each fold maintains the same class distribution:- Number of folds: min(5, smallest_class_size), at least 2
- Each fold gets a representative sample of all classes
- Final model is retrained on all data for deployment