Skip to main content

PyTorch Model Serialization

When developing machine learning models with PyTorch, being able to save your trained models is essential for deployment, sharing, or continuing training later. PyTorch offers several ways to serialize (convert to a storable format) and deserialize (load back) your models.

What is Model Serialization?

Model serialization is the process of converting a model object into a format that can be stored on disk, transmitted over a network, or saved to a database. For PyTorch models, this means converting complex Python objects containing model architecture and parameters into a file or byte stream.

Why Serialize PyTorch Models?

  • Deployment: Use trained models in production environments
  • Sharing: Distribute models to teammates or the community
  • Checkpointing: Save progress during long training sessions
  • Transfer Learning: Use pre-trained models as starting points
  • Versioning: Keep track of model iterations and improvements

Methods of Serialization in PyTorch

PyTorch provides two primary approaches for model serialization:

  1. torch.save and torch.load: The standard serialization methods
  2. TorchScript: For more production-ready, language-independent serialization

Let's explore each approach in detail.

Using torch.save and torch.load

This is the most common way to save and load PyTorch models. PyTorch uses Python's pickle module under the hood, with some additional capabilities for handling tensors efficiently.

Saving the Entire Model

python
import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 1)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# Create and train a model
model = SimpleModel()
# ... training code would be here ...

# Save the entire model
torch.save(model, 'entire_model.pth')

# Later, load the entire model
loaded_model = torch.load('entire_model.pth')
loaded_model.eval() # Set to evaluation mode

When saving the entire model, PyTorch serializes both the model architecture (class definition) and the model parameters (weights and biases).

Saving Only the State Dictionary

A more flexible approach is to save only the model's state dictionary, which contains all the learnable parameters:

python
# Save just the model state dictionary
torch.save(model.state_dict(), 'model_state_dict.pth')

# To load, you need to create an instance of the model first
new_model = SimpleModel()
new_model.load_state_dict(torch.load('model_state_dict.pth'))
new_model.eval() # Set to evaluation mode

This approach is generally preferred because:

  • It's more memory-efficient
  • It's more compatible across different versions of PyTorch
  • It separates model architecture from parameters

Saving Additional Information

Often, you want to save more than just the model. For example, you might want to save optimizer state, epoch numbers, or loss values:

python
import torch.optim as optim

model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 10
loss = 0.1

# Save checkpoint with additional information
checkpoint = {
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}
torch.save(checkpoint, 'model_checkpoint.pth')

# Load checkpoint
checkpoint = torch.load('model_checkpoint.pth')
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval() # Set to evaluation mode

This approach allows you to resume training exactly where you left off.

CPU and GPU Considerations

When saving and loading models across different devices, you may need to specify the device mapping:

python
# Save a model trained on GPU
torch.save(model.state_dict(), 'model_gpu.pth')

# Load to CPU
model = SimpleModel()
model.load_state_dict(torch.load('model_gpu.pth', map_location=torch.device('cpu')))

# Load to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleModel()
model.load_state_dict(torch.load('model_gpu.pth', map_location=device))

Using TorchScript for Serialization

TorchScript is a way to create serializable and optimizable versions of PyTorch models. It's particularly useful for production deployment:

python
import torch

# Define a simple model
class SimpleScriptableModel(torch.nn.Module):
def __init__(self):
super(SimpleScriptableModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 50)
self.fc2 = torch.nn.Linear(50, 1)

def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x

# Create a model instance
model = SimpleScriptableModel()

# Convert to TorchScript via tracing
example_input = torch.rand(1, 10)
traced_script_module = torch.jit.trace(model, example_input)

# Save the TorchScript model
traced_script_module.save('model_torchscript.pt')

# Load the TorchScript model
loaded_model = torch.jit.load('model_torchscript.pt')

# Use the model
test_input = torch.rand(1, 10)
output = loaded_model(test_input)
print(f"Model output: {output}")

With TorchScript, you can deploy your model in environments that don't have Python, like C++ applications.

Real-World Example: Transfer Learning with a Pretrained Model

In real-world scenarios, you might fine-tune a pretrained model on your dataset and then save it for later use:

python
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

# Load a pretrained ResNet model
model = models.resnet18(pretrained=True)

# Modify the model for your task (e.g., 5 output classes)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Train the model (simulated here)
epochs = 5
for epoch in range(epochs):
running_loss = 0.0
# Training code would go here
running_loss = 0.1 # Simulated loss
print(f"Epoch {epoch+1}, Loss: {running_loss}")

# Save the fine-tuned model
save_path = 'finetuned_resnet.pth'
torch.save({
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epochs': epochs,
}, save_path)
print(f"Model saved to {save_path}")

# To load and use the model
def load_model(path, num_classes=5):
model = models.resnet18(pretrained=False)
# Update the final fully connected layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model

# Load the fine-tuned model
loaded_model = load_model('finetuned_resnet.pth')

# Use the model for inference
# Create a random test image (3 channels, 224x224 pixels)
test_image = torch.rand(1, 3, 224, 224)
outputs = loaded_model(test_image)
_, predicted = torch.max(outputs, 1)
print(f"Predicted class: {predicted.item()}")

Best Practices for PyTorch Model Serialization

  1. Prefer state dictionaries over entire models:

    • More efficient and flexible
    • Better compatibility across PyTorch versions
  2. Save checkpoints regularly during training:

    • Store model state, optimizer state, epochs, and other metadata
    • Implement automatic checkpoint saving every N epochs
  3. Version your saved models:

    • Include version information in filenames or metadata
    • Track hyperparameters and training conditions
  4. Handle device compatibility:

    • Use the map_location parameter to handle CPU/GPU transfers
    • Check device availability before loading models
  5. Consider serialization format based on use case:

    • Use standard PyTorch serialization for development and research
    • Use TorchScript for production deployment
  6. Test model loading:

    • Always verify that loaded models produce the expected outputs
    • Check that performance metrics are preserved after loading

Common Issues and Solutions

"Module not found" Error

When loading a model that uses custom classes:

python
# When saving
torch.save(model, 'model.pth')

# When loading from another script, ensure the class is defined
from your_module import YourModelClass
model = torch.load('model.pth')

Version Mismatch

Saving and loading models across different PyTorch versions:

python
# Save with version information
torch.save({
'model_state_dict': model.state_dict(),
'pytorch_version': torch.__version__
}, 'model.pth')

# Check version when loading
checkpoint = torch.load('model.pth')
if checkpoint['pytorch_version'] != torch.__version__:
print(f"Warning: model was saved with PyTorch {checkpoint['pytorch_version']}, "
f"but you're using {torch.__version__}")

Summary

PyTorch model serialization is a critical skill for any deep learning practitioner. We've covered:

  • Basic model saving and loading with torch.save and torch.load
  • Saving complete models vs. state dictionaries
  • Including additional information in model checkpoints
  • Handling CPU/GPU device differences during serialization
  • Using TorchScript for production deployment
  • Real-world examples and best practices

By mastering these serialization techniques, you'll be able to efficiently store, share, and deploy your PyTorch models in various environments and applications.

Additional Resources

Exercises

  1. Create a simple CNN model, train it on MNIST, and save both the entire model and just the state dictionary.
  2. Implement a training loop with automatic checkpointing every 5 epochs.
  3. Fine-tune a pretrained ResNet model on a small custom dataset and save the model for later use.
  4. Convert a trained PyTorch model to TorchScript and load it in a separate script.
  5. Create a function that can resume training from a saved checkpoint.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)