FastAPI Request Validation Middleware
Introduction
Request validation is a critical aspect of building robust APIs. While FastAPI already provides excellent built-in validation through Pydantic models, there are scenarios where you might need additional custom validation logic that applies globally across your application. This is where request validation middleware comes in handy.
In this tutorial, you'll learn how to create custom middleware for validating incoming requests in FastAPI. This approach allows you to implement validation rules that run before your route handlers, giving you more control over the data entering your application.
Prerequisites
Before we dive in, you should have:
- Basic knowledge of Python
- Familiarity with FastAPI fundamentals
- Understanding of HTTP requests and responses
- Python 3.7+ installed
- FastAPI and Uvicorn installed (
pip install fastapi uvicorn
)
Understanding Middleware in FastAPI
Middleware in FastAPI acts as a layer between the incoming request and your route handlers. It can:
- Process the request before it reaches your endpoint functions
- Modify the request or response
- Raise exceptions to prevent request processing
- Perform logging, authentication, or validation
For request validation specifically, middleware helps you enforce rules on incoming data regardless of which endpoint is being accessed.
Basic Request Validation Middleware
Let's start with a simple middleware that validates request headers:
from fastapi import FastAPI, Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
app = FastAPI()
class RequestValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Check if required headers are present
if "content-type" not in request.headers:
return HTTPException(
status_code=400,
detail="Content-Type header is required"
)
# Continue processing the request
response = await call_next(request)
return response
# Add the middleware to the application
app.add_middleware(RequestValidationMiddleware)
@app.get("/")
async def root():
return {"message": "Hello World"}
In this example, we've created a custom middleware that checks if the content-type
header is present in all incoming requests. If not, it returns a 400 Bad Request error.
Advanced Request Body Validation
While FastAPI automatically validates request bodies using Pydantic models in your route definitions, sometimes you may want to apply additional global validation. Let's create a middleware that validates JSON payloads for POST and PUT requests:
import json
from fastapi import FastAPI, Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
app = FastAPI()
class JSONValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method in ["POST", "PUT"]:
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
# Read and validate JSON body
body = await request.body()
if len(body) > 0:
json.loads(body)
# Reset the request body stream for the route handlers
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format in request body"},
)
return await call_next(request)
app.add_middleware(JSONValidationMiddleware)
@app.post("/items/")
async def create_item(request: Request):
data = await request.json()
return {"received_data": data}
This middleware intercepts POST and PUT requests, checks if they contain JSON data, and validates that the JSON is properly formatted. If not, it returns a 400 Bad Request response.
Middleware for Size Limits and Rate Limiting
Another common validation use case is limiting request sizes and implementing rate limiting. Here's how you could implement request size validation:
from fastapi import FastAPI, Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
app = FastAPI()
class RequestSizeMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_size_mb: float = 5):
super().__init__(app)
self.max_size_bytes = max_size_mb * 1024 * 1024 # Convert MB to bytes
async def dispatch(self, request: Request, call_next):
content_length = request.headers.get("content-length")
if content_length:
if int(content_length) > self.max_size_bytes:
return JSONResponse(
status_code=413,
content={
"detail": f"Request too large. Maximum size allowed is {self.max_size_bytes / (1024 * 1024)} MB"
},
)
response = await call_next(request)
return response
# Add the middleware with a 2MB limit
app.add_middleware(RequestSizeMiddleware, max_size_mb=2)
@app.post("/upload/")
async def upload_file(request: Request):
data = await request.body()
return {"file_size": len(data)}
This middleware checks the content-length
header and rejects requests that exceed the specified size limit.
Combining Multiple Validation Rules
In real-world applications, you might want to apply multiple validation rules. Let's create a comprehensive validation middleware that combines various checks:
import time
import json
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
app = FastAPI()
class ComprehensiveValidationMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_size_mb: float = 5):
super().__init__(app)
self.max_size_bytes = max_size_mb * 1024 * 1024
# Store IP addresses and their last request timestamps
self.request_history = {}
# Minimum time between requests from same IP (in seconds)
self.rate_limit = 1
async def dispatch(self, request: Request, call_next):
# 1. Rate limiting check
client_ip = request.client.host
current_time = time.time()
if client_ip in self.request_history:
time_since_last_request = current_time - self.request_history[client_ip]
if time_since_last_request < self.rate_limit:
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."}
)
self.request_history[client_ip] = current_time
# 2. Size validation
content_length = request.headers.get("content-length")
if content_length and int(content_length) > self.max_size_bytes:
return JSONResponse(
status_code=413,
content={"detail": "Request too large"}
)
# 3. Content-Type validation for POST/PUT requests
if request.method in ["POST", "PUT"]:
content_type = request.headers.get("content-type", "")
if not content_type:
return JSONResponse(
status_code=400,
content={"detail": "Content-Type header is required for POST and PUT requests"}
)
# 4. JSON validation for appropriate content types
if "application/json" in content_type:
try:
body = await request.body()
if len(body) > 0:
json.loads(body)
# Reset request body
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={"detail": "Invalid JSON format"}
)
# Continue with the request if all validations pass
response = await call_next(request)
return response
app.add_middleware(ComprehensiveValidationMiddleware, max_size_mb=2)
@app.post("/api/data/")
async def create_data(request: Request):
data = await request.json()
return {"message": "Data received successfully", "data": data}
This comprehensive middleware implements:
- Rate limiting based on client IP address
- Request size validation
- Content-Type header verification
- JSON format validation
Real-World Example: API Security Middleware
Let's implement a security-focused validation middleware for a production API:
import re
import json
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
app = FastAPI()
class SecurityValidationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Validate API key if endpoint is protected
if "/api/" in request.url.path and not "/public/" in request.url.path:
api_key = request.headers.get("x-api-key")
if not api_key or not self._is_valid_api_key(api_key):
return JSONResponse(
status_code=401,
content={"detail": "Invalid or missing API key"}
)
# SQL injection protection for query parameters
for param, values in request.query_params.items():
if self._contains_sql_injection(values):
return JSONResponse(
status_code=400,
content={"detail": f"Potential SQL injection detected in parameter: {param}"}
)
# For POST/PUT with JSON, check for injection in the body
if request.method in ["POST", "PUT"]:
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
body = await request.body()
if len(body) > 0:
json_body = json.loads(body)
if self._check_json_for_injection(json_body):
return JSONResponse(
status_code=400,
content={"detail": "Potential injection attack detected in request body"}
)
# Reset request body
async def receive():
return {"type": "http.request", "body": body}
request._receive = receive
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={"detail": "Invalid JSON format"}
)
response = await call_next(request)
return response
def _is_valid_api_key(self, key):
# In a real app, you would check against a database or secret store
valid_keys = ["test_key_123", "prod_key_456"]
return key in valid_keys
def _contains_sql_injection(self, value):
# Simple SQL injection detection patterns
sql_patterns = [
r"--",
r";",
r"drop\s+table",
r"insert\s+into",
r"select\s+from",
r"delete\s+from",
r"update\s+set",
r"union\s+select"
]
for pattern in sql_patterns:
if re.search(pattern, value, re.IGNORECASE):
return True
return False
def _check_json_for_injection(self, data):
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, str) and self._contains_sql_injection(value):
return True
elif isinstance(value, (dict, list)):
if self._check_json_for_injection(value):
return True
elif isinstance(data, list):
for item in data:
if isinstance(item, (dict, list)):
if self._check_json_for_injection(item):
return True
elif isinstance(item, str) and self._contains_sql_injection(item):
return True
return False
app.add_middleware(SecurityValidationMiddleware)
@app.get("/api/users/")
async def get_users(request: Request):
# In a real app, you would fetch from a database
return {"users": ["user1", "user2", "user3"]}
@app.post("/api/users/")
async def create_user(request: Request):
data = await request.json()
return {"message": "User created successfully", "user": data}
@app.get("/public/health/")
async def health_check():
return {"status": "ok"}
This security middleware implements:
- API key validation for protected endpoints
- SQL injection detection in query parameters
- Injection attack detection in JSON payloads
Handling Validation Errors Gracefully
To make your validation middleware more user-friendly, you can customize error responses:
import time
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
import uuid
app = FastAPI()
class ValidationErrorHandlingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# For demonstration, we'll implement rate limiting with custom error handling
client_ip = request.client.host
current_time = time.time()
# Simple rate limiting logic
if hasattr(self, 'last_request') and self.last_request.get(client_ip, 0) > current_time - 1:
# Create a unique error reference ID
error_ref = str(uuid.uuid4())
# Log the error (in a real app, you'd use a proper logging system)
print(f"Rate limit exceeded: {client_ip}, Reference: {error_ref}")
return JSONResponse(
status_code=429,
content={
"error": "Too many requests",
"detail": "Please wait before making another request",
"reference_id": error_ref,
"retry_after": 1,
},
headers={"Retry-After": "1"}
)
if not hasattr(self, 'last_request'):
self.last_request = {}
self.last_request[client_ip] = current_time
try:
response = await call_next(request)
return response
except Exception as e:
# In case of an unexpected error during request processing
error_ref = str(uuid.uuid4())
# Log the error
print(f"Unexpected error: {str(e)}, Reference: {error_ref}")
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"reference_id": error_ref,
"detail": "An unexpected error occurred. Our team has been notified."
}
)
app.add_middleware(ValidationErrorHandlingMiddleware)
@app.get("/api/resource/")
async def get_resource():
return {"data": "This is a protected resource"}
This middleware provides friendly error messages with reference IDs that can help with troubleshooting.
Performance Considerations
When implementing request validation middleware, keep in mind these performance tips:
- Keep validation logic efficient: Complex validation rules can slow down all requests.
- Consider caching: For rate limiting or similar features, consider using Redis instead of in-memory storage.
- Prioritize validations: Perform quick checks first (like header validation) before more resource-intensive ones (like body parsing).
- Use async properly: Ensure your middleware is properly awaiting async operations.
Summary
Request validation middleware in FastAPI provides a powerful way to:
- Implement global validation rules across your application
- Enhance security by detecting and blocking malicious requests
- Ensure data quality before it reaches your business logic
- Provide consistent error handling for validation failures
By creating custom validation middleware, you can centralize validation logic, reduce code duplication, and ensure that all endpoints adhere to the same validation standards.
Additional Resources
- FastAPI Official Documentation on Middleware
- Starlette Middleware Documentation
- OWASP API Security Project
Exercises
-
Basic: Create a middleware that validates a custom header called
x-app-version
and ensures it follows a semantic versioning format (e.g., "1.2.3"). -
Intermediate: Implement a middleware that checks for suspicious user agents and blocks requests with known malicious patterns.
-
Advanced: Create a content validation middleware that can sanitize HTML content in request bodies to prevent XSS attacks.
-
Challenge: Build a middleware that detects and blocks potential DDoS attacks by tracking request patterns and applying increasingly strict rate limits when suspicious patterns are detected.
Remember that while validation middleware is powerful, it's just one part of a comprehensive security strategy. Always combine it with proper authentication, authorization, and other security best practices.
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)