How to Implement Early Stopping in PyTorch
Early stopping is a regularization technique that monitors your model's validation performance during training and stops when improvement plateaus. Instead of training for a fixed number of epochs...
Key Insights
- Early stopping monitors validation metrics during training and halts when performance stops improving, preventing overfitting without requiring manual intervention or predetermined epoch counts
- A proper implementation requires tracking the best validation score, implementing a patience counter, and saving model checkpoints to restore the best weights after training stops
- The patience parameter is critical: too low causes premature stopping before convergence, too high wastes compute and risks overfitting—start with 5-10 epochs for most applications
Introduction to Early Stopping
Early stopping is a regularization technique that monitors your model’s validation performance during training and stops when improvement plateaus. Instead of training for a fixed number of epochs and hoping you picked the right number, early stopping automatically detects when your model has learned as much as it can from the data.
The mechanism is straightforward: after each epoch, you evaluate your model on a validation set. If the validation loss doesn’t improve for a specified number of epochs (the “patience” parameter), training stops. This prevents the model from continuing to train after it starts overfitting—when training loss keeps decreasing but validation loss increases or stagnates.
Early stopping is particularly valuable because it eliminates the need to guess the optimal number of training epochs. It also saves computational resources by not wasting time on epochs that don’t improve generalization. The key is saving model checkpoints so you can restore the weights from the epoch with the best validation performance, not the final epoch.
Basic Early Stopping Implementation
Let’s build a reusable EarlyStopping class that you can drop into any PyTorch project. This implementation tracks validation loss, implements patience logic, and handles model checkpointing.
import numpy as np
import torch
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
"""
Args:
patience (int): How many epochs to wait after last improvement
verbose (bool): If True, prints a message for each improvement
delta (float): Minimum change to qualify as an improvement
path (str): Path to save the checkpoint
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Saves model when validation loss decreases."""
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
This class maintains state across epochs. The __call__ method lets you use it like a function, passing in the current validation loss and model. It automatically handles the patience counter and checkpoint saving.
Integrating Early Stopping into the Training Loop
Here’s how to integrate early stopping into a standard PyTorch training loop. This example shows the complete flow including validation and checkpoint restoration.
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
def train_with_early_stopping(model, train_loader, val_loader, epochs=100, patience=10):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Initialize early stopping
early_stopping = EarlyStopping(patience=patience, verbose=True, path='best_model.pt')
for epoch in range(epochs):
# Training phase
model.train()
train_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# Validation phase
model.eval()
val_loss = 0.0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
train_loss = train_loss / len(train_loader)
val_loss = val_loss / len(val_loader)
print(f'Epoch {epoch+1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
# Early stopping check
early_stopping(val_loss, model)
if early_stopping.early_stop:
print(f"Early stopping triggered at epoch {epoch+1}")
break
# Load the best model weights
model.load_state_dict(torch.load('best_model.pt'))
return model
The critical part is calling early_stopping(val_loss, model) after each validation phase. The early stopping object maintains state across epochs and sets its early_stop flag when patience is exhausted. After training completes, we reload the best checkpoint—not the final weights.
Advanced Early Stopping Techniques
A production-ready early stopping implementation needs more flexibility. Here’s an enhanced version that supports different metrics, minimum improvement thresholds, and mode switching for metrics where higher is better.
class AdvancedEarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt',
mode='min', min_delta=0):
"""
Args:
patience (int): How many epochs to wait after last improvement
verbose (bool): If True, prints messages
delta (float): Minimum change to qualify as improvement (deprecated, use min_delta)
path (str): Path to save checkpoint
mode (str): 'min' for loss, 'max' for accuracy/F1
min_delta (float): Minimum change in monitored value to qualify as improvement
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.path = path
self.mode = mode
self.min_delta = min_delta
if self.mode == 'min':
self.monitor_op = lambda current, best: current < (best - self.min_delta)
self.best_score = np.Inf
elif self.mode == 'max':
self.monitor_op = lambda current, best: current > (best + self.min_delta)
self.best_score = -np.Inf
else:
raise ValueError(f"Mode {mode} is unknown. Use 'min' or 'max'")
def __call__(self, metric_value, model):
if self.monitor_op(metric_value, self.best_score):
if self.verbose:
print(f'Metric improved from {self.best_score:.6f} to {metric_value:.6f}. Saving model...')
self.best_score = metric_value
self.save_checkpoint(model)
self.counter = 0
else:
self.counter += 1
if self.verbose:
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
def save_checkpoint(self, model):
torch.save(model.state_dict(), self.path)
This version handles both minimization (loss) and maximization (accuracy, F1) metrics. The min_delta parameter requires improvements to exceed a threshold, preventing stopping due to tiny fluctuations that don’t represent real progress.
Usage example monitoring validation accuracy:
early_stopping = AdvancedEarlyStopping(patience=10, mode='max', min_delta=0.001, verbose=True)
# In training loop
val_accuracy = compute_accuracy(model, val_loader)
early_stopping(val_accuracy, model)
Early Stopping with PyTorch Lightning
If you’re using PyTorch Lightning, you get early stopping for free with a cleaner API. Lightning handles all the boilerplate and integrates seamlessly with its callback system.
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = YourModel()
self.criterion = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
# Setup early stopping
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=10,
mode='min',
min_delta=0.001,
verbose=True
)
# Train with early stopping
trainer = pl.Trainer(
max_epochs=100,
callbacks=[early_stop_callback],
accelerator='auto'
)
model = LitModel()
trainer.fit(model, train_loader, val_loader)
Lightning’s implementation includes additional features like divergence thresholds and automatic metric logging. The callback system makes it easy to combine early stopping with other callbacks like learning rate scheduling or model checkpointing.
Best Practices and Common Pitfalls
Choosing Patience Values: Start with patience between 5-10 epochs for most tasks. Computer vision models with large datasets can use lower values (3-5), while NLP models or small datasets benefit from higher patience (10-20). Monitor your validation curves—if they’re noisy, increase patience.
Minimum Delta Considerations: Set min_delta to filter out noise in your validation metrics. For loss values around 0.1-1.0, try min_delta=0.001. For accuracy percentages, use min_delta=0.001 (0.1% improvement). Without this, early stopping may trigger on meaningless fluctuations.
When Early Stopping Hurts: Don’t use early stopping when:
- Your validation set is too small (high variance in metrics)
- Training exhibits periodic oscillations (use longer patience instead)
- You’re doing curriculum learning or staged training
- You need to train for a minimum number of epochs for learning rate schedules
Checkpoint Management: Always save the best model, not the last. Validation loss often increases in the final epochs before early stopping triggers. Your production model should use weights from the best epoch.
# Good: Different configurations for different scenarios
configs = {
'stable_large_dataset': {'patience': 5, 'min_delta': 0.001},
'noisy_small_dataset': {'patience': 15, 'min_delta': 0.005},
'quick_experiments': {'patience': 3, 'min_delta': 0.0},
'production_training': {'patience': 10, 'min_delta': 0.001}
}
# Choose based on your context
early_stopping = AdvancedEarlyStopping(**configs['production_training'])
Monitoring Multiple Metrics: While you can only early stop on one metric, log everything. Sometimes validation loss increases while accuracy improves. Use TensorBoard or Weights & Biases to visualize all metrics and verify your early stopping metric aligns with your actual goal.
Early stopping is a simple yet powerful technique that should be in every deep learning practitioner’s toolkit. Implement it once properly, and you’ll never waste compute on unnecessary epochs or manually guess when to stop training again.