Skip to main content

PyTorch Loading Models

In deep learning workflows, training models can be time-consuming and computationally expensive. The ability to save and subsequently load models is crucial for effective development. In this tutorial, we'll explore how to load pre-trained PyTorch models that have been saved to disk.

Introduction

After you've saved your PyTorch models, you'll need to know how to load them back into memory to:

  • Continue training from where you left off
  • Deploy models to production
  • Transfer knowledge from one model to another
  • Share models with others

PyTorch offers several methods to load models, each suited to different scenarios. Let's dive into these approaches.

Basic Model Loading Techniques

Loading an Entire Model

The simplest approach is loading an entire model that was saved using torch.save(model, PATH).

python
import torch
import torch.nn as nn

# Define a simple model architecture (same as what was saved)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# Load the entire model
model_path = "entire_model.pth"
loaded_model = torch.load(model_path)
loaded_model.eval() # Set to evaluation mode

print(type(loaded_model)) # Confirm it's loaded as a model

Output:

<class '__main__.SimpleModel'>

This method is simple but comes with a drawback: it's tightly coupled to the specific Python class and directory structure used when saving the model. If you change your code structure, loading might fail.

Loading State Dictionaries

A more flexible approach is loading just the model's state dictionary:

python
import torch
import torch.nn as nn

# Define the model architecture
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# Create a model instance
model = SimpleModel()

# Load the state dictionary
model_path = "model_state_dict.pth"
state_dict = torch.load(model_path)
model.load_state_dict(state_dict)
model.eval() # Set to evaluation mode

# Verify some weights are loaded
print(model.fc1.weight[0, :5]) # Show first 5 weights of first neuron

Output:

tensor([-0.0193, 0.0118, -0.0083, 0.0246, 0.0164], grad_fn=<SliceBackward0>)

This approach is more flexible and is the recommended way to load models for inference or continued training.

Loading Models to Different Devices

When loading a model, you might need to specify the device (CPU or GPU):

python
import torch

# Option 1: Load to CPU
device = torch.device('cpu')
model = torch.load('model.pth', map_location=device)

# Option 2: Load to specific GPU
device = torch.device('cuda:0') # First GPU
model = torch.load('model.pth', map_location=device)

# Option 3: Load to whichever device is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('model.pth', map_location=device)

print(f"Model loaded to: {next(model.parameters()).device}")

Output:

Model loaded to: cuda:0  # (or cpu, depending on your system)

Loading Checkpoints

For models saved as checkpoints (including optimizer state, epoch info, etc.), use this approach:

python
import torch
import torch.nn as nn
import torch.optim as optim

# Create a model and optimizer
model = nn.Linear(10, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Load checkpoint
checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# Choose whether to be in training or evaluation mode
model_mode = checkpoint.get('mode', 'train') # Default to train if not found
if model_mode == 'train':
model.train()
else:
model.eval()

print(f"Loaded model from epoch {epoch} with loss {loss:.4f}")
print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

Output:

Loaded model from epoch 10 with loss 0.2345
Current learning rate: 0.01

Loading Models Trained on Different Devices

Sometimes you need to load a model that was trained on a different device:

python
import torch

# If model was saved on GPU but you're loading on CPU
model = torch.load('gpu_model.pth', map_location=torch.device('cpu'))

# If model was saved on CPU but you're loading on GPU
device = torch.device('cuda')
model = torch.load('cpu_model.pth', map_location=device)
# Make sure model is on the right device
model = model.to(device)

Handling Model Architecture Changes

If you've made changes to your model architecture but still want to load some of the pre-trained weights:

python
import torch
import torch.nn as nn

# Original model had 128 hidden units
class OldModel(nn.Module):
def __init__(self):
super(OldModel, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# New model has 256 hidden units
class NewModel(nn.Module):
def __init__(self):
super(NewModel, self).__init__()
self.fc1 = nn.Linear(784, 256) # Changed size
self.fc2 = nn.Linear(256, 10) # Changed size

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# Load old state dict
old_state_dict = torch.load('old_model_state_dict.pth')

# Create new model
new_model = NewModel()

# Filter out incompatible keys
filtered_state_dict = {k: v for k, v in old_state_dict.items()
if k in new_model.state_dict() and
new_model.state_dict()[k].shape == v.shape}

# Load compatible weights
new_model.load_state_dict(filtered_state_dict, strict=False)
print(f"Loaded {len(filtered_state_dict)}/{len(new_model.state_dict())} layers")

Output:

Loaded 0/4 layers

In this example, none of the layers would be compatible due to shape differences. In real scenarios, you might have some compatible layers that can be loaded.

Real-world Example: Loading a Pre-trained Classification Model

Let's put together what we've learned to load a pre-trained model for image classification:

python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

# Define the transformation pipeline
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

# Load our fine-tuned model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.resnet18() # Create a base model with the right architecture
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 5) # 5 output classes

# Load the state dict
model.load_state_dict(torch.load('finetuned_resnet18.pth', map_location=device))
model.to(device)
model.eval()

# Load and preprocess an image
img = Image.open('sample_image.jpg')
img_tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension

# Perform inference
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)

class_labels = ['dog', 'cat', 'car', 'airplane', 'flower']
print(f"Prediction: {class_labels[predicted.item()]}")

Output:

Prediction: cat

Loading Models from Torch Hub

PyTorch Hub provides a simple way to load pre-trained models:

python
import torch

# Load a pre-trained BERT model
model = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')

# Load a pre-trained vision model
model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
print(f"Model loaded: {type(model).__name__}")

Output:

Model loaded: ResNet

Troubleshooting Common Issues

Missing Keys or Unexpected Keys

When loading state dictionaries, you might encounter these errors:

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
# Missing fc2 compared to saved model

def forward(self, x):
x = torch.relu(self.fc1(x))
return x

model = SimpleModel()

# Try to load a state dict that has more keys
try:
state_dict = {
'fc1.weight': torch.randn(5, 10),
'fc1.bias': torch.randn(5),
'fc2.weight': torch.randn(2, 5), # This key doesn't exist in our model
'fc2.bias': torch.randn(2) # This key doesn't exist in our model
}

# This will raise an error
model.load_state_dict(state_dict)
except Exception as e:
print(f"Error with strict=True: {e}")

# Try with strict=False
model.load_state_dict(state_dict, strict=False)
print("Loaded with strict=False successfully")

Output:

Error with strict=True: Unexpected key(s) in state_dict: "fc2.weight", "fc2.bias". 
Loaded with strict=False successfully

Version Compatibility Issues

Models saved with newer PyTorch versions might not load in older versions:

python
import torch

# Handle version compatibility errors
try:
model = torch.load("newer_version_model.pth")
except Exception as e:
print(f"Error loading model: {e}")

# Try pickling protocol workaround
try:
model = torch.load("newer_version_model.pth", pickle_module=pickle, pickle_protocol=4)
print("Successfully loaded with pickle protocol adjustment")
except Exception as e2:
print(f"Still failed: {e2}")

Summary

In this tutorial, we've covered various ways to load PyTorch models:

  1. Basic loading of entire models and state dictionaries
  2. Device management for loading models to CPU or GPU
  3. Loading checkpoints with training state
  4. Handling device differences between saved and loaded models
  5. Managing architecture changes in models
  6. Using PyTorch Hub for pre-trained models
  7. Troubleshooting common loading issues

Being able to effectively load models is essential for both continuing model development and deploying trained models to production environments.

Exercise Ideas

  1. Save a model, modify its architecture slightly, and practice loading it with strict=False
  2. Create a checkpoint saving system that automatically saves models at regular intervals
  3. Build a script that loads models trained on either CPU or GPU and adapts them to the available device
  4. Load a pre-trained model from PyTorch Hub and fine-tune it on a custom dataset
  5. Create a model versioning system that saves model metadata alongside weights

Additional Resources

Happy model loading!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)