Skip to main content

The Problem

Real-world datasets are rarely balanced. In the TikTok Auto Collection Sorter, you might have 82 videos in your “soccer” folder but only 5 in “funny”. Without special handling, the model learns to predict the majority class (soccer) almost always, since it gets rewarded for being correct 82 out of 87 times.

How It Works

The system uses class-weighted cross-entropy loss to give minority classes more importance during training. When the model misclassifies a rare class, the penalty is proportionally larger.

Implementation Details

From train.py:54-58, here’s the exact implementation:
# Class-weighted loss to handle imbalance (soccer=82 vs funny=5)
class_counts = np.bincount(y_train, minlength=num_classes).astype(float)
class_counts = np.maximum(class_counts, 1.0)  # avoid div by zero
weights = 1.0 / class_counts
weights = weights / weights.sum() * num_classes  # normalize
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights).to(device))

Step-by-Step Breakdown

  1. Count samples per class: np.bincount(y_train) returns [82, 5, 30, ...]
  2. Compute inverse frequencies: weights = 1.0 / class_counts[0.012, 0.200, 0.033, ...]
  3. Normalize: Scale weights so they sum to num_classes
  4. Apply to loss function: PyTorch automatically multiplies loss by class weight

Mathematical Intuition

Standard cross-entropy loss:
Loss = -log(p_correct_class)
Class-weighted loss:
Loss = -w_class × log(p_correct_class)
Where w_class is inversely proportional to class frequency.

Example

Suppose you have:
  • Soccer: 80 videos (weight = 1/80 = 0.0125)
  • Funny: 5 videos (weight = 1/5 = 0.2)
After normalization (assuming 2 classes):
  • Soccer weight: 0.059
  • Funny weight: 0.941
Now misclassifying a “funny” video costs 16× more than misclassifying a “soccer” video, forcing the model to learn minority classes.

When Class Weighting Helps

Effective for:
  • Imbalance ratios up to 1:20 (e.g., 100 vs 5 samples)
  • Minority classes with distinct, learnable features
  • Small to medium datasets where collecting more minority samples is expensive
Less effective when:
  • Extreme imbalance (1:100+) — consider oversampling or SMOTE
  • Minority class has very high intra-class variance
  • Minority class features overlap heavily with majority class

Comparing with Alternatives

Oversampling

Duplicate minority class samples to match majority class size:
from sklearn.utils import resample

# Separate classes
X_majority = X_train[y_train == 0]
X_minority = X_train[y_train == 1]

# Upsample minority
X_minority_upsampled = resample(
    X_minority,
    n_samples=len(X_majority),
    random_state=42
)

X_train_balanced = np.vstack([X_majority, X_minority_upsampled])
Pros: Simple, no loss function changes needed
Cons: Can lead to overfitting if minority class is very small

Undersampling

Reduce majority class to match minority class:
X_majority_downsampled = resample(
    X_majority,
    n_samples=len(X_minority),
    random_state=42
)
Pros: Prevents majority class dominance
Cons: Throws away potentially useful data

Why We Use Class Weighting

Class weighting offers the best tradeoff:
  • Keeps all training data (no undersampling)
  • No artificial duplication (avoids overfitting)
  • Simple to implement
  • Works well with small datasets

Monitoring Class Performance

The training script prints per-class metrics using classification_report (train.py:103):
              precision    recall  f1-score   support

      soccer       0.92      0.94      0.93        82
       funny       0.85      0.80      0.82         5
     cooking       0.88      0.90      0.89        30

    accuracy                           0.90       213
Watch the recall column for minority classes. Low recall means the model is still missing many minority samples despite class weighting.

Adjusting Class Weights

If a minority class still performs poorly, you can manually increase its weight:
# Default: inverse frequency
weights = 1.0 / class_counts

# Option 1: Square root (less aggressive)
weights = 1.0 / np.sqrt(class_counts)

# Option 2: Manual boost for specific class
weights[minority_class_idx] *= 2.0

# Always normalize after adjustment
weights = weights / weights.sum() * num_classes
Over-weighting minority classes can hurt majority class performance. Monitor the confusion matrix to ensure majority classes aren’t suffering significantly.

Logistic Regression Alternative

The system also trains a Logistic Regression model with scikit-learn’s class_weight="balanced" (train.py:155):
lr = LogisticRegression(max_iter=1000, C=1.0, class_weight="balanced")
lr.fit(X_train, y_train)
This uses the same inverse frequency weighting internally. If cross-validation selects Logistic Regression as the best model, class imbalance is already handled automatically.

Practical Tips

  1. Check class distribution first: Run python train.py and look at the “Class distribution” output
  2. Monitor minority class recall: Aim for >70% recall on all classes
  3. Combine with active learning: Use active learning (see Active Learning) to collect more minority samples efficiently
  4. Validate on balanced test set: If possible, manually curate a balanced test set to get accurate per-class metrics

Code Example: Custom Weighting Scheme

Modify train.py to implement custom weights:
def compute_custom_weights(y_train, num_classes, strategy='inverse'):
    """
    Compute class weights with different strategies.
    
    Args:
        y_train: Training labels
        num_classes: Number of classes
        strategy: 'inverse', 'sqrt', or 'log'
    """
    class_counts = np.bincount(y_train, minlength=num_classes).astype(float)
    class_counts = np.maximum(class_counts, 1.0)
    
    if strategy == 'inverse':
        weights = 1.0 / class_counts
    elif strategy == 'sqrt':
        weights = 1.0 / np.sqrt(class_counts)
    elif strategy == 'log':
        weights = 1.0 / np.log1p(class_counts)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")
    
    # Normalize
    weights = weights / weights.sum() * num_classes
    return weights

# In train_mlp function, replace line 56-57 with:
weights = compute_custom_weights(y_train, num_classes, strategy='inverse')
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(weights).to(device))

Active Learning

Prioritize minority class samples for labeling

Custom Models

Modify model architecture for better class separation

Build docs developers (and LLMs) love