Skip to content

Class Imbalance

Class Imbalance

Class imbalance occurs when the distribution of target classes is highly skewed — one or more classes dominate the dataset. This is the norm, not the exception, in real-world problems: fraud detection (99.9% legitimate), disease diagnosis (95% healthy), spam detection (80% legitimate email).


Why Imbalance Breaks Standard Training

A model trained with cross-entropy loss on an imbalanced dataset will maximize accuracy by predicting the majority class. With 99% negatives, predicting always-negative gives 99% accuracy — but 0% recall on positives.


Strategies to Handle Imbalance

1. Use the Right Metric

First and most important: stop using accuracy. Use:

Metric Good when
F1-Score (macro/weighted) Binary or multiclass, want balance
Precision-Recall AUC When positives are rare and important
ROC-AUC When you need threshold-agnostic evaluation
Matthews Correlation Coefficient (MCC) Binary classification, very imbalanced
G-Mean Geometric mean of recall per class

2. Class Weights

Tell the loss function to penalize errors on the minority class more:

from sklearn.linear_model import LogisticRegression
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Automatic class weight computation
weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight = dict(zip(np.unique(y_train), weights))

# For sklearn models
model = LogisticRegression(class_weight='balanced')

# For PyTorch
import torch
pos_weight = torch.tensor([neg_count / pos_count])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

3. Oversampling — SMOTE

SMOTE (Synthetic Minority Over-sampling Technique) generates synthetic minority samples by interpolating between existing ones:

from imblearn.over_sampling import SMOTE

smote = SMOTE(sampling_strategy='auto', k_neighbors=5, random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

print(f"Before: {dict(zip(*np.unique(y_train, return_counts=True)))}")
print(f"After:  {dict(zip(*np.unique(y_resampled, return_counts=True)))}")

Apply SMOTE only on training data

Never resample the validation or test set. This would give an unrealistic class distribution not representative of production.

4. Undersampling

Remove majority class samples. Simpler than SMOTE but loses information:

from imblearn.under_sampling import RandomUnderSampler, TomekLinks

# Random undersampling
rus = RandomUnderSampler(sampling_strategy=0.5)  # minority:majority = 1:2
X_res, y_res = rus.fit_resample(X_train, y_train)

# Tomek Links: remove majority samples near the boundary
tl = TomekLinks()
X_res, y_res = tl.fit_resample(X_train, y_train)

5. Threshold Tuning

Standard classifiers output probabilities. The default threshold (0.5) may not be optimal for imbalanced data. Tune it on the validation set:

from sklearn.metrics import precision_recall_curve
import numpy as np

probs = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, probs)

# Find threshold maximizing F1
f1_scores = 2 * precisions * recalls / (precisions + recalls + 1e-8)
best_thresh = thresholds[np.argmax(f1_scores)]
y_pred = (probs >= best_thresh).astype(int)

Strategy Decision Guide

flowchart TD
    A[Imbalanced dataset] --> B{Imbalance ratio?}
    B -->|"< 1:10"| C[Class weights\noften sufficient]
    B -->|"1:10 to 1:100"| D{Enough minority samples?}
    B -->|"> 1:100"| E[Combine: SMOTE\n+ class weights\n+ threshold tuning]
    D -->|"> 500 samples| F[SMOTE or\ncombined approach]
    D -->|"< 500 samples"| G[Class weights\n+ data collection]
    C --> H[Evaluate with\nPR-AUC / F1]
    F --> H
    E --> H
    G --> H

Imbalance in Deep Learning

For neural networks with large datasets, class weights are usually the most practical and effective solution. SMOTE on millions of samples is slow and the synthetic samples may not match the learned feature distribution well.

Additional techniques specific to deep learning:

Technique Idea
Focal Loss Down-weights easy examples (well-classified majority), focuses on hard minority
Label Smoothing Prevents over-confident predictions on majority class
Balanced Batch Sampling Ensure each batch contains equal minority/majority samples
Mixup Interpolate between samples of different classes
# Focal Loss (binary)
class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25):
        super().__init__()
        self.gamma, self.alpha = gamma, alpha

    def forward(self, pred, target):
        bce = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='none')
        pt = torch.exp(-bce)
        return (self.alpha * (1 - pt)**self.gamma * bce).mean()