FastAPI WebSocket Middleware
Introduction
WebSocket middleware is a powerful concept that allows you to execute code before and after WebSocket connections are processed in a FastAPI application. Just as HTTP middleware helps you manage HTTP requests, WebSocket middleware provides a mechanism to intercept, inspect, and potentially modify WebSocket connections and messages.
Middleware can be particularly useful for implementing:
- Authentication and authorization
- Logging and monitoring
- Rate limiting
- Message transformation
- Connection management
In this guide, we'll explore how to create and apply middleware to your FastAPI WebSocket endpoints.
Understanding WebSocket Middleware
Unlike HTTP middleware which is built into FastAPI by default, WebSocket middleware requires a custom implementation. We need to create a class or function that wraps the WebSocket connection and adds the desired functionality.
The Basic Structure
A WebSocket middleware typically follows this pattern:
class MyWebSocketMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "websocket":
# Perform WebSocket-specific middleware operations
await self.handle_websocket(scope, receive, send)
else:
# Pass other types (HTTP, lifespan) to the underlying app
await self.app(scope, receive, send)
async def handle_websocket(self, scope, receive, send):
# Middleware logic goes here
# For example, authentication, logging, etc.
# Then pass to the application
await self.app(scope, receive, send)
Creating Your First WebSocket Middleware
Let's create a simple logging middleware that logs when WebSocket connections are established and closed:
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class WebSocketLoggingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "websocket":
connection_id = f"{scope['client'][0]}:{scope['client'][1]}"
logger.info(f"WebSocket connection attempt from {connection_id}")
# Custom send function to intercept connection events
async def wrapped_send(message):
if message["type"] == "websocket.accept":
logger.info(f"WebSocket connection accepted: {connection_id}")
elif message["type"] == "websocket.close":
logger.info(f"WebSocket connection closed: {connection_id}")
await send(message)
# Custom receive function to log incoming messages
async def wrapped_receive():
message = await receive()
if message["type"] == "websocket.receive":
logger.info(f"Received WebSocket message from {connection_id}")
return message
# Pass control with our wrapped functions
start_time = time.time()
try:
await self.app(scope, wrapped_receive, wrapped_send)
finally:
duration = time.time() - start_time
logger.info(f"WebSocket connection {connection_id} duration: {duration:.2f} seconds")
else:
# Not a WebSocket request, pass through
await self.app(scope, receive, send)
# Create a FastAPI app with our middleware
app = FastAPI()
# Add the middleware to the app
app.add_middleware(WebSocketLoggingMiddleware)
Authentication Middleware for WebSockets
A common use case for WebSocket middleware is authentication. Here's how you could create a middleware that authenticates WebSocket connections:
from fastapi import FastAPI, WebSocketDisconnect
from starlette.websockets import WebSocketState
import jwt
from typing import Optional
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
class WebSocketAuthMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "websocket":
await self.app(scope, receive, send)
return
# Extract token from query parameters
query_string = scope["query_string"].decode()
token = None
# Parse query string manually
if query_string:
params = dict(param.split('=') for param in query_string.split('&'))
token = params.get("token")
if not token:
# No token provided, close connection
await self.close_connection(send, 1008, "Missing authentication token")
return
# Verify the token
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id = payload.get("sub")
if not user_id:
await self.close_connection(send, 1008, "Invalid token payload")
return
# Add user info to the scope
scope["user_id"] = user_id
# Continue to the app
await self.app(scope, receive, send)
except jwt.PyJWTError:
await self.close_connection(send, 1008, "Invalid authentication token")
async def close_connection(self, send, code: int, reason: str):
await send({
"type": "websocket.close",
"code": code,
"reason": reason
})
# Add to your FastAPI app
app = FastAPI()
app.add_middleware(WebSocketAuthMiddleware)
Now you can use this middleware with a WebSocket endpoint:
@app.websocket("/ws")
async def websocket_endpoint(websocket):
await websocket.accept()
# The user_id is available from the middleware
user_id = websocket.scope.get("user_id")
await websocket.send_text(f"Hello user {user_id}!")
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"You sent: {data}")
except WebSocketDisconnect:
print(f"User {user_id} disconnected")
Rate Limiting Middleware
Another useful middleware is for rate limiting WebSocket messages to prevent abuse:
import time
import asyncio
from collections import defaultdict
class WebSocketRateLimitMiddleware:
def __init__(self, app, messages_per_minute=60):
self.app = app
self.rate_limit = messages_per_minute
self.message_counts = defaultdict(list)
async def __call__(self, scope, receive, send):
if scope["type"] != "websocket":
await self.app(scope, receive, send)
return
client_id = f"{scope['client'][0]}:{scope['client'][1]}"
# Custom receive function to apply rate limiting
async def rate_limited_receive():
message = await receive()
if message["type"] == "websocket.receive":
current_time = time.time()
# Remove timestamps older than 1 minute
self.message_counts[client_id] = [
ts for ts in self.message_counts[client_id]
if current_time - ts < 60
]
# Check if rate limit is exceeded
if len(self.message_counts[client_id]) >= self.rate_limit:
# Instead of receiving this message, send a rate limit message
await send({
"type": "websocket.send",
"text": "Rate limit exceeded. Please slow down."
})
# Pause briefly to slow down clients that ignore the warning
await asyncio.sleep(1)
# Create a fake "receive" message to keep the connection alive
return {
"type": "websocket.receive",
"text": "__rate_limited__" # Special marker to be handled by the app
}
# Record this message timestamp
self.message_counts[client_id].append(current_time)
return message
# Pass to the app with our rate-limited receive function
await self.app(scope, rate_limited_receive, send)
# Add to your FastAPI app
app = FastAPI()
app.add_middleware(WebSocketRateLimitMiddleware, messages_per_minute=60)
Combining Multiple Middlewares
You might want to use multiple middleware components together. The order of middleware is important - each middleware wraps the next one in the chain:
app = FastAPI()
# Order matters! The first middleware added is the outermost wrapper
app.add_middleware(WebSocketLoggingMiddleware)
app.add_middleware(WebSocketAuthMiddleware)
app.add_middleware(WebSocketRateLimitMiddleware, messages_per_minute=60)
In this setup, the logging middleware will run first, then authentication, and finally rate limiting.
Real-World Example: A Complete Chat Application with Middleware
Let's put it all together in a more complete example - a chat application that uses all three of our middleware components:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel
import jwt
from datetime import datetime, timedelta
from typing import List, Dict, Set
import logging
import time
from collections import defaultdict
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Authentication settings
SECRET_KEY = "your-secure-secret-key"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
# FastAPI app
app = FastAPI(title="WebSocket Chat with Middleware")
# User model and storage (simplified for demo)
class User(BaseModel):
username: str
password: str
users_db = {
"alice": User(username="alice", password="wonderland"),
"bob": User(username="bob", password="builder")
}
# Token endpoint
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Authentication functions
def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = users_db.get(form_data.username)
if not user or form_data.password != user.password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
return {"access_token": access_token, "token_type": "bearer"}
# Chat connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = defaultdict(list)
async def connect(self, websocket: WebSocket, room: str):
await websocket.accept()
self.active_connections[room].append(websocket)
def disconnect(self, websocket: WebSocket, room: str):
self.active_connections[room].remove(websocket)
async def broadcast(self, message: str, room: str, sender: str):
for connection in self.active_connections[room]:
await connection.send_text(f"{sender}: {message}")
manager = ConnectionManager()
# Middleware definitions (reusing the ones we created earlier)
class WebSocketLoggingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "websocket":
connection_id = f"{scope['client'][0]}:{scope['client'][1]}"
logger.info(f"WebSocket connection attempt from {connection_id}")
async def wrapped_send(message):
if message["type"] == "websocket.accept":
logger.info(f"WebSocket connection accepted: {connection_id}")
elif message["type"] == "websocket.close":
logger.info(f"WebSocket connection closed: {connection_id}")
await send(message)
async def wrapped_receive():
message = await receive()
if message["type"] == "websocket.receive":
logger.info(f"Received WebSocket message from {connection_id}")
return message
start_time = time.time()
try:
await self.app(scope, wrapped_receive, wrapped_send)
finally:
duration = time.time() - start_time
logger.info(f"WebSocket connection {connection_id} duration: {duration:.2f} seconds")
else:
await self.app(scope, receive, send)
class WebSocketAuthMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "websocket":
await self.app(scope, receive, send)
return
query_string = scope["query_string"].decode()
token = None
if query_string:
params = dict(param.split('=') for param in query_string.split('&') if '=' in param)
token = params.get("token")
if not token:
await self.close_connection(send, 1008, "Missing authentication token")
return
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if not username or username not in users_db:
await self.close_connection(send, 1008, "Invalid user")
return
scope["username"] = username
await self.app(scope, receive, send)
except jwt.PyJWTError:
await self.close_connection(send, 1008, "Invalid authentication token")
async def close_connection(self, send, code: int, reason: str):
await send({
"type": "websocket.close",
"code": code,
"reason": reason
})
class WebSocketRateLimitMiddleware:
def __init__(self, app, messages_per_minute=60):
self.app = app
self.rate_limit = messages_per_minute
self.message_counts = defaultdict(list)
async def __call__(self, scope, receive, send):
if scope["type"] != "websocket":
await self.app(scope, receive, send)
return
client_id = f"{scope['client'][0]}:{scope['client'][1]}"
async def rate_limited_receive():
message = await receive()
if message["type"] == "websocket.receive":
current_time = time.time()
self.message_counts[client_id] = [
ts for ts in self.message_counts[client_id]
if current_time - ts < 60
]
if len(self.message_counts[client_id]) >= self.rate_limit:
await send({
"type": "websocket.send",
"text": "Server: Rate limit exceeded. Please slow down."
})
await asyncio.sleep(1)
return {
"type": "websocket.receive",
"text": "__rate_limited__"
}
self.message_counts[client_id].append(current_time)
return message
await self.app(scope, rate_limited_receive, send)
# Add middleware to the app
app.add_middleware(WebSocketLoggingMiddleware)
app.add_middleware(WebSocketAuthMiddleware)
app.add_middleware(WebSocketRateLimitMiddleware, messages_per_minute=10)
# WebSocket chat endpoint
@app.websocket("/ws/{room}")
async def websocket_endpoint(websocket: WebSocket, room: str):
username = websocket.scope["username"]
await manager.connect(websocket, room)
try:
# Announce user joining
await manager.broadcast(f"{username} has joined the chat", room, "Server")
while True:
data = await websocket.receive_text()
# Ignore rate-limited messages
if data == "__rate_limited__":
continue
await manager.broadcast(data, room, username)
except WebSocketDisconnect:
manager.disconnect(websocket, room)
await manager.broadcast(f"{username} has left the chat", room, "Server")
# HTML page to test the chat
@app.get("/")
async def get():
return {
"message": "WebSocket Chat API is running",
"instructions": "Get a token via POST /token, then connect to /ws/{room}?token=your_token"
}
To test this application:
- Start the server
- Get a token by making a POST request to
/token
with username/password - Connect to the WebSocket endpoint
/ws/{room}?token=your_token
- Send and receive messages
Summary
WebSocket middleware in FastAPI provides a powerful mechanism to add functionality across all your WebSocket connections. In this guide, we've covered:
- The concept of WebSocket middleware
- How to create custom middleware for logging, authentication, and rate limiting
- Combining multiple middleware components
- A complete real-world chat application using middleware
By implementing middleware, you can separate cross-cutting concerns from your WebSocket endpoint logic, resulting in cleaner, more maintainable code.
Further Resources and Exercises
Resources
Exercises
-
Monitoring Middleware: Create a middleware that tracks the number of active connections and the total number of messages processed.
-
Message Transformation: Create a middleware that transforms incoming messages (e.g., censoring certain words, converting to uppercase, etc.).
-
Connection Limits: Implement a middleware that limits the number of simultaneous WebSocket connections per user or IP address.
-
Custom Protocol: Create a middleware that handles a custom message format (e.g., JSON with specific fields) and validates each message.
-
Integration Exercise: Build a complete chat application with rooms, private messaging, and user status indicators using middleware for all cross-cutting concerns.
By working with these exercises, you'll gain a deeper understanding of how middleware can be used to enhance your WebSocket applications.
If you spot any mistakes on this website, please let me know at feedback@compilenrun.com. I’d greatly appreciate your feedback! :)