Skip to main content

FastAPI Caching Middleware

In web applications, caching is a critical technique to improve performance by storing copies of frequently accessed data or responses. By implementing caching in your FastAPI application, you can significantly reduce response times and server load, providing a better experience for your users.

What is Caching Middleware?

Middleware in FastAPI allows you to process requests before they reach your route handlers or modify responses before they're sent back to clients. Caching middleware specifically handles storing and retrieving responses for identical requests, eliminating the need to recompute the same result multiple times.

Why Use Caching?

  • Improved Response Time: Cached responses are served immediately
  • Reduced Server Load: Computationally expensive operations run less frequently
  • Better User Experience: Faster page loads and API responses
  • Lower Backend Costs: Reduced database queries and CPU usage

Basic Caching Implementation

Let's start with a simple in-memory caching middleware for FastAPI:

python
from fastapi import FastAPI, Request, Response
from fastapi.middleware.base import BaseHTTPMiddleware
from typing import Dict, Callable, Any
import time

class SimpleCacheMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI, cache_expiry_time: int = 60):
super().__init__(app)
self.cache: Dict[str, dict] = {}
self.cache_expiry_time = cache_expiry_time

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Don't cache non-GET requests
if request.method != "GET":
return await call_next(request)

# Create a cache key from the full request URL
cache_key = str(request.url)

# Check if we have a cached response and it's still valid
if cache_key in self.cache:
cached_response = self.cache[cache_key]
if cached_response["expiry"] > time.time():
content = cached_response["content"]
media_type = cached_response["media_type"]
status_code = cached_response["status_code"]
return Response(content=content, media_type=media_type, status_code=status_code)

# No cache hit, process the request
response = await call_next(request)

# Cache the response if it's successful
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk

self.cache[cache_key] = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code,
"expiry": time.time() + self.cache_expiry_time
}

return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code
)

return response

Now let's see how to use this middleware in a FastAPI application:

python
import fastapi

app = FastAPI()

# Add our cache middleware with a 30-second expiry time
app.add_middleware(SimpleCacheMiddleware, cache_expiry_time=30)

@app.get("/users/{user_id}")
async def get_user(user_id: int):
# Simulate a slow database query
time.sleep(2) # This would normally be a database call
return {"user_id": user_id, "name": f"User {user_id}", "timestamp": time.time()}

When you make requests to the /users/1 endpoint, the first request will take about 2 seconds, but subsequent requests within the 30-second window will return instantly from the cache.

Advanced Caching with Redis

For production applications, in-memory caching has limitations:

  • It doesn't persist if your server restarts
  • It doesn't work across multiple servers
  • It can consume excessive memory

Redis is a popular choice for implementing distributed caching. Let's create a Redis-based caching middleware:

python
from fastapi import FastAPI, Request, Response
from fastapi.middleware.base import BaseHTTPMiddleware
import redis
import json
import time
import pickle
from typing import Callable, Optional

class RedisCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60,
prefix: str = "fastapi_cache:"
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time
self.prefix = prefix

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Don't cache non-GET requests
if request.method != "GET":
return await call_next(request)

# Create a cache key from the full request URL
cache_key = f"{self.prefix}{str(request.url)}"

# Try to fetch from cache
cached_response = self.redis_client.get(cache_key)
if cached_response:
cached_data = pickle.loads(cached_response)
return Response(
content=cached_data["content"],
media_type=cached_data["media_type"],
status_code=cached_data["status_code"]
)

# Process the request
response = await call_next(request)

# Cache successful responses
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk

cache_data = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code
}

# Store in Redis with expiry
self.redis_client.setex(
cache_key,
self.cache_expiry_time,
pickle.dumps(cache_data)
)

return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code
)

return response

And here's how to use it:

python
from fastapi import FastAPI

app = FastAPI()

# Add Redis cache middleware
app.add_middleware(
RedisCacheMiddleware,
redis_url="redis://localhost:6379",
cache_expiry_time=120, # 2 minutes
prefix="myapi_cache:"
)

@app.get("/products/{product_id}")
async def get_product(product_id: int):
# Normally expensive database query here
return {"product_id": product_id, "name": f"Product {product_id}"}

Selective Caching

Sometimes you don't want to cache every endpoint. Let's modify our middleware to be more selective:

python
class SelectiveCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
cache_paths: list[str] = None,
exclude_paths: list[str] = None,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time
self.cache_paths = cache_paths or []
self.exclude_paths = exclude_paths or []

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Check if this path should be cached
path = request.url.path
should_cache = (
request.method == "GET" and
(not self.cache_paths or any(path.startswith(p) for p in self.cache_paths)) and
not any(path.startswith(p) for p in self.exclude_paths)
)

if not should_cache:
return await call_next(request)

# Rest of caching logic here (same as RedisCacheMiddleware)
# ...

Using the selective cache:

python
app = FastAPI()

app.add_middleware(
SelectiveCacheMiddleware,
cache_paths=["/products", "/categories"], # Only cache these paths
exclude_paths=["/users/me"], # Never cache this path
cache_expiry_time=300
)

Cache Headers and Client-Side Caching

Proper HTTP caching involves more than just server-side caching. You should also send appropriate cache headers to clients:

python
from fastapi import FastAPI, Response
from datetime import datetime, timedelta

app = FastAPI()

@app.get("/cached-content")
def get_cached_content(response: Response):
# Set cache control headers
response.headers["Cache-Control"] = "public, max-age=300" # Cache for 5 minutes

# Set expiration header
expires = datetime.utcnow() + timedelta(minutes=5)
response.headers["Expires"] = expires.strftime("%a, %d %b %Y %H:%M:%S GMT")

return {"content": "This response can be cached by browsers"}

Invalidating Cache

For dynamic data that changes, you need a way to invalidate the cache:

python
@app.post("/products/{product_id}")
async def update_product(product_id: int, product: ProductUpdate):
# Update the product in the database
db.update_product(product_id, product)

# Invalidate the cache for this product
cache_key = f"myapi_cache:/products/{product_id}"
redis_client.delete(cache_key)

# Also invalidate the list endpoint that might include this product
redis_client.delete("myapi_cache:/products")

return {"status": "updated"}

Best Practices for FastAPI Caching

  1. Cache Selectively: Not all endpoints need caching. Focus on those that are expensive to compute or frequently accessed.

  2. Set Appropriate TTL: Time-to-live (TTL) should be based on how frequently your data changes.

  3. Use Cache Keys Carefully: Include query parameters and relevant headers in your cache key when needed.

  4. Cache Invalidation Strategy: Plan how and when to invalidate cached items when the underlying data changes.

  5. Monitor Cache Performance: Track cache hit rates and response times to ensure your caching strategy is effective.

  6. Handle Errors: Make sure error responses aren't cached inappropriately.

Complete Example: A Product API with Caching

Here's a more complete example showing caching in a product API:

python
from fastapi import FastAPI, Depends, HTTPException, Response, Request
from fastapi.middleware.base import BaseHTTPMiddleware
import redis
import json
import time
import pickle
from typing import List, Dict, Optional
from pydantic import BaseModel

# Models
class Product(BaseModel):
id: int
name: str
price: float
description: str
stock: int

# Fake database
products_db = {
1: Product(id=1, name="Laptop", price=999.99, description="Powerful laptop", stock=10),
2: Product(id=2, name="Smartphone", price=499.99, description="Latest model", stock=20),
3: Product(id=3, name="Headphones", price=149.99, description="Noise cancelling", stock=15),
}

# Redis cache middleware
class RedisCacheMiddleware(BaseHTTPMiddleware):
def __init__(
self,
app: FastAPI,
redis_url: str = "redis://localhost:6379",
cache_expiry_time: int = 60
):
super().__init__(app)
self.redis_client = redis.from_url(redis_url)
self.cache_expiry_time = cache_expiry_time

async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Only cache GET requests
if request.method != "GET":
return await call_next(request)

# Create a cache key
cache_key = f"cache:{request.url}"

# Check cache
cached_response = self.redis_client.get(cache_key)
if cached_response:
cached_data = pickle.loads(cached_response)
return Response(
content=cached_data["content"],
media_type=cached_data["media_type"],
status_code=cached_data["status_code"],
headers={"X-Cache": "HIT"}
)

# Process request
response = await call_next(request)

# Cache successful responses
if 200 <= response.status_code < 300:
content = b''
async for chunk in response.body_iterator:
content += chunk

cache_data = {
"content": content,
"media_type": response.media_type,
"status_code": response.status_code
}

self.redis_client.setex(
cache_key,
self.cache_expiry_time,
pickle.dumps(cache_data)
)

return Response(
content=content,
media_type=response.media_type,
status_code=response.status_code,
headers={"X-Cache": "MISS"}
)

return response

# Initialize FastAPI with cache middleware
app = FastAPI(title="Cached Product API")
app.add_middleware(
RedisCacheMiddleware,
redis_url="redis://localhost:6379",
cache_expiry_time=30 # 30 seconds
)

# Redis client for manual cache operations
redis_client = redis.from_url("redis://localhost:6379")

# Endpoints
@app.get("/products", response_model=List[Product])
async def get_products():
# Simulate slow operation
time.sleep(1)
return list(products_db.values())

@app.get("/products/{product_id}", response_model=Product)
async def get_product(product_id: int):
# Simulate slow operation
time.sleep(0.5)
if product_id not in products_db:
raise HTTPException(status_code=404, detail="Product not found")
return products_db[product_id]

@app.post("/products", response_model=Product)
async def create_product(product: Product):
products_db[product.id] = product

# Invalidate product list cache
redis_client.delete("cache:http://localhost:8000/products")

return product

@app.put("/products/{product_id}", response_model=Product)
async def update_product(product_id: int, product: Product):
if product_id not in products_db:
raise HTTPException(status_code=404, detail="Product not found")

products_db[product_id] = product

# Invalidate specific product cache and product list cache
redis_client.delete(f"cache:http://localhost:8000/products/{product_id}")
redis_client.delete("cache:http://localhost:8000/products")

return product

Summary

Caching middleware is a powerful way to improve the performance of your FastAPI applications. By intelligently caching responses, you can:

  1. Drastically reduce response times for frequently accessed endpoints
  2. Lower server load and resource usage
  3. Scale your application more efficiently
  4. Provide better user experience

We've covered several approaches to caching, from simple in-memory caching to more robust Redis-based solutions. Remember that effective caching requires careful consideration of which data to cache, for how long, and how to invalidate the cache when data changes.

Exercises

  1. Implement a simple in-memory cache middleware and test its performance on a computationally expensive endpoint.

  2. Extend the RedisCacheMiddleware to support different cache expiry times for different URL patterns.

  3. Create a middleware that caches based on authentication status (different caches for authenticated vs. anonymous users).

  4. Implement a mechanism to manually purge the entire cache or specific cache entries through an admin endpoint.

Additional Resources



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)