PyTorch Mobile Export
Introduction
PyTorch Mobile is a runtime environment that enables on-device machine learning inference with optimized performance on mobile devices. It allows you to deploy PyTorch models to iOS, Android, and other mobile platforms without requiring complex conversions or separate frameworks. In this tutorial, we'll learn how to export PyTorch models for mobile applications, optimize them for better performance, and implement them in real-world mobile scenarios.
PyTorch Mobile provides several advantages:
- On-device inference: Process data directly on the device, improving privacy and reducing latency
- Offline functionality: Applications can run without internet connectivity
- Reduced server costs: No need for server-side processing for inference
- Lower latency: Eliminates network transmission time
Understanding PyTorch Mobile Workflow
The general workflow for deploying models with PyTorch Mobile includes:
- Developing and training your model in PyTorch
- Optimizing and converting the model for mobile deployment
- Integrating the model into your mobile application
- Running inference on the mobile device
Let's dive into each of these steps.
Preparing Your Model for Mobile Export
Before exporting your model, you need to ensure it's compatible with PyTorch Mobile. Let's start with a simple example:
import torch
import torch.nn as nn
import torchvision.models as models
# Define a simple model (or load your pre-trained model)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(16 * 112 * 112, 10) # Assuming input is 224x224
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Create an instance of our model
model = SimpleModel()
# If you have a trained model, load its state
# model.load_state_dict(torch.load('path_to_weights.pth'))
# Set the model to evaluation mode
model.eval()
Exporting with TorchScript
TorchScript is an intermediate representation of a PyTorch model that can be run in Python and in a C++ environment. There are two ways to convert a model to TorchScript:
- Tracing: Runs example inputs through your model and records the operations
- Scripting: Directly analyzes your model code and converts it to TorchScript
Tracing Your Model
# Create an example input tensor
example_input = torch.rand(1, 3, 224, 224)
# Trace the model with the example input
traced_model = torch.jit.trace(model, example_input)
# Save the traced model
traced_model.save("simple_model_traced.pt")
print("Traced model saved successfully!")
Scripting Your Model
# Script the model (used for models with control flow)
scripted_model = torch.jit.script(model)
# Save the scripted model
scripted_model.save("simple_model_scripted.pt")
print("Scripted model saved successfully!")
Optimizing Your Model for Mobile
To optimize your model for mobile deployment, PyTorch provides several quantization methods to reduce model size and improve inference speed.
Quantization
Quantization reduces the precision of weights from 32-bit floating point to 8-bit integers, significantly reducing model size and speeding up computations.
import torch.quantization
# Define a quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
# Prepare for quantization
model_prepared = torch.quantization.prepare(model)
# Calibrate the model (usually with a calibration dataset)
# For example:
# calibration_data = get_calibration_data()
# for data, _ in calibration_data:
# model_prepared(data)
# Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)
# Export the quantized model
quantized_scripted_model = torch.jit.script(model_quantized)
quantized_scripted_model.save("simple_model_quantized.pt")
print("Quantized model saved successfully!")
Model Size Comparison
Let's compare the sizes of the original model versus the quantized model:
import os
original_size = os.path.getsize("simple_model_scripted.pt") / (1024 * 1024)
quantized_size = os.path.getsize("simple_model_quantized.pt") / (1024 * 1024)
print(f"Original model size: {original_size:.2f} MB")
print(f"Quantized model size: {quantized_size:.2f} MB")
print(f"Size reduction: {100 * (original_size - quantized_size) / original_size:.2f}%")
Example output:
Original model size: 4.35 MB
Quantized model size: 1.18 MB
Size reduction: 72.87%
Additional Optimization Techniques
Here are additional techniques to optimize your model for mobile deployment:
1. Pruning
Pruning removes less important weights to reduce model size:
import torch.nn.utils.prune as prune
# Apply pruning to convolutional layers
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2) # Prune 20% of weights
# Make pruning permanent
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.remove(module, 'weight')
# Export the pruned model
pruned_scripted_model = torch.jit.script(model)
pruned_scripted_model.save("simple_model_pruned.pt")
2. Knowledge Distillation
Train a smaller "student" model to mimic a larger "teacher" model:
class SmallerModel(nn.Module):
def __init__(self):
super(SmallerModel, self).__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1) # Fewer filters
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(8 * 112 * 112, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# Create student model
student_model = SmallerModel()
# Knowledge distillation training would happen here
# ...
# Export the student model
student_scripted_model = torch.jit.script(student_model)
student_scripted_model.save("student_model.pt")
Integrating PyTorch Mobile into Mobile Apps
Now that we have optimized our model, let's see how to integrate it into mobile applications.
Android Integration
First, add PyTorch Mobile to your Android app's build.gradle
:
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
Then, load and use the model in your Java/Kotlin code:
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.torchvision.TensorImageUtils
// Load the model
val module = Module.load(assetFilePath(this, "simple_model_quantized.pt"))
// Prepare input tensor from an image
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
// Run inference
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
// Process the output
val scores = outputTensor.dataAsFloatArray
// Find the class with the highest score
val maxScore = scores.indices.maxByOrNull { scores[it] } ?: -1
println("Predicted class: $maxScore")
iOS Integration
For iOS, add PyTorch Mobile to your Podfile
:
target 'YourApp' do
pod 'LibTorch', '~> 1.10.0'
end
Then, use it in your Swift code:
import LibTorch
// Load the model
guard let modelPath = Bundle.main.path(forResource: "simple_model_quantized", ofType: "pt") else {
fatalError("Failed to find model file")
}
let module = try? TorchModule(fileAtPath: modelPath)
// Process an image (assuming you have a UIImage)
guard let tensor = image.tensorFromImage() else {
fatalError("Failed to convert image to tensor")
}
// Run inference
guard let outputs = try? module.forward([tensor]) else {
fatalError("Failed to run inference")
}
// Process the output
let predictions = outputs[0].data()
// Further processing...
Real-World Application: Image Classification
Let's implement a complete image classification example for Android:
- First, we'll prepare a MobileNetV2 model:
import torch
import torch.nn as nn
import torchvision.models as models
# Load MobileNetV2 pre-trained model
model = models.mobilenet_v2(pretrained=True)
model.eval()
# Example input for tracing
example_input = torch.rand(1, 3, 224, 224)
# Script the model
scripted_model = torch.jit.trace(model, example_input)
# Save the model for mobile deployment
scripted_model.save("mobilenet_v2.pt")
print("MobileNetV2 model saved for mobile deployment!")
- Then, implement Android code for image classification:
// ImageClassifier.kt
class ImageClassifier(context: Context) {
private val module: Module
private val classes: List<String>
init {
// Load model
module = Module.load(assetFilePath(context, "mobilenet_v2.pt"))
// Load labels
val labelReader = BufferedReader(InputStreamReader(
context.assets.open("imagenet_classes.txt")))
classes = labelReader.readLines()
labelReader.close()
}
fun classify(bitmap: Bitmap): Pair<String, Float> {
// Resize and preprocess the image
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true)
// Convert bitmap to tensor
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
resizedBitmap,
floatArrayOf(0.485f, 0.456f, 0.406f),
floatArrayOf(0.229f, 0.224f, 0.225f)
)
// Run inference
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
// Find the best class
var maxScore = -Float.MAX_VALUE
var maxScoreIdx = -1
for (i in scores.indices) {
if (scores[i] > maxScore) {
maxScore = scores[i]
maxScoreIdx = i
}
}
return Pair(classes[maxScoreIdx], maxScore)
}
// Helper function to get asset path
private fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}
context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
}
return file.absolutePath
}
}
- Using the classifier in an Activity:
// MainActivity.kt
class MainActivity : AppCompatActivity() {
private lateinit var classifier: ImageClassifier
private lateinit var imageView: ImageView
private lateinit var resultTextView: TextView
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
imageView = findViewById(R.id.imageView)
resultTextView = findViewById(R.id.resultTextView)
classifier = ImageClassifier(this)
val selectImageButton: Button = findViewById(R.id.selectImageButton)
selectImageButton.setOnClickListener {
openGallery()
}
}
private fun openGallery() {
val intent = Intent(Intent.ACTION_PICK)
intent.type = "image/*"
startActivityForResult(intent, 100)
}
override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if (requestCode == 100 && resultCode == RESULT_OK) {
val imageUri = data?.data
val bitmap = MediaStore.Images.Media.getBitmap(contentResolver, imageUri)
imageView.setImageBitmap(bitmap)
// Run classification in a background thread
Thread {
val (className, score) = classifier.classify(bitmap)
runOnUiThread {
resultTextView.text = "Class: $className\nConfidence: ${String.format("%.2f", score)}"
}
}.start()
}
}
}
Troubleshooting Common Issues
When working with PyTorch Mobile, you might encounter these common issues:
1. Unsupported Operators
Not all PyTorch operators are supported in PyTorch Mobile. If you encounter this error:
RuntimeError: Unsupported operator aten::xxx_yyy
Possible solutions:
- Check the list of supported operators
- Replace unsupported operations with supported alternatives
- Use
torch.jit.trace
instead oftorch.jit.script
when possible
2. Model Size Issues
If your model is too large:
- Apply more aggressive quantization
- Use pruning techniques
- Consider model architecture with fewer parameters
- Use MobileNet, EfficientNet, or other mobile-optimized architectures
3. Performance Issues
If inference is too slow:
- Use the Android/iOS profiler to identify bottlenecks
- Ensure you're using the latest PyTorch Mobile version
- Consider batch processing if appropriate
- Reduce input resolution if possible without compromising accuracy
Summary
In this tutorial, we've covered how to export PyTorch models for mobile deployment using PyTorch Mobile:
- We learned how to convert models to TorchScript using tracing and scripting
- We explored optimization techniques like quantization and pruning to reduce model size
- We saw how to integrate PyTorch models into Android and iOS applications
- We built a complete image classification example for real-world use
PyTorch Mobile enables on-device inference with several benefits including improved privacy, offline functionality, and reduced latency. By properly optimizing your models, you can achieve excellent performance even on resource-constrained mobile devices.
Additional Resources
- PyTorch Mobile Official Documentation
- PyTorch Mobile Demo Applications
- PyTorch iOS Getting Started Guide
- PyTorch Android Getting Started Guide
- Model Optimization Techniques
Exercises
-
Convert a pre-trained ResNet18 model to a mobile-optimized format and measure the size difference before and after optimization.
-
Implement a simple image classification app using PyTorch Mobile that can identify different types of flowers or animals.
-
Experiment with different quantization techniques and compare their impact on model size and inference speed.
-
Create a real-time object detection application using a mobile-optimized model like SSD-MobileNet.
-
Build a text classification model and deploy it to a mobile app for sentiment analysis of user input text.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)