PyTorch Early Stopping
Introduction
When training deep learning models, one of the biggest challenges is determining when to stop the training process. Training for too few epochs might result in underfitting, while training for too many epochs can lead to overfitting, where the model performs well on training data but poorly on unseen data.
Early stopping is a regularization technique that helps prevent overfitting by monitoring the model's performance on a validation set during training and stopping when the performance starts to degrade. In this tutorial, we'll learn how to implement early stopping in PyTorch and understand why it's an essential tool in your deep learning toolkit.
Why Use Early Stopping?
Let's first understand why early stopping is important:
- Prevents Overfitting: Stops training before the model starts memorizing the training data
- Saves Computational Resources: Reduces unnecessary training time
- Improves Generalization: Helps select models that perform better on unseen data
- Automatic Model Selection: Acts as a hyperparameter tuning mechanism
Basic Implementation of Early Stopping
Let's implement a simple early stopping class that can be used in any PyTorch training loop:
class EarlyStopping:
def __init__(self, patience=5, min_delta=0, verbose=False):
"""
Early stopping to stop training when validation loss doesn't improve.
Args:
patience (int): Number of epochs to wait after min has been hit before stopping
min_delta (float): Minimum change in monitored value to qualify as improvement
verbose (bool): If True, prints a message for each improvement
"""
self.patience = patience
self.min_delta = min_delta
self.verbose = verbose
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss, model):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
if self.verbose:
print(f'Validation loss decreased to {val_loss:.6f}. Saving model...')
# Save the model
torch.save(model.state_dict(), 'best_model.pth')
else:
self.counter += 1
if self.verbose:
print(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
if self.verbose:
print('Early stopping triggered')
Using Early Stopping in a Training Loop
Now let's see how to incorporate early stopping into a PyTorch training loop:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Define a simple network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Prepare data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
# Split training data into train and validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Initialize the network, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)
# Training loop
num_epochs = 100 # Set a large number, early stopping will prevent overfitting
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)
train_loss /= len(train_loader.dataset)
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
val_loss /= len(val_loader.dataset)
val_acc = 100. * correct / len(val_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.6f} | '
f'Val Loss: {val_loss:.6f} | Val Acc: {val_acc:.2f}%')
# Check early stopping
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))
# Evaluate on test set
model.eval()
test_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc = 100. * correct / len(test_loader.dataset)
print(f'Test Loss: {test_loss:.6f} | Test Acc: {test_acc:.2f}%')
Expected Output:
Running this code will produce output similar to:
Epoch 1/100 | Train Loss: 0.542312 | Val Loss: 0.284919 | Val Acc: 91.73%
Validation loss decreased to 0.284919. Saving model...
Epoch 2/100 | Train Loss: 0.251796 | Val Loss: 0.206234 | Val Acc: 94.22%
Validation loss decreased to 0.206234. Saving model...
...
Epoch 20/100 | Train Loss: 0.054328 | Val Loss: 0.102435 | Val Acc: 97.16%
Validation loss decreased to 0.102435. Saving model...
Epoch 21/100 | Train Loss: 0.051265 | Val Loss: 0.108621 | Val Acc: 96.98%
Early stopping counter: 1 out of 10
...
Epoch 30/100 | Train Loss: 0.034521 | Val Loss: 0.124853 | Val Acc: 96.54%
Early stopping counter: 10 out of 10
Early stopping triggered!
Test Loss: 0.099837 | Test Acc: 97.24%
Advanced Early Stopping
Let's enhance our early stopping implementation with more features:
class EarlyStoppingWithCheckpoint:
def __init__(self, patience=5, min_delta=0, verbose=False, path='checkpoint.pth', trace_func=print):
"""
Enhanced early stopping with model checkpointing
Args:
patience (int): How long to wait after last improvement
min_delta (float): Minimum change to qualify as an improvement
verbose (bool): If True, prints a message for each improvement
path (str): Path to save the checkpoint
trace_func (callable): Function to use for printing information
"""
self.patience = patience
self.min_delta = min_delta
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):
score = -val_loss # We use negative loss as score because higher score is better
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.min_delta:
self.counter += 1
if self.verbose:
self.trace_func(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save({
'model_state_dict': model.state_dict(),
'val_loss': val_loss
}, self.path)
self.val_loss_min = val_loss
def load_checkpoint(self, model):
"""Load the best model."""
checkpoint = torch.load(self.path)
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint['val_loss']
Monitoring Different Metrics
Early stopping doesn't have to be based only on validation loss. You can use accuracy, F1 score, or any other relevant metric:
class MetricEarlyStopping:
def __init__(self, patience=5, min_delta=0, mode='min', verbose=False):
"""
Early stopping based on different metrics
Args:
patience (int): How long to wait after last improvement
min_delta (float): Minimum change to qualify as an improvement
mode (str): 'min' or 'max', whether we want to minimize or maximize the metric
verbose (bool): If True, prints a message for each improvement
"""
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = None
self.early_stop = False
self.verbose = verbose
if mode == 'min':
self.monitor_op = lambda current, best: current < best - self.min_delta
elif mode == 'max':
self.monitor_op = lambda current, best: current > best + self.min_delta
else:
raise ValueError("Mode must be either 'min' or 'max'")
self.best_score = float('inf') if mode == 'min' else float('-inf')
def __call__(self, metric_value, model):
if self.monitor_op(metric_value, self.best_score):
self.best_score = metric_value
self.counter = 0
if self.verbose:
print(f'Metric improved to {metric_value:.6f}. Saving model...')
torch.save(model.state_dict(), 'best_model.pth')
else:
self.counter += 1
if self.verbose:
print(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
if self.verbose:
print('Early stopping triggered')
Practical Example: Image Classification with Early Stopping
Let's see a more complete example with a CNN for image classification:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = self.dropout(torch.relu(self.fc1(x)))
x = self.fc2(x)
return x
# Data preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load CIFAR-10 dataset
full_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
# Split dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# Initialize model, criterion, optimizer
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Initialize early stopping (metric: validation accuracy)
early_stopping = MetricEarlyStopping(patience=7, mode='max', verbose=True)
# Training loop
num_epochs = 50
for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_acc = 100 * correct / total
train_loss = train_loss / len(train_loader)
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
val_acc = 100 * correct / total
val_loss = val_loss / len(val_loader)
print(f'Epoch {epoch+1}/{num_epochs} | '
f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | '
f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
# Call early stopping
early_stopping(val_acc, model) # Using accuracy (higher is better)
if early_stopping.early_stop:
print("Early stopping triggered!")
break
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))
# Evaluate on test set
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
test_acc = 100 * correct / total
print(f'Test Accuracy: {test_acc:.2f}%')
Best Practices for Early Stopping
-
Choose the Right Metric:
- For classification: Validation accuracy or F1-score
- For regression: Validation loss (MSE, MAE, etc.)
- For generative models: Domain-specific metrics
-
Set Appropriate Patience:
- Too low: May stop training too early
- Too high: May not prevent overfitting effectively
- A good rule of thumb is 5-20 epochs, depending on your dataset size
-
Save the Best Model:
- Always save the model with the best validation performance, not the final model
-
Combine with Other Regularization:
- Use with dropout, batch normalization, or weight decay for better results
-
Different Learning Rates:
- Consider reducing the learning rate before early stopping triggers
Summary
Early stopping is a simple yet powerful regularization technique that can significantly improve your model's generalization capabilities. By implementing early stopping in your PyTorch training loops, you can:
- Automatically determine the optimal number of training epochs
- Prevent overfitting and improve generalization
- Save computational resources by not training unnecessarily
- Select the best model based on validation performance
The implementations provided in this tutorial can be easily integrated into any PyTorch project, and the concepts apply to all types of neural networks.
Additional Resources
Exercises
- Modify the
EarlyStopping
class to monitor multiple metrics at once - Implement learning rate reduction before early stopping in your training loop
- Compare models trained with and without early stopping on a dataset of your choice
- Experiment with different patience values and observe their effects on model performance
- Create a visualization that shows training and validation curves, marking where early stopping would trigger
Happy coding with PyTorch!
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)