PyTorch TorchScript
Introduction
TorchScript is a way to create serializable and optimizable models from PyTorch code. It allows you to save models that can be loaded in environments where Python is not available, enabling deployment in production environments such as C++. TorchScript provides a transition between eager mode development and graph-based optimization and deployment.
In this tutorial, we'll explore:
- What TorchScript is and why it's useful
- How to convert PyTorch models to TorchScript using tracing and scripting
- How to save and load TorchScript models
- Common use cases and best practices
What is TorchScript?
TorchScript is an intermediate representation of a PyTorch model that can be run in a high-performance environment like C++. It's essentially a way to serialize your PyTorch models so they can be used outside of Python.
There are two ways to convert a PyTorch model to TorchScript:
- Tracing: Runs example inputs through your model and records the operations
- Scripting: Directly analyzes your Python code and converts it to TorchScript
Converting Models with Tracing
Tracing works by recording operations as you execute your model with example inputs. This is the simplest way to convert a model, but has limitations with control flow.
Let's create a simple model and trace it:
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(100, 50)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Create an instance of the model
model = SimpleModel()
# Create example input
example_input = torch.rand(1, 100)
# Trace the model
traced_model = torch.jit.trace(model, example_input)
print(traced_model)
Output:
TracedModule[SimpleModel](
(fc1): TracedModule[Linear](...)
(relu): TracedModule[ReLU](...)
(fc2): TracedModule[Linear](...)
)
The traced model can now be saved and loaded without Python dependencies:
# Save the traced model
traced_model.save("traced_model.pt")
# Load the traced model
loaded_model = torch.jit.load("traced_model.pt")
# Use the loaded model for inference
test_input = torch.rand(1, 100)
output = loaded_model(test_input)
print(output.shape) # Should print torch.Size([1, 10])
Limitations of Tracing
Tracing has some limitations you should be aware of:
-
Control flow is not captured: If your model has if-statements or loops that depend on the input data, tracing will only record operations for the example input provided.
-
Dynamic operations: Operations whose behavior changes across runs won't be captured correctly.
Let's see an example where tracing might fail:
class ModelWithControlFlow(nn.Module):
def __init__(self):
super(ModelWithControlFlow, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
if x.sum() > 0:
return self.fc(x)
else:
return x
# This model can't be correctly traced because the control flow
# depends on the input data
Converting Models with Scripting
Scripting directly analyzes your Python code and converts it to TorchScript. It can handle control flow better than tracing:
import torch
import torch.nn as nn
# Define a model with control flow
class ModelWithControlFlow(nn.Module):
def __init__(self):
super(ModelWithControlFlow, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
if x.sum() > 0:
return self.fc(x)
else:
return x
# Create an instance and script it
model = ModelWithControlFlow()
scripted_model = torch.jit.script(model)
print(scripted_model)
Output:
ScriptModule(
(fc): ScriptModule(...)
)
The scripted model preserves control flow constructs from your Python code:
# Save the scripted model
scripted_model.save("scripted_model.pt")
# Load the scripted model
loaded_model = torch.jit.load("scripted_model.pt")
# Test with different inputs
positive_input = torch.ones(1, 10)
negative_input = -torch.ones(1, 10)
print("Positive input output shape:", loaded_model(positive_input).shape)
print("Negative input output shape:", loaded_model(negative_input).shape)
Limitations of Scripting
Scripting also has limitations:
- Python features: Not all Python language features are supported.
- Dynamic typing: TorchScript requires more explicit typing than Python.
- External libraries: You can't directly call arbitrary Python functions.
Combining Tracing and Scripting
You can combine both approaches for complex models. For example, you might script a model with control flow but trace individual submodules:
import torch
import torch.nn as nn
# A module we'll trace
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.fc = nn.Linear(10, 10)
def forward(self, x):
return self.fc(x)
# The main module with control flow
class MainModule(nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.submodule = SubModule()
self.another_fc = nn.Linear(10, 5)
def forward(self, x):
# Trace the submodule
if x.sum() > 0:
x = self.submodule(x)
return self.another_fc(x)
else:
return x
# Create the module
model = MainModule()
# Trace the submodule first
submodule_traced = torch.jit.trace(model.submodule, torch.rand(1, 10))
model.submodule = submodule_traced
# Now script the whole model
scripted_model = torch.jit.script(model)
# Save the model
scripted_model.save("hybrid_model.pt")
Real-World Use Case: Deployment in C++
One of the main benefits of TorchScript is deployment in C++ environments. Here's a simplified example of how you might use a TorchScript model in C++:
#include <torch/script.h>
#include <iostream>
int main() {
// Load the model
torch::jit::script::Module module;
try {
module = torch::jit::load("traced_model.pt");
}
catch (const c10::Error& e) {
std::cerr << "Error loading the model\n";
return -1;
}
// Create input tensor
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 100}));
// Execute the model
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.sizes() << std::endl;
return 0;
}
This C++ code loads our TorchScript model and runs inference with it - all without requiring Python.
Real-World Use Case: Mobile Deployment
TorchScript is also useful for deploying models to mobile devices:
# Create a model for mobile deployment
class MobileModel(nn.Module):
def __init__(self):
super(MobileModel, self).__init__()
# Use lightweight layers for mobile
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(16, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Create model instance
mobile_model = MobileModel()
# Convert to TorchScript
example_input = torch.rand(1, 3, 224, 224)
traced_mobile_model = torch.jit.trace(mobile_model, example_input)
# Optimize for mobile (quantize the model)
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_model = optimize_for_mobile(traced_mobile_model)
# Save for mobile deployment
optimized_model.save("mobile_model.pt")
This model can then be loaded in Android or iOS applications using the PyTorch Mobile libraries.
Tips and Best Practices
-
Start with tracing: It's simpler and works for most straightforward models.
-
Use scripting for control flow: If your model has conditional logic, use scripting instead.
-
Test your models: Always test TorchScript models with various inputs to verify they behave the same as the original models:
def test_script_model(orig_model, script_model, test_input):
orig_output = orig_model(test_input)
script_output = script_model(test_input)
# Check if outputs are close
if torch.allclose(orig_output, script_output):
print("✅ Model outputs match!")
else:
print("❌ Model outputs differ!")
print(f"Original: {orig_output}")
print(f"Scripted: {script_output}")
- Annotations for scripting: For complex models, add type annotations to help the TorchScript compiler:
import torch
import torch.nn as nn
from typing import Dict, List, Tuple
class TypedModel(nn.Module):
def forward(self, x: torch.Tensor, indices: List[int]) -> Dict[str, torch.Tensor]:
result = {"output": torch.zeros_like(x)}
for i in indices:
result["output"] += x * i
return result
# Script the model with proper annotations
typed_model = TypedModel()
scripted_typed_model = torch.jit.script(typed_model)
Summary
TorchScript provides a powerful way to transition PyTorch models from research to production:
- Tracing captures operations by running example inputs through your model
- Scripting directly analyzes Python code to handle more complex models with control flow
- TorchScript models can be deployed in C++, mobile apps, and other non-Python environments
- It allows for optimization not possible in eager mode execution
TorchScript bridges the gap between PyTorch's flexible development experience and the requirements of production deployment.
Additional Resources
Exercises
-
Convert a pre-trained model (like ResNet) to TorchScript and compare inference times between the original and TorchScript versions.
-
Create a model with complex control flow and test both tracing and scripting approaches. Which one works better?
-
Practice optimizing a TorchScript model for mobile deployment using
optimize_for_mobile()
. -
Try writing a simple C++ program that loads a TorchScript model and runs inference (if you have a C++ environment available).
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)