PyTorch Model Serving
After training a powerful deep learning model with PyTorch, the next crucial step is deploying it so that others can use it. Model serving is the process of making your trained model available to end-users or applications through an API or service.
Introduction to Model Serving
Model serving bridges the gap between developing a machine learning model and making it useful in real-world applications. It involves:
- Exposing your model through an API that applications can call
- Managing model versions and updates
- Handling scaling and performance optimization
- Processing input data and returning predictions efficiently
In this guide, we'll explore different ways to serve PyTorch models, from simple solutions for beginners to more advanced production-ready approaches.
Method 1: Simple Flask API
For beginners, Flask provides an easy way to serve your PyTorch model through a REST API.
Setting Up a Flask API
First, let's install the necessary packages:
pip install flask torch
Here's a basic example of serving a PyTorch model with Flask:
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
app = Flask(__name__)
# Load the pretrained model
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model.eval()
def transform_image(image_bytes):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(io.BytesIO(image_bytes))
return transform(image).unsqueeze(0)
def get_prediction(image_bytes):
tensor = transform_image(image_bytes)
outputs = model(tensor)
_, predicted = torch.max(outputs.data, 1)
return predicted.item()
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file provided'})
file = request.files['file']
img_bytes = file.read()
prediction = get_prediction(img_bytes)
return jsonify({'prediction': prediction})
if __name__ == '__main__':
app.run(debug=True)
Using the Flask API
You can test this API using tools like cURL or Python requests:
import requests
resp = requests.post("http://localhost:5000/predict",
files={"file": open("dog.jpg", "rb")})
print(resp.json())
# Example output: {"prediction": 258}
Pros and Cons of Flask
Pros:
- Easy to set up and understand
- Good for prototyping and small-scale applications
- Flexible and customizable
Cons:
- Limited scalability
- Manual implementation of features like load balancing, versioning
- Not optimized for high-performance serving
Method 2: TorchServe
TorchServe is a flexible and easy-to-use tool for serving PyTorch models, developed by AWS in collaboration with Meta.
Setting Up TorchServe
First, install TorchServe and its dependencies:
pip install torchserve torch-model-archiver torch-workflow-archiver
Creating a Model Archive
TorchServe requires your model to be packaged as a model archive (.mar) file:
# First, save your model
import torch
import torchvision.models as models
# Example: using a pre-trained model
model = models.resnet18(pretrained=True)
model.eval()
torch.save(model.state_dict(), "resnet18.pth")
Next, create a model_handler.py
file that defines how to load and use your model:
# model_handler.py
from ts.torch_handler.image_classifier import ImageClassifier
class ResnetImageClassifier(ImageClassifier):
"""
Custom handler for Resnet18 image classification
"""
def __init__(self):
super(ResnetImageClassifier, self).__init__()
def preprocess(self, data):
"""
Overriding preprocess to add custom preprocessing
"""
return super(ResnetImageClassifier, self).preprocess(data)
Now, use the torch-model-archiver
command to create a model archive:
torch-model-archiver --model-name resnet18 \
--version 1.0 \
--model-file model_handler.py \
--serialized-file resnet18.pth \
--handler image_classifier \
--export-path model_store
Starting TorchServe
Start the TorchServe server:
torchserve --start --model-store model_store --models resnet18=resnet18.mar
Making Predictions
You can now make predictions using the TorchServe REST API:
curl -X POST http://localhost:8080/predictions/resnet18 -T dog.jpg
Or using Python:
import requests
with open("dog.jpg", "rb") as f:
img = f.read()
resp = requests.post("http://localhost:8080/predictions/resnet18",
data=img,
headers={"Content-Type": "application/octet-stream"})
print(resp.json())
# Example output: {"class": "Labrador retriever", "confidence": 0.934}
TorchServe Benefits
- Management API for model registration/unregistration
- Metrics API for monitoring
- Built-in model versioning and A/B testing
- Optimized for PyTorch models
- Scalable with multi-model serving
Method 3: ONNX Runtime
ONNX (Open Neural Network Exchange) is an open format for representing machine learning models, allowing models to be transferred between different frameworks.
Converting PyTorch Model to ONNX
import torch
import torchvision.models as models
# Load a pretrained model
model = models.resnet18(pretrained=True)
model.eval()
# Create a random input tensor
dummy_input = torch.randn(1, 3, 224, 224)
# Export to ONNX format
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
"resnet18.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # optimize model
input_names=['input'], # model's input names
output_names=['output'], # model's output names
dynamic_axes={
'input': {0: 'batch_size'}, # dynamic batch size
'output': {0: 'batch_size'}
}
)
print("Model exported to ONNX format!")
Serving with ONNX Runtime
You can serve ONNX models using ONNX Runtime, which is optimized for performance:
from flask import Flask, request, jsonify
import onnxruntime
import numpy as np
from PIL import Image
import io
import torchvision.transforms as transforms
app = Flask(__name__)
# Load the ONNX model
session = onnxruntime.InferenceSession("resnet18.onnx")
def transform_image(image_bytes):
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.open(io.BytesIO(image_bytes))
return transform(image).numpy()
@app.route('/predict', methods=['POST'])
def predict():
file = request.files['file']
img_bytes = file.read()
input_tensor = transform_image(img_bytes)
# Get input and output names
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# Run inference
results = session.run([output_name], {input_name: input_tensor.reshape(1, 3, 224, 224)})
prediction = np.argmax(results[0])
return jsonify({'prediction': int(prediction)})
if __name__ == '__main__':
app.run(debug=True)
Benefits of ONNX
- Framework-agnostic (works with PyTorch, TensorFlow, etc.)
- Performance optimizations for different hardware
- Simplified deployment across different platforms
- Reduced model size and improved inference speed
Method 4: Cloud-based Services
For production applications, cloud services offer managed solutions for model serving.
Using AWS SageMaker
Amazon SageMaker provides a complete solution for deploying PyTorch models:
import torch
import sagemaker
from sagemaker.pytorch import PyTorchModel
# Set up the SageMaker session
session = sagemaker.Session()
role = sagemaker.get_execution_role()
# Create a PyTorch model from your saved model.tar.gz file
pytorch_model = PyTorchModel(
model_data='s3://your-bucket/model.tar.gz',
role=role,
framework_version='1.8.1',
py_version='py3',
entry_point='inference.py' # Script containing load_model and predict functions
)
# Deploy the model
predictor = pytorch_model.deploy(
initial_instance_count=1,
instance_type='ml.m5.xlarge'
)
# Make predictions
response = predictor.predict(input_data)
Other Cloud Solutions
- Google Cloud AI Platform: Offers custom prediction routines for PyTorch models
- Azure Machine Learning: Provides model deployment to Azure Kubernetes Service or Azure Container Instances
- Hugging Face Inference API: Great for NLP models built with PyTorch
Performance Optimization Techniques
When serving models in production, consider these optimization techniques:
1. Quantization
Reduce model size and improve inference speed by quantizing model weights:
import torch
# Load your model
model = torch.load("your_model.pth")
# Quantize the model
quantized_model = torch.quantization.quantize_dynamic(
model, # the original model
{torch.nn.Linear}, # a set of layers to dynamically quantize
dtype=torch.qint8 # the target dtype for quantized weights
)
# Save the quantized model
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")
2. Batching Requests
Process multiple requests together for better GPU utilization:
def batch_process(model, input_batch):
"""Process a batch of inputs at once"""
with torch.no_grad():
return model(input_batch)
3. Model Pruning
Remove unnecessary weights to reduce model size:
import torch.nn.utils.prune as prune
# Prune 30% of the weights in a layer
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)
Real-World Example: Serving an Image Classification Model
Let's put everything together with a complete example of serving a ResNet model for image classification using TorchServe:
1. Prepare the model
import torch
import torchvision.models as models
# Load a pretrained ResNet model
model = models.resnet50(pretrained=True)
model.eval()
# Save the model
torch.save(model.state_dict(), "resnet50.pth")
2. Create a custom handler file (resnet_handler.py
):
from ts.torch_handler.image_classifier import ImageClassifier
import json
import os
import logging
logger = logging.getLogger(__name__)
class ResnetClassifier(ImageClassifier):
"""
Custom handler for ResNet image classification
"""
def __init__(self):
super(ResnetClassifier, self).__init__()
self.class_names = None
def initialize(self, context):
"""
Initialize model and load class names
"""
super(ResnetClassifier, self).initialize(context)
# Load the mapping of classes
with open('index_to_name.json', 'r') as f:
self.class_names = json.load(f)
logger.info("Model initialized with class names")
def postprocess(self, data):
"""
Add human-readable class names to output
"""
result = super(ResnetClassifier, self).postprocess(data)
# Add class names to the output
for idx, item in enumerate(result):
class_idx = item.get("prediction")
if class_idx is not None and self.class_names:
item["class"] = self.class_names.get(str(class_idx), "Unknown")
return result
3. Create a mapping file for class names (index_to_name.json
):
{
"0": "tench",
"1": "goldfish",
"2": "great_white_shark",
"...": "..."
}
4. Archive and serve the model:
# Create the model archive
torch-model-archiver --model-name resnet50 \
--version 1.0 \
--model-file resnet_handler.py \
--serialized-file resnet50.pth \
--handler image_classifier \
--extra-files index_to_name.json \
--export-path model_store
# Start TorchServe
torchserve --start --model-store model_store --models resnet50=resnet50.mar
5. Use the deployed model:
import requests
from PIL import Image
import io
import json
# Load and send an image
with open("cat.jpg", "rb") as f:
img_bytes = f.read()
resp = requests.post("http://localhost:8080/predictions/resnet50",
data=img_bytes,
headers={"Content-Type": "application/octet-stream"})
prediction = resp.json()
print(f"The image is classified as: {prediction['class']}")
print(f"Confidence: {prediction['confidence']:.4f}")
Summary
Model serving is the critical final step in the machine learning lifecycle, making your PyTorch models available for real-world use. In this guide, we've covered:
- Simple Flask API - Great for prototypes and learning
- TorchServe - PyTorch's dedicated serving solution
- ONNX Runtime - Framework-agnostic deployment with optimizations
- Cloud Services - Managed solutions for enterprise-scale deployment
- Optimization Techniques - Methods to improve serving performance
When choosing a serving solution, consider factors like:
- Scale and expected traffic
- Latency requirements
- Need for versioning and monitoring
- Available infrastructure and budget
- Development team expertise
Additional Resources
- TorchServe Documentation
- ONNX Runtime Documentation
- PyTorch Mobile for Edge Deployment
- AWS SageMaker Examples for PyTorch
Exercises
- Deploy a simple image classification model using Flask and test it with different images.
- Use TorchServe to deploy a pre-trained PyTorch model from torchvision.
- Convert a custom PyTorch model to ONNX format and compare the inference speed.
- Implement request batching in a Flask API to improve throughput.
- Add model monitoring to track prediction accuracy and latency over time.
By mastering model serving, you complete the machine learning workflow—from development to deployment—and enable your PyTorch models to provide value in real-world applications.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)