PyTorch Checkpointing
In deep learning, training models can be time-consuming and resource-intensive. PyTorch's checkpointing mechanism allows you to save the state of your model during training and resume it later. This is invaluable when dealing with long training sessions, potential system failures, or when implementing techniques like early stopping.
What is Checkpointing?
Checkpointing is the practice of saving a model's state at various intervals during training. A checkpoint typically contains:
- The model's architecture and parameters (weights and biases)
- The optimizer state (learning rates, momentum buffers)
- Other training metadata (current epoch, loss values, etc.)
This allows you to resume training from the exact point where you left off, rather than starting from scratch.
Basic Checkpointing in PyTorch
Let's start with a simple example of how to save and load a checkpoint in PyTorch:
import torch
import torch.nn as nn
import torch.optim as optim
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
self.output = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc(x))
return self.output(x)
# Create model, loss function, and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Training loop (simplified)
for epoch in range(100):
# Training code here...
# Save checkpoint every 10 epochs
if epoch % 10 == 0:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss, # Assuming loss is defined in your training loop
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')
print(f"Checkpoint saved at epoch {epoch}")
Loading and Resuming Training from a Checkpoint
# Load checkpoint
checkpoint = torch.load('checkpoint_epoch_30.pth')
model = SimpleModel() # Create a new model instance
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Load the saved state
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1 # Resume from the next epoch
# Resume training loop
for epoch in range(start_epoch, 100):
# Continue training from where you left off
print(f"Training resumed from epoch {epoch}")
# ...
Comprehensive Checkpoint Content
A good checkpoint should include everything needed to resume training exactly where you left off. Here's what you typically want to save:
checkpoint = {
# Training progress
'epoch': epoch,
'global_step': global_step,
'best_validation_loss': best_val_loss,
# Model and optimizer states
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# For learning rate schedulers
'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
# Configuration and hyperparameters
'config': {
'batch_size': batch_size,
'learning_rate': learning_rate,
# other hyperparameters
},
# Random states for reproducibility
'random_state': random.getstate(),
'np_random_state': numpy.random.get_state(),
'torch_random_state': torch.get_rng_state(),
'cuda_random_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
}
torch.save(checkpoint, 'comprehensive_checkpoint.pth')
Best Practices for Checkpointing
1. Save Multiple Checkpoints
Saving only the most recent checkpoint can be risky. Consider these approaches:
# Save best model based on validation performance
if validation_loss < best_validation_loss:
best_validation_loss = validation_loss
torch.save(checkpoint, 'best_model_checkpoint.pth')
# Save regular checkpoints with epoch number
torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')
# Keep a rotating set of the last N checkpoints
checkpoint_paths = [f'checkpoint_{i % 5}.pth' for i in range(100)]
torch.save(checkpoint, checkpoint_paths[epoch % 5])