Class Imbalance
Class Imbalance
Class imbalance occurs when the distribution of target classes is highly skewed — one or more classes dominate the dataset. This is the norm, not the exception, in real-world problems: fraud detection (99.9% legitimate), disease diagnosis (95% healthy), spam detection (80% legitimate email).
Why Imbalance Breaks Standard Training
A model trained with cross-entropy loss on an imbalanced dataset will maximize accuracy by predicting the majority class. With 99% negatives, predicting always-negative gives 99% accuracy — but 0% recall on positives.
Strategies to Handle Imbalance
1. Use the Right Metric
First and most important: stop using accuracy. Use:
| Metric | Good when |
|---|---|
| F1-Score (macro/weighted) | Binary or multiclass, want balance |
| Precision-Recall AUC | When positives are rare and important |
| ROC-AUC | When you need threshold-agnostic evaluation |
| Matthews Correlation Coefficient (MCC) | Binary classification, very imbalanced |
| G-Mean | Geometric mean of recall per class |
2. Class Weights
Tell the loss function to penalize errors on the minority class more:
from sklearn.linear_model import LogisticRegression
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
# Automatic class weight computation
weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight = dict(zip(np.unique(y_train), weights))
# For sklearn models
model = LogisticRegression(class_weight='balanced')
# For PyTorch
import torch
pos_weight = torch.tensor([neg_count / pos_count])
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
3. Oversampling — SMOTE
SMOTE (Synthetic Minority Over-sampling Technique) generates synthetic minority samples by interpolating between existing ones:
from imblearn.over_sampling import SMOTE
smote = SMOTE(sampling_strategy='auto', k_neighbors=5, random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
print(f"Before: {dict(zip(*np.unique(y_train, return_counts=True)))}")
print(f"After: {dict(zip(*np.unique(y_resampled, return_counts=True)))}")
Apply SMOTE only on training data
Never resample the validation or test set. This would give an unrealistic class distribution not representative of production.
4. Undersampling
Remove majority class samples. Simpler than SMOTE but loses information:
from imblearn.under_sampling import RandomUnderSampler, TomekLinks
# Random undersampling
rus = RandomUnderSampler(sampling_strategy=0.5) # minority:majority = 1:2
X_res, y_res = rus.fit_resample(X_train, y_train)
# Tomek Links: remove majority samples near the boundary
tl = TomekLinks()
X_res, y_res = tl.fit_resample(X_train, y_train)
5. Threshold Tuning
Standard classifiers output probabilities. The default threshold (0.5) may not be optimal for imbalanced data. Tune it on the validation set:
from sklearn.metrics import precision_recall_curve
import numpy as np
probs = model.predict_proba(X_val)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(y_val, probs)
# Find threshold maximizing F1
f1_scores = 2 * precisions * recalls / (precisions + recalls + 1e-8)
best_thresh = thresholds[np.argmax(f1_scores)]
y_pred = (probs >= best_thresh).astype(int)
Strategy Decision Guide
flowchart TD
A[Imbalanced dataset] --> B{Imbalance ratio?}
B -->|"< 1:10"| C[Class weights\noften sufficient]
B -->|"1:10 to 1:100"| D{Enough minority samples?}
B -->|"> 1:100"| E[Combine: SMOTE\n+ class weights\n+ threshold tuning]
D -->|"> 500 samples| F[SMOTE or\ncombined approach]
D -->|"< 500 samples"| G[Class weights\n+ data collection]
C --> H[Evaluate with\nPR-AUC / F1]
F --> H
E --> H
G --> H Imbalance in Deep Learning
For neural networks with large datasets, class weights are usually the most practical and effective solution. SMOTE on millions of samples is slow and the synthetic samples may not match the learned feature distribution well.
Additional techniques specific to deep learning:
| Technique | Idea |
|---|---|
| Focal Loss | Down-weights easy examples (well-classified majority), focuses on hard minority |
| Label Smoothing | Prevents over-confident predictions on majority class |
| Balanced Batch Sampling | Ensure each batch contains equal minority/majority samples |
| Mixup | Interpolate between samples of different classes |
# Focal Loss (binary)
class FocalLoss(torch.nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super().__init__()
self.gamma, self.alpha = gamma, alpha
def forward(self, pred, target):
bce = torch.nn.functional.binary_cross_entropy_with_logits(pred, target, reduction='none')
pt = torch.exp(-bce)
return (self.alpha * (1 - pt)**self.gamma * bce).mean()