PyTorch Model Export
After training your PyTorch models, you often need to deploy them in different environments - from web servers to mobile devices and edge computing platforms. This is where model export comes into play. In this tutorial, you'll learn how to export your PyTorch models to various formats that make deployment easier.
Introduction to Model Export
While PyTorch is excellent for research and development, the native PyTorch format may not be the most efficient or compatible choice for production environments. Model export converts your PyTorch models into formats that are:
- Optimized for inference - Faster execution without training overhead
- Environment-compatible - Works in non-Python environments
- Platform-specific - Tailored for specific hardware or platforms
Let's explore the main export options available in the PyTorch ecosystem.
TorchScript Export
TorchScript is a way to create serializable and optimizable models from PyTorch code. It allows PyTorch models to be saved and then loaded in environments where Python dependency isn't available.
Basic TorchScript Export
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# Create an instance of the model
model = SimpleModel()
# Create example input
example_input = torch.rand(1, 10)
# Export to TorchScript using tracing
traced_model = torch.jit.trace(model, example_input)
# Save the traced model
traced_model.save("traced_model.pt")
# Later, you can load the model without Python dependencies
loaded_model = torch.jit.load("traced_model.pt")
# Use the loaded model
output = loaded_model(torch.rand(1, 10))
print(output)
Output:
tensor([[-0.1942, -0.0882]], grad_fn=<AddmmBackward0>)
Scripting vs Tracing
PyTorch offers two methods to convert models to TorchScript:
- Tracing: Follows the execution of your model with example inputs
- Scripting: Directly analyzes the Python code and converts it
For models with control flow that depends on inputs, scripting is often better:
import torch
import torch.nn as nn
class ModelWithControlFlow(nn.Module):
def __init__(self):
super(ModelWithControlFlow, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
if x.sum() > 0:
return self.fc1(x)
else:
return self.fc2(self.fc1(x))
# Create an instance of the model
model = ModelWithControlFlow()
# Script the model (better for control flow)
scripted_model = torch.jit.script(model)
# Save the scripted model
scripted_model.save("scripted_model.pt")
ONNX Export
ONNX (Open Neural Network Exchange) is an open format built to represent machine learning models. ONNX allows models to be transferred between different frameworks.
Basic ONNX Export
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Create a model instance
model = SimpleModel()
model.eval() # Set to evaluation mode
# Input to the model
x = torch.randn(1, 3, 32, 32, requires_grad=True)
# Export the model to ONNX
torch.onnx.export(model, # model being run
x, # model input (or a tuple for multiple inputs)
"simple_model.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})
Verifying the ONNX Model
You can verify the exported ONNX model with:
import onnx
# Load the ONNX model
onnx_model = onnx.load("simple_model.onnx")
# Check that the model is well formed
onnx.checker.check_model(onnx_model)
# Print a human-readable representation of the graph
print(onnx.helper.printable_graph(onnx_model.graph))
Using ONNX with Different Frameworks
After exporting to ONNX, you can use the model in various frameworks:
# With ONNX Runtime
import onnxruntime
import numpy as np
ort_session = onnxruntime.InferenceSession("simple_model.onnx")
# Prepare input data (convert PyTorch tensor to NumPy array)
input_data = x.detach().numpy()
# Get output
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outputs = ort_session.run(None, ort_inputs)
# Compare ONNX Runtime and PyTorch results
print("PyTorch output:")
print(model(x))
print("\nONNX Runtime output:")
print(ort_outputs[0])
TorchMobile Export
For deploying to mobile devices, PyTorch offers TorchMobile which allows you to run models on iOS and Android.
Exporting for Mobile
import torch
import torchvision
# Load a pre-trained model
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
# Create example input
example = torch.rand(1, 3, 224, 224)
# Export to TorchScript via tracing
traced_script_module = torch.jit.trace(model, example)
# Save the TorchScript model for mobile
traced_script_module.save("mobilenet_v2.pt")
Custom C++ Export
For high-performance deployments, you might want to use LibTorch, PyTorch's C++ API:
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# Create model and example input
model = SimpleModel()
example = torch.rand(1, 10)
# Export to TorchScript
scripted_model = torch.jit.script(model)
scripted_model.save("model_for_cpp.pt")
You would then use this in C++ like:
#include <torch/script.h>
#include <iostream>
int main() {
// Load the model
torch::jit::script::Module module;
try {
module = torch::jit::load("model_for_cpp.pt");
}
catch (const c10::Error& e) {
std::cerr << "Error loading the model\n";
return -1;
}
// Create a vector of inputs
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 10}));
// Execute the model and get the output
torch::Tensor output = module.forward(inputs).toTensor();
std::cout << output << std::endl;
return 0;
}
Real-World Example: Model Export for Web Deployment
Let's say you've trained an image classification model and want to deploy it on a web application using ONNX.js:
import torch
import torchvision.models as models
import torch.nn as nn
# Load a pre-trained ResNet model
model = models.resnet18(pretrained=True)
model.eval()
# Example input (batch_size, channels, height, width)
dummy_input = torch.randn(1, 3, 224, 224)
# Export to ONNX
torch.onnx.export(
model,
dummy_input,
"resnet18_web.onnx",
export_params=True,
opset_version=11,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print("Model exported to resnet18_web.onnx")
This ONNX model can then be used in a web application with libraries like ONNX.js or TensorFlow.js (with an ONNX converter).
Optimizing Exported Models
After export, you can often optimize your models further:
import onnx
from onnxruntime.transformers import optimizer
# Load the model
model = onnx.load("resnet18_web.onnx")
# Optimize the model
optimized_model = optimizer.optimize_model(
"resnet18_web.onnx",
model_type='resnet',
num_heads=0,
hidden_size=0,
optimization_level=99
)
# Save the optimized model
optimized_model.save_model_to_file("resnet18_web_optimized.onnx")
Summary
In this tutorial, you've learned how to:
- Export PyTorch models to TorchScript for deployment in non-Python environments
- Convert models to the ONNX format for cross-framework compatibility
- Prepare models for mobile deployment using TorchMobile
- Export models for use with C++ applications
- Optimize exported models for better performance
These export methods give you flexibility to deploy your PyTorch models in various production environments, from web servers to mobile devices and edge computing platforms.
Additional Resources
- PyTorch TorchScript Documentation
- ONNX Official Website
- PyTorch Mobile Documentation
- LibTorch C++ API
Exercises
- Take a pre-trained image classification model and export it to both ONNX and TorchScript formats.
- Compare the inference speed of the same model in PyTorch, ONNX Runtime, and TorchScript.
- Create a simple CNN model, export it to ONNX, and visualize the computational graph using Netron (https://netron.app/).
- Try exporting a model with dynamic input shapes to handle different batch sizes.
- Export a PyTorch model to TorchScript and build a simple C++ application to run inference with it.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)