The Problem
Real-world datasets are rarely balanced. In the TikTok Auto Collection Sorter, you might have 82 videos in your “soccer” folder but only 5 in “funny”. Without special handling, the model learns to predict the majority class (soccer) almost always, since it gets rewarded for being correct 82 out of 87 times.
How It Works
The system uses class-weighted cross-entropy loss to give minority classes more importance during training. When the model misclassifies a rare class, the penalty is proportionally larger.
Implementation Details
From train.py:54-58, here’s the exact implementation:
# Class-weighted loss to handle imbalance (soccer=82 vs funny=5)
class_counts = np.bincount(y_train, minlength = num_classes).astype( float )
class_counts = np.maximum(class_counts, 1.0 ) # avoid div by zero
weights = 1.0 / class_counts
weights = weights / weights.sum() * num_classes # normalize
criterion = nn.CrossEntropyLoss( weight = torch.FloatTensor(weights).to(device))
Step-by-Step Breakdown
Count samples per class : np.bincount(y_train) returns [82, 5, 30, ...]
Compute inverse frequencies : weights = 1.0 / class_counts → [0.012, 0.200, 0.033, ...]
Normalize : Scale weights so they sum to num_classes
Apply to loss function : PyTorch automatically multiplies loss by class weight
Mathematical Intuition
Standard cross-entropy loss :
Loss = -log(p_correct_class)
Class-weighted loss :
Loss = -w_class × log(p_correct_class)
Where w_class is inversely proportional to class frequency.
Example
Suppose you have:
Soccer : 80 videos (weight = 1/80 = 0.0125)
Funny : 5 videos (weight = 1/5 = 0.2)
After normalization (assuming 2 classes):
Soccer weight: 0.059
Funny weight: 0.941
Now misclassifying a “funny” video costs 16× more than misclassifying a “soccer” video, forcing the model to learn minority classes.
When Class Weighting Helps
Effective for :
Imbalance ratios up to 1:20 (e.g., 100 vs 5 samples)
Minority classes with distinct, learnable features
Small to medium datasets where collecting more minority samples is expensive
Less effective when :
Extreme imbalance (1:100+) — consider oversampling or SMOTE
Minority class has very high intra-class variance
Minority class features overlap heavily with majority class
Comparing with Alternatives
Oversampling
Duplicate minority class samples to match majority class size:
from sklearn.utils import resample
# Separate classes
X_majority = X_train[y_train == 0 ]
X_minority = X_train[y_train == 1 ]
# Upsample minority
X_minority_upsampled = resample(
X_minority,
n_samples = len (X_majority),
random_state = 42
)
X_train_balanced = np.vstack([X_majority, X_minority_upsampled])
Pros : Simple, no loss function changes needed
Cons : Can lead to overfitting if minority class is very small
Undersampling
Reduce majority class to match minority class:
X_majority_downsampled = resample(
X_majority,
n_samples = len (X_minority),
random_state = 42
)
Pros : Prevents majority class dominance
Cons : Throws away potentially useful data
Why We Use Class Weighting
Class weighting offers the best tradeoff:
Keeps all training data (no undersampling)
No artificial duplication (avoids overfitting)
Simple to implement
Works well with small datasets
The training script prints per-class metrics using classification_report (train.py:103):
precision recall f1-score support
soccer 0.92 0.94 0.93 82
funny 0.85 0.80 0.82 5
cooking 0.88 0.90 0.89 30
accuracy 0.90 213
Watch the recall column for minority classes. Low recall means the model is still missing many minority samples despite class weighting.
Adjusting Class Weights
If a minority class still performs poorly, you can manually increase its weight:
# Default: inverse frequency
weights = 1.0 / class_counts
# Option 1: Square root (less aggressive)
weights = 1.0 / np.sqrt(class_counts)
# Option 2: Manual boost for specific class
weights[minority_class_idx] *= 2.0
# Always normalize after adjustment
weights = weights / weights.sum() * num_classes
Over-weighting minority classes can hurt majority class performance. Monitor the confusion matrix to ensure majority classes aren’t suffering significantly.
Logistic Regression Alternative
The system also trains a Logistic Regression model with scikit-learn’s class_weight="balanced" (train.py:155):
lr = LogisticRegression( max_iter = 1000 , C = 1.0 , class_weight = "balanced" )
lr.fit(X_train, y_train)
This uses the same inverse frequency weighting internally. If cross-validation selects Logistic Regression as the best model, class imbalance is already handled automatically.
Practical Tips
Check class distribution first : Run python train.py and look at the “Class distribution” output
Monitor minority class recall : Aim for >70% recall on all classes
Combine with active learning : Use active learning (see Active Learning ) to collect more minority samples efficiently
Validate on balanced test set : If possible, manually curate a balanced test set to get accurate per-class metrics
Code Example: Custom Weighting Scheme
Modify train.py to implement custom weights:
def compute_custom_weights ( y_train , num_classes , strategy = 'inverse' ):
"""
Compute class weights with different strategies.
Args:
y_train: Training labels
num_classes: Number of classes
strategy: 'inverse', 'sqrt', or 'log'
"""
class_counts = np.bincount(y_train, minlength = num_classes).astype( float )
class_counts = np.maximum(class_counts, 1.0 )
if strategy == 'inverse' :
weights = 1.0 / class_counts
elif strategy == 'sqrt' :
weights = 1.0 / np.sqrt(class_counts)
elif strategy == 'log' :
weights = 1.0 / np.log1p(class_counts)
else :
raise ValueError ( f "Unknown strategy: { strategy } " )
# Normalize
weights = weights / weights.sum() * num_classes
return weights
# In train_mlp function, replace line 56-57 with:
weights = compute_custom_weights(y_train, num_classes, strategy = 'inverse' )
criterion = nn.CrossEntropyLoss( weight = torch.FloatTensor(weights).to(device))
Active Learning Prioritize minority class samples for labeling
Custom Models Modify model architecture for better class separation