Skip to main content

PyTorch Mobile Deployment

Introduction

PyTorch Mobile is a framework that enables you to deploy PyTorch models on mobile and edge devices. It's designed to provide developers with an end-to-end workflow for transitioning from training on servers to on-device inference. With PyTorch Mobile, you can run your deep learning models directly on smartphones, tablets, and other edge devices without requiring an internet connection, which can significantly improve latency, privacy, and reliability.

In this tutorial, we'll explore:

  • Why you might want to deploy models on mobile devices
  • The PyTorch Mobile workflow
  • How to prepare your PyTorch model for mobile deployment
  • Deploying to Android and iOS platforms
  • Optimization techniques for mobile deployment
  • Real-world examples and best practices

Why Deploy to Mobile?

Before diving into the technical details, let's understand why you might want to deploy your PyTorch models to mobile devices:

  1. Reduced latency: On-device inference eliminates network delays.
  2. Offline functionality: Your application can work without internet connectivity.
  3. Privacy: Sensitive data never leaves the user's device.
  4. Cost efficiency: No need for server infrastructure to handle inference requests.
  5. Better user experience: Faster response times and reduced battery consumption.

PyTorch Mobile Workflow

The typical PyTorch Mobile deployment workflow consists of the following steps:

  1. Train your model: Develop and train your PyTorch model as usual on desktop/server.
  2. Optimize the model: Convert the model to a mobile-friendly format.
  3. Integrate with mobile app: Add the model to your Android or iOS application.
  4. Run inference on device: Execute the model on mobile hardware.

Let's explore each of these steps in detail.

Preparing Your PyTorch Model

Step 1: Train Your Model

First, let's create a simple PyTorch model that we'll prepare for mobile deployment. For this tutorial, we'll use a basic image classification model:

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleConvNet(nn.Module):
def __init__(self):
super(SimpleConvNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10) # 10 classes

def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), 2)
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

Train this model with your dataset as you normally would.

Step 2: Optimize and Convert the Model

PyTorch Mobile uses TorchScript to serialize and optimize models. To convert your model, you'll need to use either tracing or scripting:

Using Tracing

python
# Assuming model is an instance of your trained model
model.eval() # Set to evaluation mode

# Create an example input tensor
example_input = torch.rand(1, 3, 32, 32)

# Trace the model
traced_model = torch.jit.trace(model, example_input)

# Save the traced model
traced_model.save("model_mobile.pt")

Using Scripting

python
# Scripting can handle dynamic control flow better than tracing
scripted_model = torch.jit.script(model)

# Save the scripted model
scripted_model.save("model_mobile.pt")

Step 3: Further Optimization for Mobile

For better performance on mobile devices, you can quantize your model to reduce its size and improve inference speed:

python
import torch.quantization

# Quantize the model
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
)

# Script the quantized model
scripted_quantized_model = torch.jit.script(quantized_model)

# Save the quantized model
scripted_quantized_model.save("model_mobile_quantized.pt")

Deploying to Android

Now that we have our optimized model, let's see how to integrate it with an Android application.

Setting Up Your Android Project

  1. Create a new Android project in Android Studio or open an existing one.

  2. Add PyTorch Mobile dependencies to your app's build.gradle file:

gradle
dependencies {
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
  1. Create an assets folder in your Android project (if it doesn't exist) under app/src/main/ and copy your model_mobile.pt file there.

Loading and Running the Model

Here's how to load and use your PyTorch model in an Android application:

java
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import android.graphics.Bitmap;

public class ModelRunner {
private Module model;
private float[] mean = {0.485f, 0.456f, 0.406f};
private float[] std = {0.229f, 0.224f, 0.225f};

public ModelRunner(Context context) {
try {
// Load the model
model = Module.load(assetFilePath(context, "model_mobile.pt"));
} catch (IOException e) {
Log.e("PyTorchMobile", "Error loading model", e);
}
}

public String classify(Bitmap bitmap) {
// Preprocess the image
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
this.mean,
this.std
);

// Forward pass
Tensor outputTensor = model.forward(IValue.from(inputTensor)).toTensor();

// Get the predicted class
float[] scores = outputTensor.getDataAsFloatArray();
int maxScoreIdx = 0;
float maxScore = -Float.MAX_VALUE;

for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}

String[] classes = {"class1", "class2", "class3", "class4", "class5",
"class6", "class7", "class8", "class9", "class10"};
return classes[maxScoreIdx];
}

// Helper function to copy the model from assets to a file that can be read
private static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}

You would use this class in your Activity or Fragment:

java
ModelRunner modelRunner = new ModelRunner(this);
String prediction = modelRunner.classify(someBitmap);
TextView resultTextView = findViewById(R.id.result_text);
resultTextView.setText("Predicted: " + prediction);

Deploying to iOS

For iOS deployment, the process is similar but uses Swift or Objective-C.

Setting Up Your iOS Project

  1. Create a new iOS project in Xcode or open an existing one.

  2. Add PyTorch Mobile dependencies using CocoaPods. Create a Podfile in your project root:

ruby
target 'YourApp' do
pod 'LibTorch', '~> 1.10.0'
end
  1. Run pod install from the terminal.

  2. Add your model_mobile.pt file to your Xcode project.

Loading and Running the Model

Here's how to use your model in Swift:

swift
import UIKit
import LibTorch

class ViewController: UIViewController {
var module: TorchModule!

override func viewDidLoad() {
super.viewDidLoad()

// Load model
guard let filePath = Bundle.main.path(forResource: "model_mobile", ofType: "pt") else {
print("Failed to find model path")
return
}

do {
module = try TorchModule(fileAtPath: filePath)
} catch {
print("Error loading model: \(error)")
}
}

func classifyImage(_ image: UIImage) -> String {
guard let resizedImage = image.resized(to: CGSize(width: 32, height: 32)) else {
return "Error resizing image"
}

// Convert UIImage to tensor
guard let tensor = resizedImage.toTensor() else {
return "Error converting to tensor"
}

// Run inference
guard let outputs = try? module.predict(tensor) else {
return "Error running model"
}

// Process outputs
let resultsArray = outputs.dataAsFloatArray()
var maxIdx = 0
var maxVal: Float = -Float.greatestFiniteMagnitude

for i in 0..<resultsArray.count {
if resultsArray[i] > maxVal {
maxVal = resultsArray[i]
maxIdx = i
}
}

let classes = ["class1", "class2", "class3", "class4", "class5",
"class6", "class7", "class8", "class9", "class10"]
return classes[maxIdx]
}
}

// Helper extensions
extension UIImage {
func resized(to newSize: CGSize) -> UIImage? {
UIGraphicsBeginImageContextWithOptions(newSize, false, 0.0)
draw(in: CGRect(origin: .zero, size: newSize))
let resizedImage = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return resizedImage
}

func toTensor() -> Tensor? {
// Implementation of UIImage to tensor conversion
// This is simplified; you would need to properly normalize the image
// and convert it to the correct format
// ...
return nil // Placeholder
}
}

Optimization Techniques

Model Quantization

We briefly touched on quantization earlier. Here's a more detailed look:

python
import torch

# Define quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')

# Prepare for quantization
torch.quantization.prepare(model, inplace=True)

# Calibrate the model (run some representative data through it)
# This example assumes you have a dataloader called 'calibration_loader'
with torch.no_grad():
for inputs, _ in calibration_loader:
model(inputs)

# Convert to quantized model
torch.quantization.convert(model, inplace=True)

# Script and save
script_quantized_model = torch.jit.script(model)
script_quantized_model.save("quantized_mobile.pt")

Model Pruning

Pruning removes unnecessary weights to reduce model size:

python
import torch.nn.utils.prune as prune

# Prune 30% of connections in all linear layers
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.l1_unstructured(module, name='weight', amount=0.3)

# Make pruning permanent
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
prune.remove(module, 'weight')

# Now optimize, script, and save as before

Real-World Example: Image Classification App

Let's walk through a complete example of building a simple image classification app using PyTorch Mobile.

1. Train and Export Model

python
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms

# Use a pre-trained MobileNet model
model = models.mobilenet_v2(pretrained=True)
model.eval()

# Example input for tracing
example_input = torch.rand(1, 3, 224, 224)

# Trace and save the model
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("mobilenet_v2.pt")

# Save the labels
with open("labels.txt", "w") as f:
labels = [line.strip() for line in open("imagenet_classes.txt")]
for label in labels:
f.write(label + "\n")

2. Android Implementation

Create a new Android application with a camera view and a result display:

java
// MainActivity.java
public class MainActivity extends AppCompatActivity {
private Button captureButton;
private ImageView imageView;
private TextView resultText;
private ModelRunner modelRunner;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);

captureButton = findViewById(R.id.capture_button);
imageView = findViewById(R.id.image_view);
resultText = findViewById(R.id.result_text);

// Initialize the model
modelRunner = new ModelRunner(this);

captureButton.setOnClickListener(v -> {
// Launch camera intent
dispatchTakePictureIntent();
});
}

// Handle taking a photo
private void dispatchTakePictureIntent() {
Intent takePictureIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
if (takePictureIntent.resolveActivity(getPackageManager()) != null) {
startActivityForResult(takePictureIntent, REQUEST_IMAGE_CAPTURE);
}
}

@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);

if (requestCode == REQUEST_IMAGE_CAPTURE && resultCode == RESULT_OK) {
Bundle extras = data.getExtras();
Bitmap imageBitmap = (Bitmap) extras.get("data");
imageView.setImageBitmap(imageBitmap);

// Run classification
String result = modelRunner.classify(imageBitmap);
resultText.setText("Prediction: " + result);
}
}
}

3. iOS Implementation

Similarly, here's how you might structure the iOS app:

swift
import UIKit

class ViewController: UIViewController, UIImagePickerControllerDelegate, UINavigationControllerDelegate {

@IBOutlet weak var imageView: UIImageView!
@IBOutlet weak var resultLabel: UILabel!

var modelRunner: ModelRunner!

override func viewDidLoad() {
super.viewDidLoad()
modelRunner = ModelRunner()
}

@IBAction func takePhoto(_ sender: Any) {
let picker = UIImagePickerController()
picker.delegate = self
picker.sourceType = .camera
present(picker, animated: true)
}

// UIImagePickerControllerDelegate method
func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) {
picker.dismiss(animated: true)

guard let image = info[.originalImage] as? UIImage else {
return
}

imageView.image = image

// Run classification
let result = modelRunner.classify(image)
resultLabel.text = "Prediction: \(result)"
}
}

Best Practices for Mobile Deployment

  1. Start small: Begin with simpler, smaller models like MobileNet.
  2. Optimize aggressively: Always quantize and consider pruning your models.
  3. Test on real devices: Emulators don't accurately reflect real-world performance.
  4. Profile your app: Use Android Profiler or Instruments (iOS) to identify bottlenecks.
  5. Handle UI responsively: Run inference in a background thread to avoid freezing the UI.
  6. Consider battery impact: Frequent model runs can drain the battery quickly.
  7. Benchmark different optimization techniques: Compare the trade-offs between model size, accuracy, and speed.

Summary

In this tutorial, we've explored PyTorch Mobile deployment from start to finish:

  • Why on-device model deployment is beneficial
  • How to prepare and optimize PyTorch models for mobile deployment
  • Step-by-step guides for Android and iOS integration
  • Advanced optimization techniques like quantization and pruning
  • A real-world example of an image classification app
  • Best practices for mobile deployment

PyTorch Mobile enables you to take your deep learning models out of the data center and into the hands of users. With proper optimization and careful implementation, even complex neural networks can run efficiently on modern smartphones, opening up countless possibilities for AI-enhanced mobile applications.

Additional Resources

Exercises

  1. Convert a pre-trained ResNet model to a mobile-optimized format and measure its size before and after optimization.
  2. Build a simple Android app that classifies images from the device's gallery using a PyTorch model.
  3. Compare the inference speed of a standard model versus its quantized version on a mobile device.
  4. Implement a real-time object detection app using PyTorch Mobile and the device camera.
  5. Create an app that performs style transfer on photos using a PyTorch model deployed on the device.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)