Skip to main content

PyTorch Mobile Export

Introduction

PyTorch Mobile is a runtime environment that enables on-device machine learning inference with optimized performance on mobile devices. It allows you to deploy PyTorch models to iOS, Android, and other mobile platforms without requiring complex conversions or separate frameworks. In this tutorial, we'll learn how to export PyTorch models for mobile applications, optimize them for better performance, and implement them in real-world mobile scenarios.

PyTorch Mobile provides several advantages:

  • On-device inference: Process data directly on the device, improving privacy and reducing latency
  • Offline functionality: Applications can run without internet connectivity
  • Reduced server costs: No need for server-side processing for inference
  • Lower latency: Eliminates network transmission time

Understanding PyTorch Mobile Workflow

The general workflow for deploying models with PyTorch Mobile includes:

  1. Developing and training your model in PyTorch
  2. Optimizing and converting the model for mobile deployment
  3. Integrating the model into your mobile application
  4. Running inference on the mobile device

Let's dive into each of these steps.

Preparing Your Model for Mobile Export

Before exporting your model, you need to ensure it's compatible with PyTorch Mobile. Let's start with a simple example:

python
import torch
import torch.nn as nn
import torchvision.models as models

# Define a simple model (or load your pre-trained model)
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(16 * 112 * 112, 10) # Assuming input is 224x224

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

# Create an instance of our model
model = SimpleModel()

# If you have a trained model, load its state
# model.load_state_dict(torch.load('path_to_weights.pth'))

# Set the model to evaluation mode
model.eval()

Exporting with TorchScript

TorchScript is an intermediate representation of a PyTorch model that can be run in Python and in a C++ environment. There are two ways to convert a model to TorchScript:

  1. Tracing: Runs example inputs through your model and records the operations
  2. Scripting: Directly analyzes your model code and converts it to TorchScript

Tracing Your Model

python
# Create an example input tensor
example_input = torch.rand(1, 3, 224, 224)

# Trace the model with the example input
traced_model = torch.jit.trace(model, example_input)

# Save the traced model
traced_model.save("simple_model_traced.pt")

print("Traced model saved successfully!")

Scripting Your Model

python
# Script the model (used for models with control flow)
scripted_model = torch.jit.script(model)

# Save the scripted model
scripted_model.save("simple_model_scripted.pt")

print("Scripted model saved successfully!")

Optimizing Your Model for Mobile

To optimize your model for mobile deployment, PyTorch provides several quantization methods to reduce model size and improve inference speed.

Quantization

Quantization reduces the precision of weights from 32-bit floating point to 8-bit integers, significantly reducing model size and speeding up computations.

python
import torch.quantization

# Define a quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')

# Prepare for quantization
model_prepared = torch.quantization.prepare(model)

# Calibrate the model (usually with a calibration dataset)
# For example:
# calibration_data = get_calibration_data()
# for data, _ in calibration_data:
# model_prepared(data)

# Convert to quantized model
model_quantized = torch.quantization.convert(model_prepared)

# Export the quantized model
quantized_scripted_model = torch.jit.script(model_quantized)
quantized_scripted_model.save("simple_model_quantized.pt")

print("Quantized model saved successfully!")

Model Size Comparison

Let's compare the sizes of the original model versus the quantized model:

python
import os

original_size = os.path.getsize("simple_model_scripted.pt") / (1024 * 1024)
quantized_size = os.path.getsize("simple_model_quantized.pt") / (1024 * 1024)

print(f"Original model size: {original_size:.2f} MB")
print(f"Quantized model size: {quantized_size:.2f} MB")
print(f"Size reduction: {100 * (original_size - quantized_size) / original_size:.2f}%")

Example output:

Original model size: 4.35 MB
Quantized model size: 1.18 MB
Size reduction: 72.87%

Additional Optimization Techniques

Here are additional techniques to optimize your model for mobile deployment:

1. Pruning

Pruning removes less important weights to reduce model size:

python
import torch.nn.utils.prune as prune

# Apply pruning to convolutional layers
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.l1_unstructured(module, name='weight', amount=0.2) # Prune 20% of weights

# Make pruning permanent
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.remove(module, 'weight')

# Export the pruned model
pruned_scripted_model = torch.jit.script(model)
pruned_scripted_model.save("simple_model_pruned.pt")

2. Knowledge Distillation

Train a smaller "student" model to mimic a larger "teacher" model:

python
class SmallerModel(nn.Module):
def __init__(self):
super(SmallerModel, self).__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=3, padding=1) # Fewer filters
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Linear(8 * 112 * 112, 10)

def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.maxpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

# Create student model
student_model = SmallerModel()

# Knowledge distillation training would happen here
# ...

# Export the student model
student_scripted_model = torch.jit.script(student_model)
student_scripted_model.save("student_model.pt")

Integrating PyTorch Mobile into Mobile Apps

Now that we have optimized our model, let's see how to integrate it into mobile applications.

Android Integration

First, add PyTorch Mobile to your Android app's build.gradle:

gradle
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}

Then, load and use the model in your Java/Kotlin code:

kotlin
import org.pytorch.IValue
import org.pytorch.Module
import org.pytorch.torchvision.TensorImageUtils

// Load the model
val module = Module.load(assetFilePath(this, "simple_model_quantized.pt"))

// Prepare input tensor from an image
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
)

// Run inference
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()

// Process the output
val scores = outputTensor.dataAsFloatArray
// Find the class with the highest score
val maxScore = scores.indices.maxByOrNull { scores[it] } ?: -1
println("Predicted class: $maxScore")

iOS Integration

For iOS, add PyTorch Mobile to your Podfile:

ruby
target 'YourApp' do
pod 'LibTorch', '~> 1.10.0'
end

Then, use it in your Swift code:

swift
import LibTorch

// Load the model
guard let modelPath = Bundle.main.path(forResource: "simple_model_quantized", ofType: "pt") else {
fatalError("Failed to find model file")
}
let module = try? TorchModule(fileAtPath: modelPath)

// Process an image (assuming you have a UIImage)
guard let tensor = image.tensorFromImage() else {
fatalError("Failed to convert image to tensor")
}

// Run inference
guard let outputs = try? module.forward([tensor]) else {
fatalError("Failed to run inference")
}

// Process the output
let predictions = outputs[0].data()
// Further processing...

Real-World Application: Image Classification

Let's implement a complete image classification example for Android:

  1. First, we'll prepare a MobileNetV2 model:
python
import torch
import torch.nn as nn
import torchvision.models as models

# Load MobileNetV2 pre-trained model
model = models.mobilenet_v2(pretrained=True)
model.eval()

# Example input for tracing
example_input = torch.rand(1, 3, 224, 224)

# Script the model
scripted_model = torch.jit.trace(model, example_input)

# Save the model for mobile deployment
scripted_model.save("mobilenet_v2.pt")

print("MobileNetV2 model saved for mobile deployment!")
  1. Then, implement Android code for image classification:
kotlin
// ImageClassifier.kt
class ImageClassifier(context: Context) {
private val module: Module
private val classes: List<String>

init {
// Load model
module = Module.load(assetFilePath(context, "mobilenet_v2.pt"))

// Load labels
val labelReader = BufferedReader(InputStreamReader(
context.assets.open("imagenet_classes.txt")))
classes = labelReader.readLines()
labelReader.close()
}

fun classify(bitmap: Bitmap): Pair<String, Float> {
// Resize and preprocess the image
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, true)

// Convert bitmap to tensor
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
resizedBitmap,
floatArrayOf(0.485f, 0.456f, 0.406f),
floatArrayOf(0.229f, 0.224f, 0.225f)
)

// Run inference
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray

// Find the best class
var maxScore = -Float.MAX_VALUE
var maxScoreIdx = -1
for (i in scores.indices) {
if (scores[i] > maxScore) {
maxScore = scores[i]
maxScoreIdx = i
}
}

return Pair(classes[maxScoreIdx], maxScore)
}

// Helper function to get asset path
private fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}

context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
}
return file.absolutePath
}
}
  1. Using the classifier in an Activity:
kotlin
// MainActivity.kt
class MainActivity : AppCompatActivity() {
private lateinit var classifier: ImageClassifier
private lateinit var imageView: ImageView
private lateinit var resultTextView: TextView

override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)

imageView = findViewById(R.id.imageView)
resultTextView = findViewById(R.id.resultTextView)

classifier = ImageClassifier(this)

val selectImageButton: Button = findViewById(R.id.selectImageButton)
selectImageButton.setOnClickListener {
openGallery()
}
}

private fun openGallery() {
val intent = Intent(Intent.ACTION_PICK)
intent.type = "image/*"
startActivityForResult(intent, 100)
}

override fun onActivityResult(requestCode: Int, resultCode: Int, data: Intent?) {
super.onActivityResult(requestCode, resultCode, data)
if (requestCode == 100 && resultCode == RESULT_OK) {
val imageUri = data?.data
val bitmap = MediaStore.Images.Media.getBitmap(contentResolver, imageUri)
imageView.setImageBitmap(bitmap)

// Run classification in a background thread
Thread {
val (className, score) = classifier.classify(bitmap)
runOnUiThread {
resultTextView.text = "Class: $className\nConfidence: ${String.format("%.2f", score)}"
}
}.start()
}
}
}

Troubleshooting Common Issues

When working with PyTorch Mobile, you might encounter these common issues:

1. Unsupported Operators

Not all PyTorch operators are supported in PyTorch Mobile. If you encounter this error:

RuntimeError: Unsupported operator aten::xxx_yyy

Possible solutions:

  • Check the list of supported operators
  • Replace unsupported operations with supported alternatives
  • Use torch.jit.trace instead of torch.jit.script when possible

2. Model Size Issues

If your model is too large:

  • Apply more aggressive quantization
  • Use pruning techniques
  • Consider model architecture with fewer parameters
  • Use MobileNet, EfficientNet, or other mobile-optimized architectures

3. Performance Issues

If inference is too slow:

  • Use the Android/iOS profiler to identify bottlenecks
  • Ensure you're using the latest PyTorch Mobile version
  • Consider batch processing if appropriate
  • Reduce input resolution if possible without compromising accuracy

Summary

In this tutorial, we've covered how to export PyTorch models for mobile deployment using PyTorch Mobile:

  1. We learned how to convert models to TorchScript using tracing and scripting
  2. We explored optimization techniques like quantization and pruning to reduce model size
  3. We saw how to integrate PyTorch models into Android and iOS applications
  4. We built a complete image classification example for real-world use

PyTorch Mobile enables on-device inference with several benefits including improved privacy, offline functionality, and reduced latency. By properly optimizing your models, you can achieve excellent performance even on resource-constrained mobile devices.

Additional Resources

Exercises

  1. Convert a pre-trained ResNet18 model to a mobile-optimized format and measure the size difference before and after optimization.

  2. Implement a simple image classification app using PyTorch Mobile that can identify different types of flowers or animals.

  3. Experiment with different quantization techniques and compare their impact on model size and inference speed.

  4. Create a real-time object detection application using a mobile-optimized model like SSD-MobileNet.

  5. Build a text classification model and deploy it to a mobile app for sentiment analysis of user input text.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)