Skip to main content

PyTorch Model Serving

After training a powerful deep learning model with PyTorch, the next crucial step is deploying it so that others can use it. Model serving is the process of making your trained model available to end-users or applications through an API or service.

Introduction to Model Serving

Model serving bridges the gap between developing a machine learning model and making it useful in real-world applications. It involves:

  • Exposing your model through an API that applications can call
  • Managing model versions and updates
  • Handling scaling and performance optimization
  • Processing input data and returning predictions efficiently

In this guide, we'll explore different ways to serve PyTorch models, from simple solutions for beginners to more advanced production-ready approaches.

Method 1: Simple Flask API

For beginners, Flask provides an easy way to serve your PyTorch model through a REST API.

Setting Up a Flask API

First, let's install the necessary packages:

bash
pip install flask torch

Here's a basic example of serving a PyTorch model with Flask:

python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io

app = Flask(__name__)

# Load the pretrained model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()

def transform_image(image_bytes):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(io.BytesIO(image_bytes))
return transform(image).unsqueeze(0)

def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model(tensor)
_, predicted = torch.max(outputs.data, 1)
return predicted.item()

@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file provided'})

file = request.files['file']
img_bytes = file.read()
prediction = get_prediction(img_bytes)

return jsonify({'prediction': prediction})

if __name__ == '__main__':
app.run(debug=True)

Using the Flask API

You can test this API using tools like cURL or Python requests:

python
import requests

resp = requests.post("http://localhost:5000/predict",
files={"file": open("dog.jpg", "rb")})

print(resp.json())
# Example output: {"prediction": 258}

Pros and Cons of Flask

Pros:

  • Easy to set up and understand
  • Good for prototyping and small-scale applications
  • Flexible and customizable

Cons:

  • Limited scalability
  • Manual implementation of features like load balancing, versioning
  • Not optimized for high-performance serving

Method 2: TorchServe

TorchServe is a flexible and easy-to-use tool for serving PyTorch models, developed by AWS in collaboration with Meta.

Setting Up TorchServe

First, install TorchServe and its dependencies:

bash
pip install torchserve torch-model-archiver torch-workflow-archiver

Creating a Model Archive

TorchServe requires your model to be packaged as a model archive (.mar) file:

python
# First, save your model
import torch
import torchvision.models as models

# Example: using a pre-trained model
model = models.resnet18(pretrained=True)
model.eval()
torch.save(model.state_dict(), "resnet18.pth")

Next, create a model_handler.py file that defines how to load and use your model:

python
# model_handler.py
from ts.torch_handler.image_classifier import ImageClassifier

class ResnetImageClassifier(ImageClassifier):
"""
Custom handler for Resnet18 image classification
"""
def __init__(self):
super(ResnetImageClassifier, self).__init__()

def preprocess(self, data):
"""
Overriding preprocess to add custom preprocessing
"""
return super(ResnetImageClassifier, self).preprocess(data)

Now, use the torch-model-archiver command to create a model archive:

bash
torch-model-archiver --model-name resnet18 \
--version 1.0 \
--model-file model_handler.py \
--serialized-file resnet18.pth \
--handler image_classifier \
--export-path model_store

Starting TorchServe

Start the TorchServe server:

bash
torchserve --start --model-store model_store --models resnet18=resnet18.mar

Making Predictions

You can now make predictions using the TorchServe REST API:

bash
curl -X POST http://localhost:8080/predictions/resnet18 -T dog.jpg

Or using Python:

python
import requests

with open("dog.jpg", "rb") as f:
img = f.read()

resp = requests.post("http://localhost:8080/predictions/resnet18",
data=img,
headers={"Content-Type": "application/octet-stream"})

print(resp.json())
# Example output: {"class": "Labrador retriever", "confidence": 0.934}

TorchServe Benefits

  • Management API for model registration/unregistration
  • Metrics API for monitoring
  • Built-in model versioning and A/B testing
  • Optimized for PyTorch models
  • Scalable with multi-model serving

Method 3: ONNX Runtime

ONNX (Open Neural Network Exchange) is an open format for representing machine learning models, allowing models to be transferred between different frameworks.

Converting PyTorch Model to ONNX

python
import torch
import torchvision.models as models

# Load a pretrained model
model = models.resnet18(pretrained=True)
model.eval()

# Create a random input tensor
dummy_input = torch.randn(1, 3, 224, 224)

# Export to ONNX format
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"resnet18.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # optimize model
input_names=['input'], # model's input names
output_names=['output'], # model's output names
dynamic_axes={
'input': {0: 'batch_size'}, # dynamic batch size
'output': {0: 'batch_size'}
}
)
print("Model exported to ONNX format!")

Serving with ONNX Runtime

You can serve ONNX models using ONNX Runtime, which is optimized for performance:

python
from flask import Flask, request, jsonify
import onnxruntime
import numpy as np
from PIL import Image
import io
import torchvision.transforms as transforms

app = Flask(__name__)

# Load the ONNX model
session = onnxruntime.InferenceSession("resnet18.onnx")

def transform_image(image_bytes):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(io.BytesIO(image_bytes))
return transform(image).numpy()

@app.route('/predict', methods=['POST'])
def predict():
file = request.files['file']
img_bytes = file.read()
input_tensor = transform_image(img_bytes)

# Get input and output names
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# Run inference
results = session.run([output_name], {input_name: input_tensor.reshape(1, 3, 224, 224)})
prediction = np.argmax(results[0])

return jsonify({'prediction': int(prediction)})

if __name__ == '__main__':
app.run(debug=True)

Benefits of ONNX

  • Framework-agnostic (works with PyTorch, TensorFlow, etc.)
  • Performance optimizations for different hardware
  • Simplified deployment across different platforms
  • Reduced model size and improved inference speed

Method 4: Cloud-based Services

For production applications, cloud services offer managed solutions for model serving.

Using AWS SageMaker

Amazon SageMaker provides a complete solution for deploying PyTorch models:

python
import torch
import sagemaker
from sagemaker.pytorch import PyTorchModel

# Set up the SageMaker session
session = sagemaker.Session()
role = sagemaker.get_execution_role()

# Create a PyTorch model from your saved model.tar.gz file
pytorch_model = PyTorchModel(
model_data='s3://your-bucket/model.tar.gz',
role=role,
framework_version='1.8.1',
py_version='py3',
entry_point='inference.py' # Script containing load_model and predict functions
)

# Deploy the model
predictor = pytorch_model.deploy(
initial_instance_count=1,
instance_type='ml.m5.xlarge'
)

# Make predictions
response = predictor.predict(input_data)

Other Cloud Solutions

  • Google Cloud AI Platform: Offers custom prediction routines for PyTorch models
  • Azure Machine Learning: Provides model deployment to Azure Kubernetes Service or Azure Container Instances
  • Hugging Face Inference API: Great for NLP models built with PyTorch

Performance Optimization Techniques

When serving models in production, consider these optimization techniques:

1. Quantization

Reduce model size and improve inference speed by quantizing model weights:

python
import torch

# Load your model
model = torch.load("your_model.pth")

# Quantize the model
quantized_model = torch.quantization.quantize_dynamic(
model, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8 # the target dtype for quantized weights
)

# Save the quantized model
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

2. Batching Requests

Process multiple requests together for better GPU utilization:

python
def batch_process(model, input_batch):
"""Process a batch of inputs at once"""
with torch.no_grad():
return model(input_batch)

3. Model Pruning

Remove unnecessary weights to reduce model size:

python
import torch.nn.utils.prune as prune

# Prune 30% of the weights in a layer
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)

Real-World Example: Serving an Image Classification Model

Let's put everything together with a complete example of serving a ResNet model for image classification using TorchServe:

1. Prepare the model

python
import torch
import torchvision.models as models

# Load a pretrained ResNet model
model = models.resnet50(pretrained=True)
model.eval()

# Save the model
torch.save(model.state_dict(), "resnet50.pth")

2. Create a custom handler file (resnet_handler.py):

python
from ts.torch_handler.image_classifier import ImageClassifier
import json
import os
import logging

logger = logging.getLogger(__name__)

class ResnetClassifier(ImageClassifier):
"""
Custom handler for ResNet image classification
"""
def __init__(self):
super(ResnetClassifier, self).__init__()
self.class_names = None

def initialize(self, context):
"""
Initialize model and load class names
"""
super(ResnetClassifier, self).initialize(context)

# Load the mapping of classes
with open('index_to_name.json', 'r') as f:
self.class_names = json.load(f)

logger.info("Model initialized with class names")

def postprocess(self, data):
"""
Add human-readable class names to output
"""
result = super(ResnetClassifier, self).postprocess(data)

# Add class names to the output
for idx, item in enumerate(result):
class_idx = item.get("prediction")
if class_idx is not None and self.class_names:
item["class"] = self.class_names.get(str(class_idx), "Unknown")

return result

3. Create a mapping file for class names (index_to_name.json):

json
{
"0": "tench",
"1": "goldfish",
"2": "great_white_shark",
"...": "..."
}

4. Archive and serve the model:

bash
# Create the model archive
torch-model-archiver --model-name resnet50 \
--version 1.0 \
--model-file resnet_handler.py \
--serialized-file resnet50.pth \
--handler image_classifier \
--extra-files index_to_name.json \
--export-path model_store

# Start TorchServe
torchserve --start --model-store model_store --models resnet50=resnet50.mar

5. Use the deployed model:

python
import requests
from PIL import Image
import io
import json

# Load and send an image
with open("cat.jpg", "rb") as f:
img_bytes = f.read()

resp = requests.post("http://localhost:8080/predictions/resnet50",
data=img_bytes,
headers={"Content-Type": "application/octet-stream"})

prediction = resp.json()
print(f"The image is classified as: {prediction['class']}")
print(f"Confidence: {prediction['confidence']:.4f}")

Summary

Model serving is the critical final step in the machine learning lifecycle, making your PyTorch models available for real-world use. In this guide, we've covered:

  1. Simple Flask API - Great for prototypes and learning
  2. TorchServe - PyTorch's dedicated serving solution
  3. ONNX Runtime - Framework-agnostic deployment with optimizations
  4. Cloud Services - Managed solutions for enterprise-scale deployment
  5. Optimization Techniques - Methods to improve serving performance

When choosing a serving solution, consider factors like:

  • Scale and expected traffic
  • Latency requirements
  • Need for versioning and monitoring
  • Available infrastructure and budget
  • Development team expertise

Additional Resources

Exercises

  1. Deploy a simple image classification model using Flask and test it with different images.
  2. Use TorchServe to deploy a pre-trained PyTorch model from torchvision.
  3. Convert a custom PyTorch model to ONNX format and compare the inference speed.
  4. Implement request batching in a Flask API to improve throughput.
  5. Add model monitoring to track prediction accuracy and latency over time.

By mastering model serving, you complete the machine learning workflow—from development to deployment—and enable your PyTorch models to provide value in real-world applications.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)