FastAPI Caching Middleware
In web applications, caching is a critical technique to improve performance by storing copies of frequently accessed data or responses. By implementing caching in your FastAPI application, you can significantly reduce response times and server load, providing a better experience for your users.
What is Caching Middleware?
Middleware in FastAPI allows you to process requests before they reach your route handlers or modify responses before they're sent back to clients. Caching middleware specifically handles storing and retrieving responses for identical requests, eliminating the need to recompute the same result multiple times.
Why Use Caching?
- Improved Response Time: Cached responses are served immediately
- Reduced Server Load: Computationally expensive operations run less frequently
- Better User Experience: Faster page loads and API responses
- Lower Backend Costs: Reduced database queries and CPU usage
Basic Caching Implementation
Let's start with a simple in-memory caching middleware for FastAPI:
from fastapi import FastAPI, Request, Response
from fastapi.middleware.base import BaseHTTPMiddleware
from typing import Dict, Callable, Any
import time
class SimpleCacheMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI, cache_expiry_time: int = 60):
super().__init__(app)
self.cache: Dict[str, dict] = {}
self.cache_expiry_time = cache_expiry_time
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Don't cache non-GET requests
if request.method != "GET":
return await call_next(request)
# Create a cache key from the full request URL
cache_key = str(request.url)
# Check if we have a cached response and it's still valid
if cache_key in self.cache:
cached_response = self.cache[cache_key]
if cached_response["expiry"] > time.time():
content = cached_response["content"]
media_type = cached_response["media_type"]
status_code = cached_response["status_code"]
return Response(content=content, media_type=media_type, status_code=status_code)
# No cache hit, process the request
response = await call_next(request)
# Cache the response if it's successful
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk
self.cache[cache_key] = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code,
"expiry": time.time() + self.cache_expiry_time
}
return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code
)
return response
Now let's see how to use this middleware in a FastAPI application:
import fastapi
app = FastAPI()
# Add our cache middleware with a 30-second expiry time
app.add_middleware(SimpleCacheMiddleware, cache_expiry_time=30)
@app.get("/users/{user_id}")
async def get_user(user_id: int):
# Simulate a slow database query
time.sleep(2) # This would normally be a database call
return {"user_id": user_id, "name": f"User {user_id}", "timestamp": time.time()}
When you make requests to the /users/1
endpoint, the first request will take about 2 seconds, but subsequent requests within the 30-second window will return instantly from the cache.
Advanced Caching with Redis
For production applications, in-memory caching has limitations:
- It doesn't persist if your server restarts
- It doesn't work across multiple servers
- It can consume excessive memory
Redis is a popular choice for implementing distributed caching. Let's create a Redis-based caching middleware:
from fastapi import FastAPI, Request, Response
from fastapi.middleware.base import BaseHTTPMiddleware
import redis
import json
import time
import pickle
from typing import Callable, Optional
class RedisCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60,
prefix: str = "fastapi_cache:"
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time
self.prefix = prefix
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Don't cache non-GET requests
if request.method != "GET":
return await call_next(request)
# Create a cache key from the full request URL
cache_key = f"{self.prefix}{str(request.url)}"
# Try to fetch from cache
cached_response = self.redis_client.get(cache_key)
if cached_response:
cached_data = pickle.loads(cached_response)
return Response(
content=cached_data["content"],
media_type=cached_data["media_type"],
status_code=cached_data["status_code"]
)
# Process the request
response = await call_next(request)
# Cache successful responses
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk
cache_data = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code
}
# Store in Redis with expiry
self.redis_client.setex(
cache_key,
self.cache_expiry_time,
pickle.dumps(cache_data)
)
return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code
)
return response
And here's how to use it:
from fastapi import FastAPI
app = FastAPI()
# Add Redis cache middleware
app.add_middleware(
RedisCacheMiddleware,
redis_url="redis://localhost:6379",
cache_expiry_time=120, # 2 minutes
prefix="myapi_cache:"
)
@app.get("/products/{product_id}")
async def get_product(product_id: int):
# Normally expensive database query here
return {"product_id": product_id, "name": f"Product {product_id}"}
Selective Caching
Sometimes you don't want to cache every endpoint. Let's modify our middleware to be more selective:
class SelectiveCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
cache_paths: list[str] = None,
exclude_paths: list[str] = None,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time
self.cache_paths = cache_paths or []
self.exclude_paths = exclude_paths or []
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Check if this path should be cached
path = request.url.path
should_cache = (
request.method == "GET" and
(not self.cache_paths or any(path.startswith(p) for p in self.cache_paths)) and
not any(path.startswith(p) for p in self.exclude_paths)
)
if not should_cache:
return await call_next(request)
# Rest of caching logic here (same as RedisCacheMiddleware)
# ...
Using the selective cache:
app = FastAPI()
app.add_middleware(
SelectiveCacheMiddleware,
cache_paths=["/products", "/categories"], # Only cache these paths
exclude_paths=["/users/me"], # Never cache this path
cache_expiry_time=300
)
Cache Headers and Client-Side Caching
Proper HTTP caching involves more than just server-side caching. You should also send appropriate cache headers to clients:
from fastapi import FastAPI, Response
from datetime import datetime, timedelta
app = FastAPI()
@app.get("/cached-content")
def get_cached_content(response: Response):
# Set cache control headers
response.headers["Cache-Control"] = "public, max-age=300" # Cache for 5 minutes
# Set expiration header
expires = datetime.utcnow() + timedelta(minutes=5)
response.headers["Expires"] = expires.strftime("%a, %d %b %Y %H:%M:%S GMT")
return {"content": "This response can be cached by browsers"}
Invalidating Cache
For dynamic data that changes, you need a way to invalidate the cache:
@app.post("/products/{product_id}")
async def update_product(product_id: int, product: ProductUpdate):
# Update the product in the database
db.update_product(product_id, product)
# Invalidate the cache for this product
cache_key = f"myapi_cache:/products/{product_id}"
redis_client.delete(cache_key)
# Also invalidate the list endpoint that might include this product
redis_client.delete("myapi_cache:/products")
return {"status": "updated"}
Best Practices for FastAPI Caching
-
Cache Selectively: Not all endpoints need caching. Focus on those that are expensive to compute or frequently accessed.
-
Set Appropriate TTL: Time-to-live (TTL) should be based on how frequently your data changes.
-
Use Cache Keys Carefully: Include query parameters and relevant headers in your cache key when needed.
-
Cache Invalidation Strategy: Plan how and when to invalidate cached items when the underlying data changes.
-
Monitor Cache Performance: Track cache hit rates and response times to ensure your caching strategy is effective.
-
Handle Errors: Make sure error responses aren't cached inappropriately.
Complete Example: A Product API with Caching
Here's a more complete example showing caching in a product API:
from fastapi import FastAPI, Depends, HTTPException, Response, Request
from fastapi.middleware.base import BaseHTTPMiddleware
import redis
import json
import time
import pickle
from typing import List, Dict, Optional
from pydantic import BaseModel
# Models
class Product(BaseModel):
id: int
name: str
price: float
description: str
stock: int
# Fake database
products_db = {
1: Product(id=1, name="Laptop", price=999.99, description="Powerful laptop", stock=10),
2: Product(id=2, name="Smartphone", price=499.99, description="Latest model", stock=20),
3: Product(id=3, name="Headphones", price=149.99, description="Noise cancelling", stock=15),
}
# Redis cache middleware
class RedisCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Only cache GET requests
if request.method != "GET":
return await call_next(request)
# Create a cache key
cache_key = f"cache:{request.url}"
# Check cache
cached_response = self.redis_client.get(cache_key)
if cached_response:
cached_data = pickle.loads(cached_response)
return Response(
content=cached_data["content"],
media_type=cached_data["media_type"],
status_code=cached_data["status_code"],
headers={"X-Cache": "HIT"}
)
# Process request
response = await call_next(request)
# Cache successful responses
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk
cache_data = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code
}
self.redis_client.setex(
cache_key,
self.cache_expiry_time,
pickle.dumps(cache_data)
)
return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code,
headers={"X-Cache": "MISS"}
)
return response
# Initialize FastAPI with cache middleware
app = FastAPI(title="Cached Product API")
app.add_middleware(
RedisCacheMiddleware,
redis_url="redis://localhost:6379",
cache_expiry_time=30 # 30 seconds
)
# Redis client for manual cache operations
redis_client = redis.from_url("redis://localhost:6379")
# Endpoints
@app.get("/products", response_model=List[Product])
async def get_products():
# Simulate slow operation
time.sleep(1)
return list(products_db.values())
@app.get("/products/{product_id}", response_model=Product)
async def get_product(product_id: int):
# Simulate slow operation
time.sleep(0.5)
if product_id not in products_db:
raise HTTPException(status_code=404, detail="Product not found")
return products_db[product_id]
@app.post("/products", response_model=Product)
async def create_product(product: Product):
products_db[product.id] = product
# Invalidate product list cache
redis_client.delete("cache:http://localhost:8000/products")
return product
@app.put("/products/{product_id}", response_model=Product)
async def update_product(product_id: int, product: Product):
if product_id not in products_db:
raise HTTPException(status_code=404, detail="Product not found")
products_db[product_id] = product
# Invalidate specific product cache and product list cache
redis_client.delete(f"cache:http://localhost:8000/products/{product_id}")
redis_client.delete("cache:http://localhost:8000/products")
return product
Summary
Caching middleware is a powerful way to improve the performance of your FastAPI applications. By intelligently caching responses, you can:
- Drastically reduce response times for frequently accessed endpoints
- Lower server load and resource usage
- Scale your application more efficiently
- Provide better user experience
We've covered several approaches to caching, from simple in-memory caching to more robust Redis-based solutions. Remember that effective caching requires careful consideration of which data to cache, for how long, and how to invalidate the cache when data changes.
Exercises
-
Implement a simple in-memory cache middleware and test its performance on a computationally expensive endpoint.
-
Extend the
RedisCacheMiddleware
to support different cache expiry times for different URL patterns. -
Create a middleware that caches based on authentication status (different caches for authenticated vs. anonymous users).
-
Implement a mechanism to manually purge the entire cache or specific cache entries through an admin endpoint.
Additional Resources
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)