How to Perform Stratified K-Fold in Python

Standard K-Fold cross-validation splits your dataset into K equal parts without considering class distribution. This works fine when your classes are balanced, but falls apart with imbalanced...

Key Insights

  • Stratified K-Fold maintains class distribution across all folds, preventing evaluation bias in imbalanced datasets where minority classes might be completely absent from some folds
  • scikit-learn’s StratifiedKFold is drop-in compatible with standard K-Fold but ensures each fold represents the overall dataset’s class proportions within 1-2%
  • Always use stratified splitting for classification tasks—the computational overhead is negligible and the evaluation reliability improvement is substantial

Introduction to Cross-Validation and Class Imbalance

Standard K-Fold cross-validation splits your dataset into K equal parts without considering class distribution. This works fine when your classes are balanced, but falls apart with imbalanced datasets. Imagine a fraud detection dataset with 95% legitimate transactions and 5% fraudulent ones. With standard 5-fold cross-validation, you might end up with folds containing zero fraudulent cases, making meaningful evaluation impossible.

Stratified K-Fold solves this by ensuring each fold maintains approximately the same percentage of samples for each class as the complete dataset. If your original dataset is 95/5, each fold will be roughly 95/5. This preserves the statistical properties of your data and gives you reliable performance estimates across all folds.

Understanding Stratified K-Fold

The stratification process works by sampling from each class independently. Instead of randomly shuffling all samples and splitting them into K parts, stratified K-Fold:

  1. Separates samples by class
  2. Splits each class into K parts
  3. Combines the corresponding parts from each class to form each fold

This guarantees proportional representation. Let’s visualize the difference:

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold
import matplotlib.pyplot as plt

# Create imbalanced dataset: 90% class 0, 10% class 1
np.random.seed(42)
X = np.random.randn(100, 2)
y = np.array([0] * 90 + [1] * 10)

# Standard K-Fold
kf = KFold(n_splits=5, shuffle=True, random_state=42)
standard_distributions = []

for fold_idx, (train_idx, test_idx) in enumerate(kf.split(X)):
    y_test = y[test_idx]
    class_1_percentage = (y_test == 1).sum() / len(y_test) * 100
    standard_distributions.append(class_1_percentage)
    print(f"Standard K-Fold {fold_idx + 1}: {class_1_percentage:.1f}% minority class")

print("\n")

# Stratified K-Fold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
stratified_distributions = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    y_test = y[test_idx]
    class_1_percentage = (y_test == 1).sum() / len(y_test) * 100
    stratified_distributions.append(class_1_percentage)
    print(f"Stratified K-Fold {fold_idx + 1}: {class_1_percentage:.1f}% minority class")

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.bar(range(1, 6), standard_distributions)
ax1.axhline(y=10, color='r', linestyle='--', label='True distribution (10%)')
ax1.set_xlabel('Fold')
ax1.set_ylabel('Minority Class %')
ax1.set_title('Standard K-Fold')
ax1.legend()

ax2.bar(range(1, 6), stratified_distributions)
ax2.axhline(y=10, color='r', linestyle='--', label='True distribution (10%)')
ax2.set_xlabel('Fold')
ax2.set_ylabel('Minority Class %')
ax2.set_title('Stratified K-Fold')
ax2.legend()

plt.tight_layout()
plt.show()

You’ll see that standard K-Fold produces wildly varying distributions (some folds might have 0%, 5%, or 15% minority class), while stratified K-Fold consistently maintains the 10% proportion.

Basic Implementation with scikit-learn

The StratifiedKFold class from scikit-learn makes implementation straightforward. Here’s the basic pattern:

from sklearn.model_selection import StratifiedKFold
from sklearn.datasets import make_classification

# Create imbalanced dataset
X, y = make_classification(
    n_samples=1000,
    n_features=20,
    n_informative=15,
    n_redundant=5,
    n_classes=3,
    weights=[0.7, 0.2, 0.1],  # Imbalanced classes
    random_state=42
)

print(f"Original class distribution:")
unique, counts = np.unique(y, return_counts=True)
for cls, count in zip(unique, counts):
    print(f"  Class {cls}: {count} ({count/len(y)*100:.1f}%)")

# Initialize StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

print("\nFold distributions:")
for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    y_train, y_test = y[train_idx], y[test_idx]
    
    print(f"\nFold {fold_idx + 1}:")
    print(f"  Train size: {len(y_train)}, Test size: {len(y_test)}")
    
    # Check test set distribution
    unique_test, counts_test = np.unique(y_test, return_counts=True)
    for cls, count in zip(unique_test, counts_test):
        print(f"  Test Class {cls}: {count} ({count/len(y_test)*100:.1f}%)")

Notice that split() requires both X and y for stratified splitting, unlike standard K-Fold which only needs X. The shuffle=True parameter randomizes the data before splitting, and random_state ensures reproducibility.

Practical Example: Model Evaluation

Here’s a complete workflow using stratified K-Fold for model evaluation on an imbalanced dataset:

from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
import pandas as pd

# Load dataset
data = load_breast_cancer()
X, y = data.data, data.target

# Make it more imbalanced by removing some benign cases
mask = (y == 1) | (np.random.rand(len(y)) > 0.5)
X, y = X[mask], y[mask]

print(f"Dataset size: {len(y)}")
print(f"Class distribution: {np.bincount(y)}")

# Initialize stratified K-Fold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Store metrics
fold_metrics = []

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X, y)):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    
    # Train model
    clf = RandomForestClassifier(n_estimators=100, random_state=42)
    clf.fit(X_train, y_train)
    
    # Predict and evaluate
    y_pred = clf.predict(X_test)
    
    accuracy = accuracy_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred, average='weighted')
    
    fold_metrics.append({
        'fold': fold_idx + 1,
        'accuracy': accuracy,
        'f1_score': f1
    })
    
    print(f"\nFold {fold_idx + 1}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  F1-Score: {f1:.4f}")

# Aggregate results
metrics_df = pd.DataFrame(fold_metrics)
print("\n" + "="*50)
print(f"Mean Accuracy: {metrics_df['accuracy'].mean():.4f} (+/- {metrics_df['accuracy'].std():.4f})")
print(f"Mean F1-Score: {metrics_df['f1_score'].mean():.4f} (+/- {metrics_df['f1_score'].std():.4f})")

This pattern gives you robust performance estimates with confidence intervals. The standard deviation tells you how stable your model is across different data splits.

Advanced Techniques

Stratified Train-Test Split

For a simple train-test split (not K-Fold), use the stratify parameter:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y,  # This ensures stratification
    random_state=42
)

print("Training set distribution:", np.bincount(y_train))
print("Test set distribution:", np.bincount(y_test))

Handling Multi-Label Classification

For multi-label problems where each sample can have multiple labels, you need a different approach:

from sklearn.model_selection import StratifiedKFold
import pandas as pd

# Simulate multi-label data
n_samples = 1000
y_multilabel = np.random.randint(0, 2, size=(n_samples, 3))

# Convert to string representation for stratification
y_combined = pd.DataFrame(y_multilabel).apply(lambda x: ''.join(x.astype(str)), axis=1)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

for fold_idx, (train_idx, test_idx) in enumerate(skf.split(X[:n_samples], y_combined)):
    print(f"Fold {fold_idx + 1}: Train={len(train_idx)}, Test={len(test_idx)}")

Stratified Group K-Fold

When you have grouped data (e.g., multiple samples from the same patient), use StratifiedGroupKFold to ensure groups don’t leak across folds while maintaining stratification:

from sklearn.model_selection import StratifiedGroupKFold

# Create groups (e.g., patient IDs)
groups = np.repeat(np.arange(100), 10)  # 100 groups, 10 samples each
X_grouped = np.random.randn(1000, 20)
y_grouped = np.random.randint(0, 2, 1000)

sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

for fold_idx, (train_idx, test_idx) in enumerate(sgkf.split(X_grouped, y_grouped, groups)):
    train_groups = set(groups[train_idx])
    test_groups = set(groups[test_idx])
    
    # Verify no group leakage
    assert len(train_groups & test_groups) == 0
    print(f"Fold {fold_idx + 1}: No group leakage verified")

Common Pitfalls and Best Practices

Always use stratified splitting for classification tasks. The only exception is when you’re specifically testing model robustness to distribution shift. The computational cost is identical to standard K-Fold.

Set random_state for reproducibility. Without it, your results will vary between runs, making debugging and comparison impossible.

Watch for very small classes. If a class has fewer samples than the number of folds, stratification will fail. Either reduce the number of folds or combine rare classes.

Here’s a comparison showing the performance difference:

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score

# Create highly imbalanced dataset
X, y = make_classification(
    n_samples=1000,
    n_classes=2,
    weights=[0.95, 0.05],
    random_state=42
)

def evaluate_cv(cv_splitter, X, y, name):
    scores = []
    for train_idx, test_idx in cv_splitter.split(X, y):
        clf = LogisticRegression(max_iter=1000)
        clf.fit(X[train_idx], y[train_idx])
        y_pred_proba = clf.predict_proba(X[test_idx])[:, 1]
        scores.append(roc_auc_score(y[test_idx], y_pred_proba))
    
    print(f"{name}:")
    print(f"  Mean AUC: {np.mean(scores):.4f} (+/- {np.std(scores):.4f})")
    print(f"  Min/Max: {np.min(scores):.4f} / {np.max(scores):.4f}")
    return scores

kf = KFold(n_splits=5, shuffle=True, random_state=42)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

standard_scores = evaluate_cv(kf, X, y, "Standard K-Fold")
print()
stratified_scores = evaluate_cv(skf, X, y, "Stratified K-Fold")

You’ll typically see that stratified K-Fold produces more stable results with lower variance, especially on imbalanced datasets. This translates to more reliable model selection and hyperparameter tuning.

Use stratified K-Fold as your default. Your evaluation metrics will thank you.

Liked this? There's more.

Every week: one practical technique, explained simply, with code you can use immediately.