Skip to main content

PyTorch Early Stopping

Introduction

When training deep learning models, one of the biggest challenges is determining when to stop the training process. Training for too few epochs might result in underfitting, while training for too many epochs can lead to overfitting, where the model performs well on training data but poorly on unseen data.

Early stopping is a regularization technique that helps prevent overfitting by monitoring the model's performance on a validation set during training and stopping when the performance starts to degrade. In this tutorial, we'll learn how to implement early stopping in PyTorch and understand why it's an essential tool in your deep learning toolkit.

Why Use Early Stopping?

Let's first understand why early stopping is important:

  1. Prevents Overfitting: Stops training before the model starts memorizing the training data
  2. Saves Computational Resources: Reduces unnecessary training time
  3. Improves Generalization: Helps select models that perform better on unseen data
  4. Automatic Model Selection: Acts as a hyperparameter tuning mechanism

Basic Implementation of Early Stopping

Let's implement a simple early stopping class that can be used in any PyTorch training loop:

python
class EarlyStopping:
def __init__(self, patience=5, min_delta=0, verbose=False):
"""
Early stopping to stop training when validation loss doesn't improve.

Args:
patience (int): Number of epochs to wait after min has been hit before stopping
min_delta (float): Minimum change in monitored value to qualify as improvement
verbose (bool): If True, prints a message for each improvement
"""
self.patience = patience
self.min_delta = min_delta
self.verbose = verbose
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False

def __call__(self, val_loss, model):
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
if self.verbose:
print(f'Validation loss decreased to {val_loss:.6f}. Saving model...')
# Save the model
torch.save(model.state_dict(), 'best_model.pth')
else:
self.counter += 1
if self.verbose:
print(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
if self.verbose:
print('Early stopping triggered')

Using Early Stopping in a Training Loop

Now let's see how to incorporate early stopping into a PyTorch training loop:

python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define a simple network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()

def forward(self, x):
x = x.view(-1, 784)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x

# Prepare data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Split training data into train and validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize the network, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize early stopping
early_stopping = EarlyStopping(patience=10, verbose=True)

# Training loop
num_epochs = 100 # Set a large number, early stopping will prevent overfitting

for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
for data, target in train_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item() * data.size(0)

train_loss /= len(train_loader.dataset)

# Validation phase
model.eval()
val_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
val_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

val_loss /= len(val_loader.dataset)
val_acc = 100. * correct / len(val_loader.dataset)

print(f'Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.6f} | '
f'Val Loss: {val_loss:.6f} | Val Acc: {val_acc:.2f}%')

# Check early stopping
early_stopping(val_loss, model)
if early_stopping.early_stop:
print("Early stopping triggered!")
break

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on test set
model.eval()
test_loss = 0.0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item() * data.size(0)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
test_acc = 100. * correct / len(test_loader.dataset)
print(f'Test Loss: {test_loss:.6f} | Test Acc: {test_acc:.2f}%')

Expected Output:

Running this code will produce output similar to:

Epoch 1/100 | Train Loss: 0.542312 | Val Loss: 0.284919 | Val Acc: 91.73%
Validation loss decreased to 0.284919. Saving model...
Epoch 2/100 | Train Loss: 0.251796 | Val Loss: 0.206234 | Val Acc: 94.22%
Validation loss decreased to 0.206234. Saving model...
...
Epoch 20/100 | Train Loss: 0.054328 | Val Loss: 0.102435 | Val Acc: 97.16%
Validation loss decreased to 0.102435. Saving model...
Epoch 21/100 | Train Loss: 0.051265 | Val Loss: 0.108621 | Val Acc: 96.98%
Early stopping counter: 1 out of 10
...
Epoch 30/100 | Train Loss: 0.034521 | Val Loss: 0.124853 | Val Acc: 96.54%
Early stopping counter: 10 out of 10
Early stopping triggered!
Test Loss: 0.099837 | Test Acc: 97.24%

Advanced Early Stopping

Let's enhance our early stopping implementation with more features:

python
class EarlyStoppingWithCheckpoint:
def __init__(self, patience=5, min_delta=0, verbose=False, path='checkpoint.pth', trace_func=print):
"""
Enhanced early stopping with model checkpointing

Args:
patience (int): How long to wait after last improvement
min_delta (float): Minimum change to qualify as an improvement
verbose (bool): If True, prints a message for each improvement
path (str): Path to save the checkpoint
trace_func (callable): Function to use for printing information
"""
self.patience = patience
self.min_delta = min_delta
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float('inf')
self.path = path
self.trace_func = trace_func

def __call__(self, val_loss, model):
score = -val_loss # We use negative loss as score because higher score is better

if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.min_delta:
self.counter += 1
if self.verbose:
self.trace_func(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0

def save_checkpoint(self, val_loss, model):
"""Save model when validation loss decreases."""
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
torch.save({
'model_state_dict': model.state_dict(),
'val_loss': val_loss
}, self.path)
self.val_loss_min = val_loss

def load_checkpoint(self, model):
"""Load the best model."""
checkpoint = torch.load(self.path)
model.load_state_dict(checkpoint['model_state_dict'])
return model, checkpoint['val_loss']

Monitoring Different Metrics

Early stopping doesn't have to be based only on validation loss. You can use accuracy, F1 score, or any other relevant metric:

python
class MetricEarlyStopping:
def __init__(self, patience=5, min_delta=0, mode='min', verbose=False):
"""
Early stopping based on different metrics

Args:
patience (int): How long to wait after last improvement
min_delta (float): Minimum change to qualify as an improvement
mode (str): 'min' or 'max', whether we want to minimize or maximize the metric
verbose (bool): If True, prints a message for each improvement
"""
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_score = None
self.early_stop = False
self.verbose = verbose

if mode == 'min':
self.monitor_op = lambda current, best: current < best - self.min_delta
elif mode == 'max':
self.monitor_op = lambda current, best: current > best + self.min_delta
else:
raise ValueError("Mode must be either 'min' or 'max'")

self.best_score = float('inf') if mode == 'min' else float('-inf')

def __call__(self, metric_value, model):
if self.monitor_op(metric_value, self.best_score):
self.best_score = metric_value
self.counter = 0
if self.verbose:
print(f'Metric improved to {metric_value:.6f}. Saving model...')
torch.save(model.state_dict(), 'best_model.pth')
else:
self.counter += 1
if self.verbose:
print(f'Early stopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
if self.verbose:
print('Early stopping triggered')

Practical Example: Image Classification with Early Stopping

Let's see a more complete example with a CNN for image classification:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.25)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = self.dropout(torch.relu(self.fc1(x)))
x = self.fc2(x)
return x

# Data preprocessing
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
full_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)

# Split dataset
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize model, criterion, optimizer
model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Initialize early stopping (metric: validation accuracy)
early_stopping = MetricEarlyStopping(patience=7, mode='max', verbose=True)

# Training loop
num_epochs = 50

for epoch in range(num_epochs):
# Training phase
model.train()
train_loss = 0.0
correct = 0
total = 0

for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

train_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

train_acc = 100 * correct / total
train_loss = train_loss / len(train_loader)

# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)

val_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

val_acc = 100 * correct / total
val_loss = val_loss / len(val_loader)

print(f'Epoch {epoch+1}/{num_epochs} | '
f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | '
f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

# Call early stopping
early_stopping(val_acc, model) # Using accuracy (higher is better)

if early_stopping.early_stop:
print("Early stopping triggered!")
break

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on test set
model.eval()
correct = 0
total = 0

with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()

test_acc = 100 * correct / total
print(f'Test Accuracy: {test_acc:.2f}%')

Best Practices for Early Stopping

  1. Choose the Right Metric:

    • For classification: Validation accuracy or F1-score
    • For regression: Validation loss (MSE, MAE, etc.)
    • For generative models: Domain-specific metrics
  2. Set Appropriate Patience:

    • Too low: May stop training too early
    • Too high: May not prevent overfitting effectively
    • A good rule of thumb is 5-20 epochs, depending on your dataset size
  3. Save the Best Model:

    • Always save the model with the best validation performance, not the final model
  4. Combine with Other Regularization:

    • Use with dropout, batch normalization, or weight decay for better results
  5. Different Learning Rates:

    • Consider reducing the learning rate before early stopping triggers

Summary

Early stopping is a simple yet powerful regularization technique that can significantly improve your model's generalization capabilities. By implementing early stopping in your PyTorch training loops, you can:

  1. Automatically determine the optimal number of training epochs
  2. Prevent overfitting and improve generalization
  3. Save computational resources by not training unnecessarily
  4. Select the best model based on validation performance

The implementations provided in this tutorial can be easily integrated into any PyTorch project, and the concepts apply to all types of neural networks.

Additional Resources

  1. PyTorch Documentation
  2. Early Stopping in Neural Networks research paper

Exercises

  1. Modify the EarlyStopping class to monitor multiple metrics at once
  2. Implement learning rate reduction before early stopping in your training loop
  3. Compare models trained with and without early stopping on a dataset of your choice
  4. Experiment with different patience values and observe their effects on model performance
  5. Create a visualization that shows training and validation curves, marking where early stopping would trigger

Happy coding with PyTorch!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)