Skip to main content

Overview

Trains a classifier on extracted CLIP+Whisper embeddings to predict TikTok folder categories. Compares three approaches using stratified cross-validation and selects the best model:
  1. k-NN baseline (no training, nearest neighbors)
  2. Logistic Regression
  3. Small MLP with 2 hidden layers
Location: source/train.py Input: artifacts/labeled_embeddings.pt (from extract_features.py) Output Files:
  • artifacts/model.pt - Trained MLP state dict (if MLP is best)
  • artifacts/model.pkl - Trained sklearn model (if k-NN or LogReg is best)
  • artifacts/model_config.json - Model metadata and configuration

Configuration Constants

ARTIFACTS_DIR
Path
default:"artifacts"
Directory containing embeddings and where trained models are saved.

Classes

MLP

Multi-layer perceptron classifier with 2 hidden layers and dropout. Architecture:
  • Input layer: input_dimhidden_dim (default 256)
  • Hidden layer 1: ReLU activation + 30% dropout
  • Hidden layer 2: hidden_dimhidden_dim // 2 (128)
  • Hidden layer 2: ReLU activation + 20% dropout
  • Output layer: hidden_dim // 2num_classes

Constructor

MLP(input_dim, num_classes, hidden_dim=256)
input_dim
int
required
Dimension of input features (typically 1024 for CLIP+Whisper)
num_classes
int
required
Number of output categories
hidden_dim
int
default:"256"
Size of first hidden layer (second layer is hidden_dim // 2)

forward(x)

x
torch.Tensor
required
Input tensor of shape [batch_size, input_dim]
Returns: torch.Tensor - Logits of shape [batch_size, num_classes]

Functions

train_mlp(X_train, y_train, X_val, y_val, num_classes, device, epochs=100, lr=1e-3)

Trains an MLP classifier with early stopping and class-weighted loss.
X_train
np.ndarray
required
Training features of shape [N, feature_dim]
y_train
np.ndarray
required
Training labels of shape [N] (integer class indices)
X_val
np.ndarray
required
Validation features of shape [M, feature_dim]
y_val
np.ndarray
required
Validation labels of shape [M]
num_classes
int
required
Number of output categories
device
str
required
Device to train on ("cuda" or "cpu")
epochs
int
default:"100"
Maximum number of training epochs
lr
float
default:"0.001"
Learning rate for Adam optimizer
Returns: tuple - (model, best_val_acc) where model is the trained MLP and best_val_acc is the best validation accuracy achieved Training Details:
  • Optimizer: Adam with weight decay 1e-4
  • Loss: Cross-entropy with class weights (inverse frequency, normalized)
  • Batch size: 32
  • Early stopping: Patience of 15 epochs
  • Returns model with best validation accuracy
model, val_acc = train_mlp(
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    num_classes=5,
    device="cuda",
    epochs=100,
    lr=0.001
)
print(f"Best validation accuracy: {val_acc:.2%}")

evaluate(name, y_true, y_pred, label_names)

Prints detailed evaluation metrics including classification report and confusion matrix.
name
str
required
Name/title for this evaluation (e.g., “k-NN”, “MLP”)
y_true
np.ndarray
required
True labels
y_pred
np.ndarray
required
Predicted labels
label_names
list[str]
required
List of category names corresponding to label indices
Returns: float - Overall accuracy Prints:
  • Per-class precision, recall, F1-score (via sklearn classification_report)
  • Confusion matrix with category names
  • Overall accuracy percentage
accuracy = evaluate(
    name="MLP",
    y_true=y_test,
    y_pred=predictions,
    label_names=["soccer", "funny", "food"]
)

main()

Main execution function that orchestrates model training and selection. Pipeline:
  1. Load labeled embeddings from artifacts/labeled_embeddings.pt
  2. Print dataset statistics (class distribution)
  3. Perform stratified k-fold cross-validation (2-5 folds depending on smallest class)
  4. For each fold, train and evaluate:
    • k-NN classifier (k=5, cosine distance)
    • Logistic Regression (max_iter=1000, balanced class weights)
    • MLP (2 hidden layers, early stopping)
  5. Compare mean cross-validation accuracies
  6. Select best model type
  7. Retrain best model on all labeled data
  8. Save model and configuration to artifacts directory
Saved Configuration (model_config.json):
{
  "model_type": "mlp",
  "input_dim": 1024,
  "num_classes": 5,
  "hidden_dim": 256,
  "label_names": ["funny", "food", "soccer"],
  "feature_dim": 1024,
  "best_cv_accuracy": 0.92
}
For sklearn models (model_type: "knn" or "logreg"), the model is saved to model.pkl using pickle. For MLP (model_type: "mlp"), the model state dict is saved to model.pt using torch.save.
if __name__ == "__main__":
    main()

Usage

# Train classifier on extracted features
python train.py
Output Example:
Loaded 142 samples, 3 classes: ['funny', 'food', 'soccer']
Feature dimension: 1024
Class distribution: {'funny': 5, 'food': 55, 'soccer': 82}

Using 3-fold stratified cross-validation
  Fold 1: kNN=85.0%  LogReg=87.2%  MLP=89.4%
  Fold 2: kNN=83.3%  LogReg=86.7%  MLP=91.1%
  Fold 3: kNN=84.8%  LogReg=88.5%  MLP=90.2%

============================================================
Cross-validation results (mean accuracy):
     knn: 84.4% (+/- 0.7%)
  logreg: 87.5% (+/- 0.8%)
     mlp: 90.2% (+/- 0.8%)

Best model: mlp (90.2%)

Retraining mlp on all 142 samples...
Model saved to artifacts
Done!

Model Selection

The script automatically selects the best model based on cross-validation performance:
  • k-NN: Good baseline, no training required, works well with small datasets
  • Logistic Regression: Fast, interpretable, handles class imbalance well
  • MLP: More expressive, can capture non-linear patterns, best for larger datasets
Class imbalance is handled via:
  • Class-weighted loss for MLP (inverse frequency)
  • class_weight="balanced" for Logistic Regression
  • Cosine distance for k-NN (works well with normalized embeddings)

Cross-Validation Strategy

Stratified k-fold ensures each fold maintains the same class distribution:
  • Number of folds: min(5, smallest_class_size), at least 2
  • Each fold gets a representative sample of all classes
  • Final model is retrained on all data for deployment

Build docs developers (and LLMs) love