PyTorch REST API
In modern machine learning workflows, developing a model is just the first step. To make your PyTorch models accessible to applications, websites, or other services, you need to deploy them in a way that allows for easy integration. One of the most popular approaches is to create a REST API that serves your PyTorch model, allowing clients to send requests and receive predictions over HTTP.
What is a REST API?
REST (Representational State Transfer) is an architectural style for designing networked applications. A REST API allows different systems to communicate over HTTP using standard methods like GET, POST, PUT, and DELETE. For model deployment, we typically use POST requests to send input data and receive predictions in return.
Why Deploy PyTorch Models as REST APIs?
- Language-agnostic integration: Any application that can make HTTP requests can use your model
- Scalability: APIs can be horizontally scaled to handle increasing loads
- Separation of concerns: Keeps model serving separate from application logic
- Centralized updates: You can update your model without requiring clients to change their code
Prerequisites
Before we start, ensure you have the following installed:
pip install torch torchvision flask fastapi uvicorn pillow numpy
Option 1: Building a PyTorch API with Flask
Flask is a lightweight web framework for Python that's perfect for simple APIs. Let's create a basic API to serve a pre-trained PyTorch model.
Step 1: Load a Pre-trained Model
First, let's load a pre-trained ResNet model from torchvision:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import io
import json
# Load a pre-trained ResNet model
model = models.resnet18(pretrained=True)
model.eval() # Set to evaluation mode
# Define image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load class labels
with open('imagenet_classes.json', 'r') as f:
labels = json.load(f)
Step 2: Create Flask API
Now, let's create a Flask application with an endpoint to accept images and return predictions:
from flask import Flask, request, jsonify
app = Flask(__name__)
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return jsonify({'error': 'No file provided'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No file selected'}), 400
try:
# Read image
img_bytes = file.read()
img = Image.open(io.BytesIO(img_bytes))
# Preprocess the image
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
prediction_idx = predicted.item()
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(outputs, 5)
top5_prob = torch.nn.functional.softmax(top5_prob, dim=1)[0].tolist()
top5_idx = top5_idx[0].tolist()
top5_predictions = [
{"label": labels[idx], "probability": prob}
for idx, prob in zip(top5_idx, top5_prob)
]
return jsonify({
'prediction': labels[prediction_idx],
'top5_predictions': top5_predictions
})
except Exception as e:
return jsonify({'error': str(e)}), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
Step 3: Test the Flask API
You can test your API using curl
or any HTTP client:
curl -X POST -F "file=@path/to/your/image.jpg" http://localhost:5000/predict
Example output:
{
"prediction": "golden retriever",
"top5_predictions": [
{"label": "golden retriever", "probability": 0.8723},
{"label": "Labrador retriever", "probability": 0.0912},
{"label": "kuvasz", "probability": 0.0134},
{"label": "clumber", "probability": 0.0087},
{"label": "tennis ball", "probability": 0.0032}
]
}
Option 2: Building a PyTorch API with FastAPI
FastAPI is a modern, high-performance web framework that's becoming increasingly popular for API development. It offers automatic interactive documentation, data validation, and better performance than Flask.
Step 1: Define the FastAPI Application
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import io
import json
import numpy as np
app = FastAPI(title="PyTorch Model API",
description="API for serving PyTorch image classification models")
# Load model (same as Flask example)
model = models.resnet18(pretrained=True)
model.eval()
# Define image transformation (same as Flask example)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load class labels
with open('imagenet_classes.json', 'r') as f:
labels = json.load(f)
Step 2: Create Prediction Endpoint
@app.post("/predict", response_class=JSONResponse)
async def predict_image(file: UploadFile = File(...)):
try:
# Read image
content = await file.read()
img = Image.open(io.BytesIO(content))
# Preprocess the image
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted = torch.max(outputs, 1)
prediction_idx = predicted.item()
# Get top 5 predictions
top5_prob, top5_idx = torch.topk(outputs, 5)
top5_prob = torch.nn.functional.softmax(top5_prob, dim=1)[0].tolist()
top5_idx = top5_idx[0].tolist()
top5_predictions = [
{"label": labels[idx], "probability": prob}
for idx, prob in zip(top5_idx, top5_prob)
]
return {
'prediction': labels[prediction_idx],
'top5_predictions': top5_predictions
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
Step 3: Run the FastAPI Application
Create a file to run the application:
# run.py
import uvicorn
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
Then execute:
python run.py
With FastAPI, you automatically get interactive API documentation at http://localhost:8000/docs
.
Step 4: Test the FastAPI Application
You can test using curl
:
curl -X POST -F "file=@path/to/your/image.jpg" http://localhost:8000/predict
Or use the interactive documentation at http://localhost:8000/docs
.
Real-World Deployment Considerations
When deploying your PyTorch model API in a production environment, consider the following:
1. Performance Optimization
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# In prediction function
img_tensor = img_tensor.to(device)
2. Batch Processing
For higher throughput, implement batch processing:
@app.post("/predict_batch")
async def predict_batch(files: List[UploadFile] = File(...)):
batch_tensors = []
for file in files:
content = await file.read()
img = Image.open(io.BytysIO(content))
tensor = transform(img)
batch_tensors.append(tensor)
# Stack tensors into a batch
batch = torch.stack(batch_tensors).to(device)
with torch.no_grad():
outputs = model(batch)
# Process outputs...
return results
3. Model Versioning
Implement model versioning to maintain backward compatibility:
models = {
"v1": {"model": model_v1, "transform": transform_v1},
"v2": {"model": model_v2, "transform": transform_v2},
}
@app.post("/predict/{version}")
async def predict(version: str, file: UploadFile = File(...)):
if version not in models:
raise HTTPException(status_code=404, detail="Model version not found")
model_info = models[version]
# Use version-specific model and transform
# ...
4. Request Validation
Add validation for acceptable image formats:
def validate_image(file: UploadFile):
if file.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(
status_code=400,
detail="Only JPEG and PNG images are accepted"
)
5. Containerization with Docker
Create a Dockerfile
for your API:
FROM python:3.8-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY . .
# Download the model weights at build time
RUN python -c "import torchvision.models as models; models.resnet18(pretrained=True)"
EXPOSE 8000
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
Complete Example: Production-Ready Model API
Here's a more complete example that incorporates best practices:
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
from fastapi.responses import JSONResponse
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import io
import json
import os
import uuid
import logging
from typing import List, Dict, Any
from pydantic import BaseModel
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize FastAPI
app = FastAPI(
title="PyTorch Model API",
description="Production-ready API for PyTorch model inference",
version="1.0.0"
)
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load model
@app.on_event("startup")
async def load_model():
global model, transform, labels
try:
model = models.resnet18(pretrained=True)
model.to(device)
model.eval()
# Define image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load class labels
with open('imagenet_classes.json', 'r') as f:
labels = json.load(f)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise e
# Response models
class Prediction(BaseModel):
label: str
probability: float
class PredictionResponse(BaseModel):
request_id: str
prediction: str
top_predictions: List[Prediction]
processing_time_ms: float
# Prediction endpoint
@app.post("/predict", response_model=PredictionResponse)
async def predict_image(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
request_id = str(uuid.uuid4())
logger.info(f"Processing request {request_id}")
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)
try:
# Validate file
if file.content_type not in ["image/jpeg", "image/png"]:
raise HTTPException(
status_code=400,
detail="Only JPEG and PNG images are accepted"
)
# Read and process image
content = await file.read()
img = Image.open(io.BytesIO(content))
# If image is not RGB (like PNG with alpha channel), convert it
if img.mode != "RGB":
img = img.convert("RGB")
img_tensor = transform(img)
img_tensor = img_tensor.unsqueeze(0).to(device)
# Measure inference time
start_time.record()
with torch.no_grad():
outputs = model(img_tensor)
end_time.record()
torch.cuda.synchronize()
processing_time = start_time.elapsed_time(end_time)
# Get predictions
probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
top5_prob, top5_indices = torch.topk(probabilities, 5)
# Format response
top_predictions = [
Prediction(label=labels[idx.item()], probability=prob.item())
for idx, prob in zip(top5_indices, top5_prob)
]
response = PredictionResponse(
request_id=request_id,
prediction=labels[top5_indices[0].item()],
top_predictions=top_predictions,
processing_time_ms=processing_time
)
# Log request in background
if background_tasks:
background_tasks.add_task(log_prediction, request_id, response)
return response
except Exception as e:
logger.error(f"Error processing request {request_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": "resnet18"}
# Background task for logging
async def log_prediction(request_id: str, response: Dict[str, Any]):
logger.info(f"Request {request_id} predicted as {response.prediction}")
Serving Multiple Models
For more complex applications, you might need to serve multiple models. Here's a pattern for that:
class ModelManager:
def __init__(self):
self.models = {}
def load_model(self, model_id, model_path):
"""Load a model from a path and store it with an ID"""
model = torch.load(model_path, map_location=device)
model.eval()
self.models[model_id] = model
logger.info(f"Loaded model {model_id}")
def get_model(self, model_id):
"""Get a model by ID"""
if model_id not in self.models:
raise ValueError(f"Model {model_id} not found")
return self.models[model_id]
def predict(self, model_id, input_data):
"""Make a prediction with the specified model"""
model = self.get_model(model_id)
with torch.no_grad():
output = model(input_data)
return output
# Initialize model manager
model_manager = ModelManager()
# Load models at startup
@app.on_event("startup")
async def startup_event():
model_manager.load_model("resnet18", "models/resnet18.pth")
model_manager.load_model("mobilenet", "models/mobilenet.pth")
# Endpoint that specifies which model to use
@app.post("/predict/{model_id}")
async def predict_with_model(model_id: str, file: UploadFile = File(...)):
try:
# Process image...
result = model_manager.predict(model_id, img_tensor)
# Format response...
return result
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
Summary
In this guide, you've learned how to:
- Build a REST API to serve PyTorch models using Flask and FastAPI
- Properly preprocess input data for inference
- Return formatted predictions as JSON responses
- Implement advanced features like model versioning and batch processing
- Consider production-ready concerns like performance, Docker deployment, and error handling
Deploying your PyTorch models as REST APIs makes them accessible to a wide range of applications, from web frontends to mobile apps and other services.
Additional Resources
- PyTorch Documentation
- FastAPI Documentation
- Flask Documentation
- TorchServe - A dedicated model serving framework for PyTorch
Exercises
- Extend the API to accept different image sizes and adjust preprocessing accordingly
- Implement caching to avoid reloading the model for each request
- Add authentication to your API using FastAPI's security utilities
- Implement a rate limiting mechanism to protect your API from abuse
- Create a simple web interface that allows users to upload images and view predictions
Happy deploying!
If you spot any mistakes on this website, please let me know at feedback@compilenrun.com. I’d greatly appreciate your feedback! :)