Skip to main content

FastAPI Custom Middleware

In FastAPI applications, middleware plays a crucial role in processing requests and responses. While FastAPI provides several built-in middleware options, creating custom middleware allows you to implement specific functionality tailored to your application's needs.

Introduction to Custom Middleware

Custom middleware in FastAPI enables you to execute code before a request is processed by any specific path operation or after a response has been generated. This gives you powerful hooks into the request-response cycle, allowing you to implement features like:

  • Custom authentication or authorization
  • Request logging and monitoring
  • Request or response modification
  • Performance tracking
  • Error handling
  • Input validation

By the end of this tutorial, you'll be able to create your own custom middleware to handle these scenarios and more.

Understanding Middleware in FastAPI

Before diving into custom implementations, let's understand how middleware works in FastAPI:

  1. Middleware receives every request before it's processed by your route handlers
  2. It can perform operations on the request
  3. It then passes the request to the next middleware or route handler
  4. After the response is generated, middleware can also modify the response before it's sent back to the client
  5. Middleware is executed in reverse order for responses (last registered middleware processes the response first)

Creating Your First Custom Middleware

Let's start by creating a simple custom middleware that logs information about each request. Here's a basic structure:

python
from fastapi import FastAPI, Request
import time

app = FastAPI()

@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()

# Process the request and get the response
response = await call_next(request)

# Calculate processing time
process_time = time.time() - start_time

# Log request details
print(f"Path: {request.url.path}")
print(f"Method: {request.method}")
print(f"Processing time: {process_time:.4f} seconds")

return response

This middleware:

  1. Records the time when a request arrives
  2. Passes the request to the next middleware or route handler using call_next
  3. Calculates how long the request took to process
  4. Logs information about the request
  5. Returns the response from the route handler

How It Works

When you make a request to any endpoint in your FastAPI application, this middleware will log information about each request:

Path: /users
Method: GET
Processing time: 0.0021 seconds

Path: /items
Method: POST
Processing time: 0.0035 seconds

Adding Headers with Custom Middleware

Let's create a middleware that adds custom headers to all responses:

python
@app.middleware("http")
async def add_custom_headers(request: Request, call_next):
response = await call_next(request)

# Add custom headers to the response
response.headers["X-Process-By"] = "FastAPI"
response.headers["X-Custom-Framework"] = "FastAPI Middleware Example"

return response

When added to your application, this middleware will inject these custom headers into every response. This can be useful for:

  • Adding security headers
  • Tagging responses for debugging
  • Including version information
  • Implementing CORS headers

Authentication Middleware Example

A common use case for custom middleware is authentication. Here's an example of a middleware that checks for an API key:

python
from fastapi import FastAPI, Request, HTTPException
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

app = FastAPI()

API_KEYS = ["valid_key_1", "valid_key_2", "test_key"] # In production, use a more secure storage

@app.middleware("http")
async def api_key_validator(request: Request, call_next):
# Exclude authentication for documentation paths
if request.url.path in ["/docs", "/redoc", "/openapi.json"]:
return await call_next(request)

# Get the API key from the header
api_key = request.headers.get("X-API-Key")

if not api_key:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="API key missing. Please include an X-API-Key header."
)

if api_key not in API_KEYS:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid API key"
)

# If authentication is successful, process the request
response = await call_next(request)
return response

This middleware:

  1. Checks if the current request path should skip authentication (like documentation pages)
  2. Looks for an API key in the request headers
  3. Validates the API key against a list of allowed keys
  4. Either allows the request to proceed or raises an appropriate exception

Error Handling Middleware

Another useful middleware pattern is global error handling:

python
import traceback
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse

app = FastAPI()

@app.middleware("http")
async def exception_handling_middleware(request: Request, call_next):
try:
# Process the request normally
return await call_next(request)
except Exception as e:
# Log the error details
error_details = traceback.format_exc()
print(f"Error occurred: {error_details}")

# Return a generic error response to avoid exposing sensitive details
return JSONResponse(
status_code=500,
content={"message": "An internal server error occurred"},
)

This middleware catches any exceptions that aren't handled within your route handlers, logs the detailed error for debugging, and returns a user-friendly error response.

Request Modification Middleware

Sometimes you need to modify the request before it reaches your route handlers:

python
from fastapi import FastAPI, Request

app = FastAPI()

@app.middleware("http")
async def request_modifier(request: Request, call_next):
# Store the original request body
body = await request.body()

# You can't modify the body directly, but you can store it
# and make it available when route handlers access request.body()

# Create a custom property to make the body available again
# by modifying the request state
request.state.raw_body = body

# Continue with the request
response = await call_next(request)
return response

# In your route handler, you can access the raw body:
@app.post("/items/")
async def create_item(request: Request):
raw_body = request.state.raw_body
# Process the raw body as needed...
return {"received": len(raw_body)}

This pattern is useful when you need access to the raw request body multiple times, which normally wouldn't be possible since reading the body consumes it.

Performance Monitoring Middleware

Let's create more advanced middleware for performance monitoring:

python
from fastapi import FastAPI, Request
import time
from datetime import datetime

app = FastAPI()

# Dictionary to store endpoint statistics
endpoint_stats = {}

@app.middleware("http")
async def performance_middleware(request: Request, call_next):
# Get the route path
path = request.url.path
method = request.method
endpoint = f"{method} {path}"

# Record start time
start_time = time.time()

# Process the request
response = await call_next(request)

# Calculate processing time
process_time = time.time() - start_time

# Update statistics
if endpoint not in endpoint_stats:
endpoint_stats[endpoint] = {
"count": 0,
"total_time": 0,
"min_time": float("inf"),
"max_time": 0,
"last_called": None
}

stats = endpoint_stats[endpoint]
stats["count"] += 1
stats["total_time"] += process_time
stats["min_time"] = min(stats["min_time"], process_time)
stats["max_time"] = max(stats["max_time"], process_time)
stats["last_called"] = datetime.now().isoformat()

# You could log these stats or expose them via an admin endpoint

return response

@app.get("/admin/stats")
async def get_stats():
# An endpoint to expose the collected statistics
return endpoint_stats

This middleware collects detailed performance metrics for each endpoint in your application, which you can then access through a special administrative endpoint.

Combining Multiple Middleware

In real-world applications, you'll often need multiple middleware components. FastAPI processes them in the order they are registered:

python
from fastapi import FastAPI, Request
import time

app = FastAPI()

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response

@app.middleware("http")
async def add_server_header(request: Request, call_next):
response = await call_next(request)
response.headers["Server"] = "FastAPI Custom Server"
return response

# This middleware will be executed first, then add_server_header,
# then the route handler, then add_server_header finishes,
# and finally add_process_time_header finishes

Remember that middleware is executed in the order they are registered for requests, and in reverse order for responses.

Creating a Middleware Class

For more complex middleware, you might prefer to use a class-based approach using Starlette's BaseHTTPMiddleware:

python
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

class CustomHeaderMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)
response.headers["Custom-Header"] = "Custom Value"
return response

app = FastAPI()
app.add_middleware(CustomHeaderMiddleware)

The class-based approach is particularly useful for middleware that requires initialization parameters:

python
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, max_requests: int = 10, window_seconds: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window_seconds = window_seconds
self.request_log = {}

async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()

# Clean old records
self.request_log = {
ip: times for ip, times in self.request_log.items()
if any(t > current_time - self.window_seconds for t in times)
}

# Check rate limit
if client_ip in self.request_log:
times = self.request_log[client_ip]
times = [t for t in times if t > current_time - self.window_seconds]

if len(times) >= self.max_requests:
return Response(
content="Rate limit exceeded",
status_code=429
)

self.request_log[client_ip] = times + [current_time]
else:
self.request_log[client_ip] = [current_time]

return await call_next(request)

# Add the middleware with specific parameters
app.add_middleware(RateLimitMiddleware, max_requests=100, window_seconds=60)

This rate-limiting middleware demonstrates how you can create configurable middleware components.

Real-world Application: Request Correlation

Here's a practical example of middleware that adds correlation IDs to requests for distributed tracing:

python
import uuid
from fastapi import FastAPI, Request
from starlette.middleware.base import BaseHTTPMiddleware
from contextvars import ContextVar

# Create a context variable to store the correlation ID
correlation_id = ContextVar("correlation_id", default=None)

class CorrelationMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Generate or extract correlation ID
request_id = request.headers.get("X-Correlation-ID")
if not request_id:
request_id = str(uuid.uuid4())

# Store in context var for access in route handlers
correlation_id.set(request_id)

# Add correlation ID to request state for easy access
request.state.correlation_id = request_id

# Process the request
response = await call_next(request)

# Add correlation ID to response headers
response.headers["X-Correlation-ID"] = request_id

return response

app = FastAPI()
app.add_middleware(CorrelationMiddleware)

# Use the correlation ID in a route handler
@app.get("/items/")
async def read_items(request: Request):
current_id = correlation_id.get()
return {
"correlation_id": current_id,
"message": "This ID can be used to trace this request across microservices"
}

This middleware generates a unique ID for each request, which can be used to track requests across multiple services in a microservice architecture.

Summary

In this article, we learned:

  1. How custom middleware works in FastAPI
  2. How to create simple function-based middleware
  3. How to implement class-based middleware with BaseHTTPMiddleware
  4. Various practical use cases for middleware:
    • Request logging and timing
    • Authentication and authorization
    • Error handling
    • Header manipulation
    • Performance monitoring
    • Rate limiting
    • Request correlation/distributed tracing

Custom middleware in FastAPI provides a powerful way to implement cross-cutting concerns in your application. By separating these concerns from your route handlers, you can keep your code clean and maintainable while still adding sophisticated functionality across your entire application.

Exercises

To reinforce your understanding:

  1. Create a middleware that logs detailed information about requests and responses, including headers and status codes
  2. Implement a middleware that adds appropriate CORS headers to your responses
  3. Build a middleware that detects and blocks suspicious requests based on patterns (e.g., SQL injection attempts)
  4. Create a caching middleware that stores responses for GET requests to reduce database load
  5. Implement a middleware that tracks user sessions without using cookies

Additional Resources

By mastering custom middleware in FastAPI, you'll be able to implement powerful, reusable components that can enhance all the endpoints in your application with minimal duplication of code.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)