Skip to main content

PyTorch Checkpointing

In deep learning, training models can be time-consuming and resource-intensive. PyTorch's checkpointing mechanism allows you to save the state of your model during training and resume it later. This is invaluable when dealing with long training sessions, potential system failures, or when implementing techniques like early stopping.

What is Checkpointing?

Checkpointing is the practice of saving a model's state at various intervals during training. A checkpoint typically contains:

  • The model's architecture and parameters (weights and biases)
  • The optimizer state (learning rates, momentum buffers)
  • Other training metadata (current epoch, loss values, etc.)

This allows you to resume training from the exact point where you left off, rather than starting from scratch.

Basic Checkpointing in PyTorch

Let's start with a simple example of how to save and load a checkpoint in PyTorch:

python
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
self.output = nn.Linear(5, 2)

def forward(self, x):
x = torch.relu(self.fc(x))
return self.output(x)

# Create model, loss function, and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Training loop (simplified)
for epoch in range(100):
# Training code here...

# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, # Assuming loss is defined in your training loop
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')
print(f"Checkpoint saved at epoch {epoch}")

Loading and Resuming Training from a Checkpoint

python
# Load checkpoint
checkpoint = torch.load('checkpoint_epoch_30.pth')
model = SimpleModel() # Create a new model instance
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Load the saved state
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1 # Resume from the next epoch

# Resume training loop
for epoch in range(start_epoch, 100):
# Continue training from where you left off
print(f"Training resumed from epoch {epoch}")
# ...

Comprehensive Checkpoint Content

A good checkpoint should include everything needed to resume training exactly where you left off. Here's what you typically want to save:

python
checkpoint = {
# Training progress
'epoch': epoch,
'global_step': global_step,
'best_validation_loss': best_val_loss,

# Model and optimizer states
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),

# For learning rate schedulers
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,

# Configuration and hyperparameters
'config': {
'batch_size': batch_size,
'learning_rate': learning_rate,
# other hyperparameters
},

# Random states for reproducibility
'random_state': random.getstate(),
'np_random_state': numpy.random.get_state(),
'torch_random_state': torch.get_rng_state(),
'cuda_random_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
}

torch.save(checkpoint, 'comprehensive_checkpoint.pth')

Best Practices for Checkpointing

1. Save Multiple Checkpoints

Saving only the most recent checkpoint can be risky. Consider these approaches:

python
# Save best model based on validation performance
if validation_loss < best_validation_loss:
best_validation_loss = validation_loss
torch.save(checkpoint, 'best_model_checkpoint.pth')

# Save regular checkpoints with epoch number
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')

# Keep a rotating set of the last N checkpoints
checkpoint_paths = [f'checkpoint_{i % 5}.pth' for i in range(100)]
torch.save(checkpoint, checkpoint_paths[epoch % 5])

2. Implement Early Stopping with Checkpointing

Early stopping helps prevent overfitting by stopping training when performance on a validation set stops improving.

python
def train_with_early_stopping(model, optimizer, train_loader, val_loader, patience=5):
best_val_loss = float('inf')
counter = 0

for epoch in range(100):
# Training phase
model.train()
train_loss = 0
for batch in train_loader:
# Training step...

# Validation phase
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
# Validation step...

print(f"Epoch {epoch}, Validation Loss: {val_loss}")

# Early stopping check
if val_loss < best_val_loss:
best_val_loss = val_loss
counter = 0
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'best_val_loss': best_val_loss
}
torch.save(checkpoint, 'best_model.pth')
print(f"Saved new best model with validation loss: {best_val_loss}")
else:
counter += 1

if counter >= patience:
print(f"Early stopping triggered after {epoch+1} epochs")
break

# Load the best model before returning
checkpoint = torch.save('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
return model

3. Handling Distributed Training Checkpoints

For distributed training across multiple GPUs, save checkpoints from the main process:

python
import torch.distributed as dist

# Inside training loop
if dist.get_rank() == 0: # Save only on the main process
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')

4. Automatic Checkpoint Management

Managing checkpoints becomes crucial for long-running experiments. Here's a simple checkpoint manager class:

python
class CheckpointManager:
def __init__(self, save_dir, max_to_keep=5):
self.save_dir = save_dir
self.max_to_keep = max_to_keep
self.checkpoints = []

# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

def save_checkpoint(self, state, is_best=False, metric=None):
filename = f"checkpoint_epoch_{state['epoch']}"
if metric is not None:
filename += f"_{metric:.4f}"
filename += ".pth"

filepath = os.path.join(self.save_dir, filename)
torch.save(state, filepath)

# Save as best if needed
if is_best:
best_path = os.path.join(self.save_dir, 'best_model.pth')
torch.save(state, best_path)

# Keep track of checkpoints
self.checkpoints.append(filepath)

# Remove old checkpoints if exceeding max_to_keep
if len(self.checkpoints) > self.max_to_keep:
old_checkpoint = self.checkpoints.pop(0)
if os.path.exists(old_checkpoint):
os.remove(old_checkpoint)

return filepath

Using the checkpoint manager:

python
checkpoint_manager = CheckpointManager('checkpoints/', max_to_keep=3)

# In training loop
val_metric = validation_accuracy # or some other metric
is_best = val_metric > best_metric
if is_best:
best_metric = val_metric

checkpoint_manager.save_checkpoint({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metric': val_metric
}, is_best=is_best, metric=val_metric)

Real-World Examples

Example 1: Training a CNN for Image Classification

Let's see how checkpointing works in a real image classification scenario:

python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load and transform data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=2)

# Define CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 8 * 8, 512)
self.fc2 = nn.Linear(512, 10)
self.dropout = nn.Dropout(0.25)

def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 8 * 8)
x = self.dropout(torch.relu(self.fc1(x)))
x = self.fc2(x)
return x

model = CNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

# Training with checkpointing
best_accuracy = 0
start_epoch = 0
checkpoint_path = 'cnn_cifar10_checkpoint.pth'

# Check if there's a checkpoint to resume from
if os.path.exists(checkpoint_path):
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
best_accuracy = checkpoint['best_accuracy']
print(f"Resuming from epoch {start_epoch} with best accuracy: {best_accuracy:.2f}%")

# Training loop
for epoch in range(start_epoch, 10):
model.train()
running_loss = 0.0

for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()
if (i + 1) % 100 == 0:
print(f'Epoch {epoch+1}, Batch {i+1}, Loss: {running_loss/100:.3f}')
running_loss = 0.0

# Evaluate on test set
model.eval()
correct = 0
total = 0
test_loss = 0.0

with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()

_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
avg_test_loss = test_loss / len(testloader)
print(f'Epoch {epoch+1}, Test Accuracy: {accuracy:.2f}%, Test Loss: {avg_test_loss:.4f}')

# Update learning rate scheduler
scheduler.step(avg_test_loss)

# Save checkpoint
is_best = accuracy > best_accuracy
if is_best:
best_accuracy = accuracy

checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'best_accuracy': best_accuracy,
'test_loss': avg_test_loss
}

torch.save(checkpoint, checkpoint_path)
if is_best:
torch.save(checkpoint, 'cnn_cifar10_best.pth')
print(f"Saved new best model with accuracy: {best_accuracy:.2f}%")

print(f"Training completed. Best accuracy: {best_accuracy:.2f}%")

Example 2: Resuming Training in Case of System Failure

Simulating a system crash and recovery:

python
def train_with_crash_recovery(model, optimizer, train_loader, val_loader, max_epochs=20):
# Path for our checkpoint
checkpoint_path = 'crash_recovery_checkpoint.pth'

# Check if we're resuming training
start_epoch = 0
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
print(f"Resuming training from epoch {start_epoch}")

try:
for epoch in range(start_epoch, max_epochs):
print(f"Starting epoch {epoch}")

# Training code here
for batch_idx, (data, target) in enumerate(train_loader):
# Forward, backward, optimize

# Simulate crash during epoch 5
if epoch == 5 and batch_idx == 100:
print("Simulating system crash!")
# Save checkpoint before crash
checkpoint = {
'epoch': epoch,
'batch_idx': batch_idx,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, checkpoint_path)
# Simulate crash
raise RuntimeError("System crashed during training!")

# Save checkpoint at end of epoch
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(checkpoint, checkpoint_path)
print(f"Completed epoch {epoch}")

except RuntimeError as e:
if "System crashed" in str(e):
print("Crash detected! You can restart training using the saved checkpoint.")
else:
raise

This example demonstrates how checkpointing can help recover from system failures during training.

Advanced Checkpointing Techniques

1. Saving and Loading Model Parts

For large models, you might want to save only specific parts:

python
# Save specific layers or components
torch.save({
'encoder': model.encoder.state_dict(),
'decoder': model.decoder.state_dict(),
}, 'model_components.pth')

# Load specific components
checkpoint = torch.load('model_components.pth')
model.encoder.load_state_dict(checkpoint['encoder'])
model.decoder.load_state_dict(checkpoint['decoder'])

2. Handling GPU/CPU Transitions

When loading checkpoints between different devices:

python
# Option 1: Specify map_location when loading
checkpoint = torch.load('model.pth', map_location=torch.device('cpu'))

# Option 2: Load to specified device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint = torch.load('model.pth', map_location=device)

3. Handling Different PyTorch Versions

Save version information to help with compatibility:

python
checkpoint = {
'model_state_dict': model.state_dict(),
'pytorch_version': torch.__version__,
}
torch.save(checkpoint, 'version_aware_checkpoint.pth')

Summary

PyTorch checkpointing is a crucial technique for managing deep learning training processes. It allows you to:

  • Save and resume training sessions
  • Implement early stopping to prevent overfitting
  • Recover from system failures
  • Share trained models with others
  • Deploy models to production environments

By following the best practices in this guide, you can make your training process more robust and efficient.

Additional Resources

Exercises

  1. Implement a checkpointing system for a model that saves every N batches instead of every epoch.
  2. Create a script that loads a checkpoint and continues training but with a different learning rate.
  3. Implement a checkpoint system that saves the K best models according to validation performance.
  4. Create a checkpoint manager that can handle multiple models training simultaneously.
  5. Implement a system that can recover training from arbitrary points in an epoch (not just at epoch boundaries).


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)