PyTorch State Dictionaries
Introduction
When working with PyTorch models, one of the most important aspects of model development is the ability to save and load your model's progress. At the heart of PyTorch's saving and loading mechanism are state dictionaries (or state_dict
), which are Python dictionaries that store all the parameters and persistent buffers used by a model, optimizer, or other PyTorch components.
Understanding state dictionaries is crucial because they:
- Allow you to save your model's learned parameters
- Enable transfer learning and model checkpointing
- Provide flexibility in managing model state
- Let you manipulate individual parameters if needed
In this tutorial, we'll dive into PyTorch state dictionaries, understand their structure, and learn how to effectively work with them.
What is a State Dictionary?
A state dictionary in PyTorch is simply a Python dictionary object that maps each layer's parameters to their corresponding tensor values. For a PyTorch model, the state dictionary contains all the learnable parameters (weights and biases) for each layer.
Model State Dictionaries
Let's look at a simple example to understand model state dictionaries:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define a simple neural network
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# Initialize model
model = SimpleModel()
# Print model's state_dict
print("Model's state_dict:")
for param_name, param in model.state_dict().items():
print(f"{param_name} \t {param.shape}")
Output:
Model's state_dict:
fc1.weight torch.Size([20, 10])
fc1.bias torch.Size([20])
fc2.weight torch.Size([2, 20])
fc2.bias torch.Size([2])
In this example:
- Each layer with learnable parameters (
fc1
andfc2
) has entries in the state dictionary - The keys follow the pattern
layer_name.parameter_type
- The values are tensors containing the actual parameter values
Optimizer State Dictionaries
Optimizers also have state dictionaries, which store:
- The optimizer's parameters (like learning rate, momentum)
- The optimizer's current state (like momentum buffers for SGD)
import torch.optim as optim
# Initialize the model
model = SimpleModel()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Print optimizer's state_dict
print("\nOptimizer's state_dict:")
for key, value in optimizer.state_dict().items():
print(f"{key}")
if key == 'param_groups':
for i, group in enumerate(value):
print(f" Parameter group {i}")
for k, v in group.items():
if k != 'params': # 'params' is just a list of parameter references
print(f" {k}: {v}")
Output:
Optimizer's state_dict:
state
param_groups
Parameter group 0
lr: 0.01
momentum: 0.9
dampening: 0
weight_decay: 0
nesterov: False
Saving and Loading State Dictionaries
Saving a Model's State Dictionary
The most common and recommended way to save a model in PyTorch is to save the model's state_dict
using torch.save()
:
# Training code would be here...
# Save the model state dictionary
torch.save(model.state_dict(), 'model_state_dict.pth')
# If you want to save optimizer state too
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'checkpoint.pth')
Loading a State Dictionary
To load a state dictionary back into a model:
# Method 1: Initialize model then load state dict
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval() # Set model to evaluation mode if using for inference
# Method 2: Load a checkpoint with both model and optimizer
checkpoint = torch.load('checkpoint.pth')
model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# If continuing training
model.train()
# Or if doing evaluation
# model.eval()
Working with State Dictionaries
Inspecting State Dictionaries
You can iterate through a state dictionary to inspect its contents:
model = SimpleModel()
state_dict = model.state_dict()
# Get the keys
print("State Dictionary Keys:", state_dict.keys())
# Check the data type of a specific parameter
print("\nData type of fc1.weight:", state_dict['fc1.weight'].dtype)
# Check the values of a specific parameter
print("\nValues of fc2.bias:", state_dict['fc2.bias'])
# Get the shape of parameters
print("\nShape of fc1.weight:", state_dict['fc1.weight'].shape)
Output:
State Dictionary Keys: odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
Data type of fc1.weight: torch.float32
Values of fc2.bias: tensor([0., 0.])
Shape of fc1.weight: torch.Size([20, 10])
Modifying State Dictionaries
You might sometimes need to modify a state dictionary directly. Here's an example:
model = SimpleModel()
state_dict = model.state_dict()
# Let's modify the bias of fc2 layer
print("Original fc2.bias:", state_dict['fc2.bias'])
state_dict['fc2.bias'] = torch.tensor([1.0, 2.0])
# Load the modified state dict back to the model
model.load_state_dict(state_dict)
print("Modified fc2.bias:", model.state_dict()['fc2.bias'])
Output:
Original fc2.bias: tensor([0., 0.])
Modified fc2.bias: tensor([1., 2.])
Common Scenarios for Working with State Dictionaries
Transfer Learning
State dictionaries make transfer learning straightforward:
# Load a pretrained model
import torchvision.models as models
pretrained_model = models.resnet18(pretrained=True)
pretrained_state_dict = pretrained_model.state_dict()
# Create your model with the same structure
your_model = models.resnet18(pretrained=False)
your_model_dict = your_model.state_dict()
# 1. Filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_state_dict.items()
if k in your_model_dict and 'fc' not in k}
# 2. Update your model state dict
your_model_dict.update(pretrained_dict)
# 3. Load the new state dict
your_model.load_state_dict(your_model_dict)
# Now your model has all the pretrained weights except for the fc layer
Model Checkpointing
State dictionaries enable saving checkpoints during training:
def train(model, optimizer, epochs, checkpoint_interval=5):
for epoch in range(epochs):
# Training code here...
# Save checkpoint every checkpoint_interval epochs
if (epoch + 1) % checkpoint_interval == 0:
checkpoint = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
# Add other states if needed
'loss': current_loss,
}
torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')
print(f"Checkpoint saved at epoch {epoch+1}")
Handling Devices (CPU/GPU)
When moving models between devices, state dictionaries make it easy:
# Save model trained on GPU
model_gpu = SimpleModel().cuda()
# Training code...
torch.save(model_gpu.state_dict(), 'model_gpu.pth')
# Load to CPU
model_cpu = SimpleModel()
model_cpu.load_state_dict(torch.load('model_gpu.pth', map_location='cpu'))
# Load to specific GPU
model_cuda = SimpleModel().cuda(1) # Load to GPU 1
model_cuda.load_state_dict(torch.load('model_gpu.pth', map_location='cuda:1'))
Practical Examples
Example 1: Training and Saving a MNIST Classifier
Let's create a complete example showing how to train a model on MNIST, save checkpoints, and load them:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# Define the neural network
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Set up data loaders
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_dataset = torchvision.datasets.MNIST('./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)
# Initialize model and optimizer
model = MNISTNet()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()
# Training function
def train(model, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} Batch: {batch_idx} Loss: {loss.item():.6f}')
# Testing function
def test(model):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
return accuracy
# Train the model for a few epochs
epochs = 2
best_accuracy = 0
for epoch in range(1, epochs + 1):
train(model, optimizer, epoch)
accuracy = test(model)
# Save checkpoint
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'accuracy': accuracy
}
# Save the best model
if accuracy > best_accuracy:
best_accuracy = accuracy
torch.save(checkpoint, 'best_mnist_model.pth')
print(f"Saved best model with accuracy: {accuracy:.2f}%")
# Save regular checkpoint
torch.save(checkpoint, f'mnist_checkpoint_epoch_{epoch}.pth')
print("Training complete!")
# How to load the model later
def load_best_model():
checkpoint = torch.load('best_mnist_model.pth')
model = MNISTNet()
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch']} with accuracy: {checkpoint['accuracy']:.2f}%")
return model
# Demonstration of loading
loaded_model = load_best_model()
Example 2: Fine-tuning a Pre-trained Model
Here's how to use state dictionaries for transfer learning:
import torch
import torch.nn as nn
import torchvision.models as models
# Load a pre-trained ResNet model
pretrained_model = models.resnet18(pretrained=True)
# Freeze all the parameters
for param in pretrained_model.parameters():
param.requires_grad = False
# Replace the final fully connected layer
num_features = pretrained_model.fc.in_features
pretrained_model.fc = nn.Linear(num_features, 10) # 10 output classes
# Print the model to see the new structure
print(pretrained_model)
# Now only train the final layer
optimizer = torch.optim.SGD(pretrained_model.fc.parameters(), lr=0.001, momentum=0.9)
# Later, you can save just the fine-tuned model
fine_tuned_state_dict = pretrained_model.state_dict()
torch.save(fine_tuned_state_dict, 'fine_tuned_resnet18.pth')
# You can also save a partial state dict (just the modified layer)
fc_state_dict = {k: v for k, v in fine_tuned_state_dict.items() if 'fc' in k}
torch.save(fc_state_dict, 'fine_tuned_fc_only.pth')
# To load the partial state dict back
def load_partial_model():
new_model = models.resnet18(pretrained=True)
# Load only the fc layer
fc_state = torch.load('fine_tuned_fc_only.pth')
model_state = new_model.state_dict()
model_state.update(fc_state)
new_model.load_state_dict(model_state)
return new_model
Summary
PyTorch state dictionaries are a powerful mechanism for managing model parameters and optimizer states. In this tutorial, we've covered:
- What state dictionaries are and their structure
- How to save and load model and optimizer state dictionaries
- How to inspect and modify state dictionaries
- Practical applications like transfer learning and model checkpointing
- Complete examples of training, saving, and loading models
Understanding state dictionaries is essential for effective PyTorch development, as they provide the flexibility needed for advanced deep learning workflows like checkpoint management, transfer learning, and model deployment.
Additional Resources
- PyTorch Documentation on Saving and Loading Models
- PyTorch nn.Module Documentation
- PyTorch Optimizer Documentation
Exercises
-
Create a simple CNN model and examine its state dictionary structure. How does it differ from a fully connected network?
-
Train a model, save checkpoints every 5 epochs, and implement a function that loads the best checkpoint based on validation accuracy.
-
Load a pre-trained model from torchvision, modify its architecture, and transfer the compatible weights from the pre-trained model to your modified version.
-
Create a function that merges the state dictionaries of two different models that share some layer names but may have different architectures.
-
Explore how to convert a PyTorch model's state dictionary to a format compatible with deployment frameworks like ONNX or TorchScript.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)