Skip to main content

FastAPI WebSocket Scaling

Introduction

WebSockets provide powerful real-time communication capabilities in web applications, but as your user base grows, you'll need to consider how to scale your WebSocket infrastructure effectively. In this tutorial, we'll explore various strategies and techniques for scaling WebSocket connections in FastAPI applications.

By the end of this guide, you'll understand:

  • Challenges in scaling WebSocket applications
  • Different scaling architectures for WebSockets
  • Implementation techniques for FastAPI applications
  • Best practices for production environments

Why Scaling WebSockets is Challenging

Unlike HTTP requests which are stateless and short-lived, WebSocket connections:

  1. Maintain persistent connections
  2. Consume server resources for extended periods
  3. Need to handle state across multiple server instances
  4. Require special consideration for load balancing

Let's dive into solutions for these challenges!

Basic WebSocket Application

Before we discuss scaling, let's review a simple FastAPI WebSocket application:

python
from fastapi import FastAPI, WebSocket
from typing import List

app = FastAPI()

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)

manager = ConnectionManager()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(f"Message: {data}")
except:
manager.disconnect(websocket)

This works great for a single server with a limited number of connections, but how do we scale it?

Scaling Strategies

1. Vertical Scaling

The simplest approach is to increase the resources of your server:

  • Pros: No architectural changes needed
  • Cons: Limited by hardware constraints, single point of failure

However, FastAPI leverages Uvicorn and Starlette, which are built on top of asyncio, making efficient use of resources for handling many concurrent connections on a single server.

2. Horizontal Scaling with Message Brokers

To scale across multiple servers, we need a way for different instances to communicate:

python
from fastapi import FastAPI, WebSocket
import redis.asyncio as redis
import json
import asyncio

app = FastAPI()
redis_client = redis.Redis(host='localhost', port=6379, db=0)

class RedisConnectionManager:
def __init__(self):
self.active_connections = {}
self.redis = redis_client
self.pubsub = None

async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket

# Start Redis subscription if this is the first connection
if len(self.active_connections) == 1:
asyncio.create_task(self.redis_listener())

def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]

async def broadcast(self, message: str):
# Publish message to Redis
await self.redis.publish("broadcast", message)

async def redis_listener(self):
self.pubsub = self.redis.pubsub()
await self.pubsub.subscribe("broadcast")

async for message in self.pubsub.listen():
if message["type"] == "message":
data = message["data"].decode("utf-8")
# Forward Redis message to all WebSocket clients
websocket_coros = [
connection.send_text(data)
for connection in self.active_connections.values()
]
await asyncio.gather(*websocket_coros)

manager = RedisConnectionManager()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket, client_id)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(f"Client {client_id}: {data}")
except Exception:
manager.disconnect(client_id)

In this example:

  1. Redis serves as a pub/sub mechanism
  2. Messages published by any server instance are received by all instances
  3. Each instance forwards messages to its connected WebSocket clients

3. Using Dedicated WebSocket Services

For large-scale applications, consider specialized WebSocket services:

python
# Using Broadcaster package for multiple backend support
from fastapi import FastAPI, WebSocket
from broadcaster import Broadcast
import asyncio

app = FastAPI()
broadcast = Broadcast("redis://localhost:6379")

@app.on_event("startup")
async def startup():
await broadcast.connect()

@app.on_event("shutdown")
async def shutdown():
await broadcast.disconnect()

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()

# Subscribe to channel
async with broadcast.subscribe(channel="chat") as subscriber:
# Start task to forward messages from broadcast to WebSocket
task = asyncio.create_task(forward_messages(subscriber, websocket))

try:
# Listen for messages from WebSocket and publish them
while True:
data = await websocket.receive_text()
await broadcast.publish(
channel="chat",
message=f"Client {client_id}: {data}"
)
except Exception:
task.cancel()

async def forward_messages(subscriber, websocket):
try:
async for message in subscriber:
await websocket.send_text(message)
except Exception:
pass

The broadcaster package supports various backends like Redis, PostgreSQL, and memory-based implementations.

Load Balancing Considerations

When deploying a scaled WebSocket application, consider these aspects:

Sticky Sessions

WebSocket connections should stick to the same server instance. Configure your load balancer to use sticky sessions:

nginx
# Nginx configuration example
upstream websocket_servers {
ip_hash; # Ensures client sticks to same server
server app1:8000;
server app2:8000;
}

server {
listen 80;

location /ws {
proxy_pass http://websocket_servers;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}

# Regular HTTP routes
location / {
proxy_pass http://websocket_servers;
}
}

Connection Draining

Implement graceful shutdown to handle existing connections when deploying updates:

python
import signal
import asyncio
from fastapi import FastAPI

app = FastAPI()
shutdown_event = asyncio.Event()
active_connections = set()

# Handle connection tracking
@app.websocket("/ws")
async def websocket_endpoint(websocket):
await websocket.accept()
active_connections.add(websocket)
try:
while not shutdown_event.is_set():
data = await asyncio.wait_for(
websocket.receive_text(),
timeout=1.0
)
# Process data...
except (asyncio.TimeoutError, Exception):
pass
finally:
active_connections.remove(websocket)

# Graceful shutdown handler
@app.on_event("shutdown")
async def shutdown():
shutdown_event.set()

# Give clients time to disconnect
if active_connections:
# Send close message to all clients
for websocket in active_connections:
try:
await websocket.send_text("Server shutting down")
await websocket.close(code=1000)
except:
pass

# Wait some time for clients to process the close message
await asyncio.sleep(2)

Monitoring and Limiting Connections

To prevent resource exhaustion, implement connection monitoring and limits:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
import asyncio

app = FastAPI()

# Configuration
MAX_CONNECTIONS = 10000
CONNECTIONS_PER_IP = 5
connection_count = 0
ip_counter = {}

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
global connection_count
client_ip = websocket.client.host

# Check global connection limit
if connection_count >= MAX_CONNECTIONS:
await websocket.close(code=1008, reason="Server connection limit reached")
return

# Check per-IP limit
if client_ip in ip_counter and ip_counter[client_ip] >= CONNECTIONS_PER_IP:
await websocket.close(code=1008, reason="Too many connections from your IP")
return

# Accept connection and increment counters
await websocket.accept()
connection_count += 1
ip_counter[client_ip] = ip_counter.get(client_ip, 0) + 1

try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message: {data}")
except WebSocketDisconnect:
pass
finally:
# Cleanup counters on disconnect
connection_count -= 1
ip_counter[client_ip] -= 1
if ip_counter[client_ip] == 0:
del ip_counter[client_ip]

Real-World Example: Chat Application

Let's build a scalable chat application combining these concepts:

python
# chat_app.py
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import redis.asyncio as redis
import json
import asyncio
import os
from typing import Dict, Set, Optional

app = FastAPI(title="Scalable WebSocket Chat")

# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify actual origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Redis configuration
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
redis_client = redis.from_url(REDIS_URL)

class ChatManager:
def __init__(self):
self.active_connections: Dict[str, Dict[str, WebSocket]] = {}
self.pubsub = None
self.listener_task = None

async def start_listener(self):
# Only start the listener once
if self.listener_task is None:
self.pubsub = redis_client.pubsub()
await self.pubsub.subscribe("chat:messages", "chat:user_events")
self.listener_task = asyncio.create_task(self.listen_to_redis())

async def connect(self, websocket: WebSocket, user_id: str, room: str):
# Create room if it doesn't exist
if room not in self.active_connections:
self.active_connections[room] = {}

# Accept the connection
await websocket.accept()
self.active_connections[room][user_id] = websocket

# Start Redis listener if needed
await self.start_listener()

# Announce user joined
user_event = {
"type": "join",
"room": room,
"user_id": user_id,
"online_users": list(self.active_connections[room].keys())
}
await redis_client.publish("chat:user_events", json.dumps(user_event))

def disconnect(self, user_id: str, room: str):
# Remove user from room
if room in self.active_connections and user_id in self.active_connections[room]:
del self.active_connections[room][user_id]

# Remove room if empty
if not self.active_connections[room]:
del self.active_connections[room]

# Announce user left
asyncio.create_task(
redis_client.publish(
"chat:user_events",
json.dumps({
"type": "leave",
"room": room,
"user_id": user_id
})
)
)

async def send_message(self, user_id: str, room: str, content: str):
message = {
"user_id": user_id,
"room": room,
"content": content,
"timestamp": asyncio.get_event_loop().time()
}
await redis_client.publish("chat:messages", json.dumps(message))

async def listen_to_redis(self):
try:
async for message in self.pubsub.listen():
if message["type"] == "message":
channel = message["channel"].decode("utf-8")
data = json.loads(message["data"].decode("utf-8"))

if channel == "chat:messages":
room = data["room"]
if room in self.active_connections:
websocket_coros = [
connection.send_text(json.dumps({
"type": "message",
"data": data
}))
for connection in self.active_connections[room].values()
]
await asyncio.gather(*websocket_coros, return_exceptions=True)

elif channel == "chat:user_events":
room = data["room"]
if room in self.active_connections:
websocket_coros = [
connection.send_text(json.dumps({
"type": "user_event",
"data": data
}))
for connection in self.active_connections[room].values()
]
await asyncio.gather(*websocket_coros, return_exceptions=True)
except asyncio.CancelledError:
# Listener was cancelled, clean up
await self.pubsub.unsubscribe()
except Exception as e:
# Restart listener on error
self.listener_task = asyncio.create_task(self.listen_to_redis())

manager = ChatManager()

@app.websocket("/ws/chat/{room}/{user_id}")
async def websocket_endpoint(websocket: WebSocket, room: str, user_id: str):
await manager.connect(websocket, user_id, room)
try:
while True:
data = await websocket.receive_json()
if "message" in data:
await manager.send_message(user_id, room, data["message"])
except WebSocketDisconnect:
manager.disconnect(user_id, room)
except Exception as e:
manager.disconnect(user_id, room)

# HTTP endpoint to get chat history
@app.get("/chat/{room}/history")
async def get_chat_history(room: str, limit: int = 50):
# In a real app, fetch from a database
# This is a placeholder
return {"messages": []}

# Health check endpoint
@app.get("/health")
async def health_check():
return {"status": "ok"}

To run this application at scale:

  1. Deploy multiple instances behind a load balancer
  2. Set up a Redis instance or cluster
  3. Configure the load balancer for WebSocket support and sticky sessions

Best Practices for Scaling WebSockets

  1. Use a message broker (Redis, RabbitMQ, Kafka) for cross-instance communication
  2. Implement heartbeats to detect stale connections
  3. Set appropriate timeouts for WebSocket connections
  4. Monitor connection health and implement reconnection strategies
  5. Limit connections per IP to prevent abuse
  6. Gracefully handle server shutdowns with proper connection draining
  7. Shard connections based on a consistent mechanism (like user ID)
  8. Implement backpressure handling to manage high message volumes
  9. Use rate limiting for message sending

Summary

Scaling WebSockets in FastAPI applications requires careful consideration of:

  • Persistent connection management
  • Cross-server communication
  • Load balancing with sticky sessions
  • Connection monitoring and limits

By implementing the strategies discussed in this guide, you can build robust, scalable real-time applications that can handle thousands or even millions of concurrent WebSocket connections.

Additional Resources

Exercises

  1. Exercise: Modify the chat application to store message history in a database
  2. Challenge: Implement a reconnection mechanism in the client to handle server restarts
  3. Advanced: Create a system that can scale to handle 100,000 concurrent connections by implementing proper sharding
  4. Project: Build a real-time collaborative editor using the scaling techniques discussed

Now you're equipped to build and scale WebSocket applications with FastAPI for production environments!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)