Skip to main content

FastAPI WebSocket Security

Introduction

WebSockets provide a powerful way to create real-time, bidirectional communication between clients and servers in FastAPI applications. However, with great power comes great responsibility—securing your WebSocket connections is crucial to protect your application from various attacks and unauthorized access.

In this guide, we'll explore different security measures for FastAPI WebSockets, including authentication, authorization, rate limiting, and input validation. By the end, you'll understand how to build secure WebSocket implementations that can be safely deployed in production environments.

Why WebSocket Security Matters

WebSockets maintain persistent connections, which creates unique security challenges compared to regular HTTP endpoints:

  • Persistent connections: Unlike REST APIs, where each request is separate and authenticated individually, WebSockets maintain a long-lived connection.
  • Real-time data: WebSockets often transmit sensitive information in real-time.
  • Resource consumption: Malicious users could create many WebSocket connections, leading to denial-of-service.

Authentication for WebSockets

One of the most straightforward approaches is to use cookie-based authentication that integrates with your existing authentication system:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Cookie, Depends, status
from fastapi.responses import HTMLResponse
from typing import Optional

app = FastAPI()

# Simulated user database
users_db = {
"user1": {"username": "user1", "password": "password1"},
"user2": {"username": "user2", "password": "password2"},
}

# Simple authentication function
async def get_current_user(session_token: Optional[str] = Cookie(None)):
if session_token == "valid_session_token":
return {"username": "testuser"}
return None

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
current_user: dict = Depends(get_current_user)
):
if current_user is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"User {current_user['username']} sent: {data}")
except WebSocketDisconnect:
print(f"Client disconnected")

Token-based Authentication

For applications that use token-based authentication (like JWT), you can pass the token as a query parameter:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, status
import jwt

app = FastAPI()

JWT_SECRET = "your_secret_key"

async def verify_token(token: str = Query(...)):
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
return payload
except:
return None

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
user = await verify_token(token)
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message processed: {data}")
except WebSocketDisconnect:
print("Client disconnected")

When connecting from the client, you would include the token:

javascript
// Client-side code
const token = "your.jwt.token";
const socket = new WebSocket(`ws://localhost:8000/ws?token=${token}`);

Custom Authentication Class

For more complex authentication needs, we can create a custom WebSocket authentication dependency:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status, Depends
from typing import Optional, Dict, Any

app = FastAPI()

class WebSocketAuthMiddleware:
async def __call__(
self, websocket: WebSocket
) -> Optional[Dict[str, Any]]:
# Get headers or query parameters
headers = websocket.headers
query_params = websocket.query_params

# Check for token in headers or query parameters
token = headers.get("authorization") or query_params.get("token")

if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return None

# Verify token (this is a simplified example)
if token != "valid_token":
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return None

# Return user information
return {"user_id": "123", "username": "testuser"}

ws_auth = WebSocketAuthMiddleware()

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
user: dict = Depends(ws_auth)
):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"User {user['username']} sent: {data}")
except WebSocketDisconnect:
print("Client disconnected")

Rate Limiting for WebSockets

Rate limiting helps prevent abuse of your WebSocket endpoints. Here's how to implement a simple rate limiter:

python
import time
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()

# Simple in-memory rate limiting
class RateLimiter:
def __init__(self, max_messages: int, time_window: float):
self.max_messages = max_messages
self.time_window = time_window
self.message_history = {}

async def is_rate_limited(self, client_id: str) -> bool:
current_time = time.time()

# Initialize client history if not exists
if client_id not in self.message_history:
self.message_history[client_id] = []

# Clean old messages outside the time window
self.message_history[client_id] = [
timestamp for timestamp in self.message_history[client_id]
if current_time - timestamp <= self.time_window
]

# Check if rate limit is exceeded
if len(self.message_history[client_id]) >= self.max_messages:
return True

# Add current message timestamp
self.message_history[client_id].append(current_time)
return False

# Create rate limiter: 5 messages per 3 seconds
rate_limiter = RateLimiter(max_messages=5, time_window=3.0)

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()

# Check rate limit
if await rate_limiter.is_rate_limited(client_id):
await websocket.send_text("Rate limit exceeded. Please slow down.")
continue

await websocket.send_text(f"Message processed: {data}")
except WebSocketDisconnect:
print(f"Client {client_id} disconnected")

Input Validation for WebSockets

Validating incoming WebSocket messages is essential for security. Let's see how to use Pydantic for validation:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel, validator
import json

app = FastAPI()

class ChatMessage(BaseModel):
message: str
room_id: str

@validator('message')
def validate_message(cls, v):
if len(v) > 1000:
raise ValueError("Message too long (max 1000 characters)")
return v

@validator('room_id')
def validate_room_id(cls, v):
if not v.isalnum():
raise ValueError("Room ID must be alphanumeric")
return v

@app.websocket("/ws/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()

try:
while True:
data = await websocket.receive_text()

# Validate input
try:
parsed_data = json.loads(data)
chat_message = ChatMessage(**parsed_data)

# Process validated message
await websocket.send_text(
f"Message to room {chat_message.room_id}: {chat_message.message}"
)

except json.JSONDecodeError:
await websocket.send_text("Error: Invalid JSON format")
except ValueError as e:
await websocket.send_text(f"Validation error: {str(e)}")

except WebSocketDisconnect:
print("Client disconnected")

Protection Against Connection Flooding

To prevent denial-of-service attacks through connection flooding, you can limit the number of concurrent connections:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

app = FastAPI()

# Track active connections
active_connections = {}
MAX_CONNECTIONS_PER_IP = 5

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
client_ip = websocket.client.host

# Check if max connections reached for this IP
if client_ip in active_connections and active_connections[client_ip] >= MAX_CONNECTIONS_PER_IP:
await websocket.close(code=1008) # Policy violation
return

# Update connection count
active_connections[client_ip] = active_connections.get(client_ip, 0) + 1

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
# Decrease connection count on disconnect
active_connections[client_ip] -= 1
if active_connections[client_ip] == 0:
del active_connections[client_ip]
print(f"Client disconnected, {active_connections.get(client_ip, 0)} connections remaining for IP {client_ip}")

Real-world Example: Secure Chat Application

Let's combine the techniques above into a more complete secure chat application:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, Query, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, validator
import jwt
import json
import time
from typing import Dict, List, Set, Optional

app = FastAPI()

# CORS settings
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

# Constants
JWT_SECRET = "your_secret_key" # Use env variables in production
MAX_CONNECTIONS_PER_USER = 3
MESSAGE_RATE_LIMIT = 5 # messages
RATE_LIMIT_WINDOW = 5.0 # seconds

# Data storage
active_connections: Dict[str, List[WebSocket]] = {}
user_message_timestamps: Dict[str, List[float]] = {}

# Models
class ChatMessage(BaseModel):
message: str

@validator('message')
def validate_message(cls, v):
if not v.strip():
raise ValueError("Message cannot be empty")
if len(v) > 500:
raise ValueError("Message too long (max 500 characters)")
return v

# Auth functions
async def get_user_from_token(token: str = Query(...)) -> Optional[dict]:
try:
payload = jwt.decode(token, JWT_SECRET, algorithms=["HS256"])
return {"user_id": payload["sub"], "username": payload["username"]}
except:
return None

# Rate limiting
async def is_rate_limited(user_id: str) -> bool:
current_time = time.time()

if user_id not in user_message_timestamps:
user_message_timestamps[user_id] = []

# Remove old timestamps
user_message_timestamps[user_id] = [
ts for ts in user_message_timestamps[user_id]
if current_time - ts <= RATE_LIMIT_WINDOW
]

# Check if rate limited
if len(user_message_timestamps[user_id]) >= MESSAGE_RATE_LIMIT:
return True

# Add current timestamp
user_message_timestamps[user_id].append(current_time)
return False

# Broadcast to all users
async def broadcast_message(sender_username: str, message: str):
for user_connections in active_connections.values():
for connection in user_connections:
await connection.send_text(
json.dumps({
"sender": sender_username,
"message": message,
"timestamp": time.time()
})
)

@app.websocket("/ws/chat")
async def websocket_endpoint(
websocket: WebSocket,
user: dict = Depends(get_user_from_token)
):
if not user:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

user_id = user["user_id"]

# Check for connection limit
if user_id in active_connections and len(active_connections[user_id]) >= MAX_CONNECTIONS_PER_USER:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

# Accept connection
await websocket.accept()

# Add to active connections
if user_id not in active_connections:
active_connections[user_id] = []
active_connections[user_id].append(websocket)

# Notify about new user
await broadcast_message(
"System",
f"User {user['username']} has joined the chat."
)

try:
while True:
# Receive and validate input
try:
raw_data = await websocket.receive_text()
data_dict = json.loads(raw_data)
chat_message = ChatMessage(**data_dict)

# Check rate limiting
if await is_rate_limited(user_id):
await websocket.send_text(
json.dumps({
"error": "Rate limit exceeded. Please wait before sending more messages."
})
)
continue

# Broadcast message
await broadcast_message(user["username"], chat_message.message)

except json.JSONDecodeError:
await websocket.send_text(json.dumps({"error": "Invalid JSON format"}))
except ValueError as e:
await websocket.send_text(json.dumps({"error": str(e)}))

except WebSocketDisconnect:
# Remove from active connections
active_connections[user_id].remove(websocket)
if not active_connections[user_id]:
del active_connections[user_id]

# Notify about user leaving
await broadcast_message(
"System",
f"User {user['username']} has left the chat."
)

To connect to this WebSocket from a client:

javascript
// Client-side JavaScript
const token = "your.jwt.token";
const socket = new WebSocket(`ws://localhost:8000/ws/chat?token=${token}`);

socket.onmessage = function(event) {
const data = JSON.parse(event.data);

if (data.error) {
console.error("Error:", data.error);
return;
}

console.log(`${data.sender}: ${data.message}`);
// Add to UI as needed
};

// Function to send messages
function sendMessage(message) {
socket.send(JSON.stringify({ message: message }));
}

HTTPS/WSS Protocol

In production, always ensure your WebSocket connections use the secure WSS protocol (WebSocket Secure) instead of WS, just like you would use HTTPS instead of HTTP. This ensures that data transmitted through the WebSocket is encrypted.

When deploying your FastAPI application behind a proxy like Nginx:

nginx
server {
listen 443 ssl;
server_name example.com;

ssl_certificate /path/to/cert.pem;
ssl_certificate_key /path/to/key.pem;

location / {
proxy_pass http://localhost:8000;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
}
}

Summary

In this guide, we've covered essential security practices for FastAPI WebSockets:

  1. Authentication: Using cookies, tokens, and custom middleware to authenticate users
  2. Rate limiting: Preventing abuse by limiting message frequency
  3. Input validation: Using Pydantic to validate incoming messages
  4. Connection limiting: Preventing denial-of-service through connection flooding
  5. Secure transport: Using WSS instead of WS in production

By implementing these security measures, you can build WebSocket applications that are not only feature-rich and real-time but also secure and resistant to common attacks.

Additional Resources

Exercise: Secure Chat Room

Challenge: Expand the secure chat application to support multiple chat rooms with room-specific permissions.

Requirements:

  1. Users can only join rooms they have permission for
  2. Each room has an owner who can moderate messages
  3. Implement message encryption for private rooms
  4. Add room-specific rate limits

Good luck enhancing your WebSocket security skills!



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)