Skip to main content

FastAPI WebSocket Dependencies

Introduction

When building real-time applications with FastAPI WebSockets, you'll often need to share common functionality across different WebSocket endpoints, validate incoming connections, or implement authentication. This is where WebSocket dependencies come in handy.

Dependencies in FastAPI WebSockets work similarly to HTTP route dependencies but are adapted for the WebSocket protocol. They allow you to:

  • Reuse code between different WebSocket endpoints
  • Perform pre-connection validation and authentication
  • Inject database connections or other services
  • Handle error conditions before establishing the WebSocket connection

In this guide, we'll explore how to use dependencies with WebSocket endpoints in FastAPI to create more maintainable and secure real-time applications.

Basic Dependency Usage

Let's start with a simple example of how to use dependencies with a WebSocket endpoint:

python
from fastapi import FastAPI, WebSocket, Depends, WebSocketDisconnect
from typing import Optional

app = FastAPI()

async def get_token(websocket: WebSocket, token: Optional[str] = None):
if token is None:
token = websocket.query_params.get("token")

if not token or token != "secret_token":
await websocket.close(code=1008) # Policy violation
raise ValueError("Invalid token")

return token

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Depends(get_token)):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message text was: {data}, authenticated with token: {token}")
except WebSocketDisconnect:
print("Client disconnected")

In this example:

  1. We define a get_token dependency that extracts a token from query parameters
  2. It validates the token and raises an exception if invalid
  3. We use Depends(get_token) in our WebSocket endpoint to inject the token

Connection Validation with Dependencies

Dependencies are particularly useful for validating connections before accepting them:

python
from fastapi import FastAPI, WebSocket, Depends, HTTPException, status
from typing import List
import time

app = FastAPI()
connected_clients = []

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)

manager = ConnectionManager()

async def validate_connection(websocket: WebSocket):
# Get client IP
client = websocket.client.host

# Check if client is already connected
if client in connected_clients:
await websocket.close(code=1008)
raise ValueError("Connection already established")

# Add client to connected clients
connected_clients.append(client)
return client

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, client: str = Depends(validate_connection)):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
# Process the received data
await websocket.send_text(f"Client {client} sent: {data}")
# Broadcast to all connected clients
await manager.broadcast(f"Client {client}: {data}")
except Exception as e:
manager.disconnect(websocket)
connected_clients.remove(client)

This example shows how to use a dependency to validate that each client only has one active connection.

Advanced Authentication with WebSocket Dependencies

Let's implement a more advanced example with JWT token authentication for WebSockets:

python
from fastapi import FastAPI, WebSocket, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from pydantic import BaseModel
from typing import Optional

# Constants for JWT
SECRET_KEY = "your-secret-key" # In production, use a secure key
ALGORITHM = "HS256"

app = FastAPI()

class TokenData(BaseModel):
username: Optional[str] = None

async def get_current_user(websocket: WebSocket):
credentials_exception = ValueError("Could not validate credentials")

# Get token from query parameter
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=1008)
raise credentials_exception

try:
# Decode JWT
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
await websocket.close(code=1008)
raise credentials_exception

token_data = TokenData(username=username)
except JWTError:
await websocket.close(code=1008)
raise credentials_exception

# You could fetch user from database here
return token_data.username

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_user)):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message: {data}, User: {username}")
except Exception:
print(f"User {username} disconnected")

This example demonstrates how to extract and validate a JWT token from the WebSocket connection parameters.

Class-based Dependencies

You can also use class-based dependencies with WebSockets, which is useful for more complex scenarios:

python
from fastapi import FastAPI, WebSocket, Depends
from typing import List, Optional
import asyncio

app = FastAPI()

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.connection_count = 0

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
self.connection_count += 1
return self.connection_count

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)

# Create a single instance to be shared across requests
manager = ConnectionManager()

class WebSocketDependency:
def __init__(self, room_id: Optional[str] = None):
self.room_id = room_id or "default"

async def __call__(self, websocket: WebSocket):
client_id = websocket.headers.get("client-id")
if not client_id:
await websocket.close(code=1008)
raise ValueError("Client ID is required")

# You could do room-specific validation here
if self.room_id not in ["default", "premium"]:
await websocket.close(code=1003)
raise ValueError("Invalid room")

return {"client_id": client_id, "room": self.room_id}

@app.websocket("/ws/{room_id}")
async def websocket_endpoint(
websocket: WebSocket,
room_id: str,
client_data: dict = Depends(WebSocketDependency())
):
connection_id = await manager.connect(websocket)
await websocket.send_text(f"Connected! You are client #{connection_id}")
await websocket.send_text(f"Room: {client_data['room']}, Client ID: {client_data['client_id']}")

try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"You sent: {data}")
# Broadcast message with client ID and room
await manager.broadcast(
f"Client {client_data['client_id']} in room {client_data['room']} says: {data}"
)
except Exception:
manager.disconnect(websocket)

The class-based dependency above lets us parameterize our dependency with a room ID and perform different validation based on the room.

Dependency Injection with WebSocket and Background Tasks

You can combine WebSocket dependencies with background tasks for handling resource-intensive operations:

python
from fastapi import FastAPI, WebSocket, Depends, BackgroundTasks
import asyncio
import time
from datetime import datetime

app = FastAPI()

async def process_data_in_background(data: str, client_id: str):
# Simulate time-consuming process
await asyncio.sleep(2)
process_time = datetime.now().strftime("%H:%M:%S")
print(f"[{process_time}] Processed data '{data}' from client {client_id}")

class DataProcessor:
def __init__(self):
self.processing_queue = []

async def process(self, data: str, client_id: str):
# Process data (simulated)
process_time = datetime.now().strftime("%H:%M:%S")
print(f"[{process_time}] Processing data '{data}' from client {client_id}")
await asyncio.sleep(0.5) # Simulate some processing time
return f"Processed: {data} at {process_time}"

processor = DataProcessor()

async def get_processor():
return processor

async def get_client_id(websocket: WebSocket):
client_id = websocket.headers.get("client-id")
if not client_id:
client_id = f"anonymous-{time.time()}"
return client_id

@app.websocket("/ws/process")
async def websocket_endpoint(
websocket: WebSocket,
background_tasks: BackgroundTasks,
processor: DataProcessor = Depends(get_processor),
client_id: str = Depends(get_client_id)
):
await websocket.accept()
await websocket.send_text(f"Connected as client: {client_id}")

try:
while True:
data = await websocket.receive_text()
# Process data with our injected processor
result = await processor.process(data, client_id)
await websocket.send_text(result)

# Schedule background task
background_tasks.add_task(process_data_in_background, data, client_id)
await websocket.send_text("Data queued for background processing")
except Exception as e:
print(f"Error: {e}")

This example demonstrates how to:

  1. Inject a data processor as a dependency
  2. Get the client ID through a dependency
  3. Use FastAPI's background tasks with WebSockets

Real-world Example: Chat Room with User Authentication

Let's put everything together in a more complete example of a chat room with authentication:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException
from fastapi.responses import HTMLResponse
from jose import JWTError, jwt
from typing import Dict, List, Optional
import json

app = FastAPI()

# Constants for JWT
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"

# HTML for a simple chat interface
html = """
<!DOCTYPE html>
<html>
<head>
<title>FastAPI WebSocket Chat</title>
<style>
#messages { width: 600px; height: 400px; overflow: auto; border: 1px solid #ccc; }
.message { padding: 5px; margin: 5px 0; border-bottom: 1px solid #eee; }
.system { color: #999; font-style: italic; }
.user { color: #333; }
.me { background-color: #f0f0f0; }
</style>
</head>
<body>
<h1>WebSocket Chat with Authentication</h1>
<div>
<label for="token">Auth Token:</label>
<input type="text" id="token" placeholder="Enter your token" />
</div>
<div>
<button onclick="connectWebSocket()">Connect</button>
<button onclick="disconnectWebSocket()">Disconnect</button>
</div>
<div id="messages"></div>
<div>
<input type="text" id="messageText" placeholder="Type a message"/>
<button onclick="sendMessage()">Send</button>
</div>
<script>
let ws = null;

function connectWebSocket() {
const token = document.getElementById("token").value;
if (!token) {
addMessage("System", "Please enter a token first");
return;
}

if (ws) {
ws.close();
}

ws = new WebSocket(`ws://localhost:8000/ws/chat?token=${token}`);

ws.onopen = function(event) {
addMessage("System", "Connected to chat server");
};

ws.onmessage = function(event) {
const data = JSON.parse(event.data);
addMessage(data.sender, data.message, data.is_me);
};

ws.onclose = function(event) {
addMessage("System", `Disconnected: ${event.reason || "No reason provided"}`);
};

ws.onerror = function(error) {
addMessage("System", "Error occurred");
};
}

function disconnectWebSocket() {
if (ws) {
ws.close();
ws = null;
}
}

function sendMessage() {
if (ws && ws.readyState === WebSocket.OPEN) {
const messageText = document.getElementById("messageText").value;
if (messageText) {
ws.send(messageText);
document.getElementById("messageText").value = "";
}
} else {
addMessage("System", "Not connected to chat server");
}
}

function addMessage(sender, message, isMe = false) {
const messagesDiv = document.getElementById("messages");
const messageDiv = document.createElement("div");
messageDiv.className = `message ${isMe ? "me" : ""} ${sender === "System" ? "system" : "user"}`;
messageDiv.textContent = `${sender}: ${message}`;
messagesDiv.appendChild(messageDiv);
messagesDiv.scrollTop = messagesDiv.scrollHeight;
}
</script>
</body>
</html>
"""

class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}

async def connect(self, websocket: WebSocket, user: str):
await websocket.accept()
if user not in self.active_connections:
self.active_connections[user] = []
self.active_connections[user].append(websocket)

def disconnect(self, websocket: WebSocket, user: str):
self.active_connections[user].remove(websocket)
if not self.active_connections[user]:
del self.active_connections[user]

async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)

async def broadcast(self, message: str, sender: str):
for user, connections in self.active_connections.items():
for connection in connections:
payload = {
"message": message,
"sender": sender,
"is_me": user == sender
}
await connection.send_text(json.dumps(payload))

manager = ConnectionManager()

async def get_current_user(websocket: WebSocket):
credentials_exception = ValueError("Could not validate credentials")

# Get token from query parameter
token = websocket.query_params.get("token")
if not token:
await websocket.close(code=1008, reason="Token missing")
raise credentials_exception

try:
# Just for demonstration - decode username from token
# In a real app, verify the JWT properly
if token == "user1":
return "Alice"
elif token == "user2":
return "Bob"
else:
await websocket.close(code=1008, reason="Invalid token")
raise credentials_exception
except JWTError:
await websocket.close(code=1008, reason="Invalid token format")
raise credentials_exception

@app.get("/")
async def get():
return HTMLResponse(html)

@app.websocket("/ws/chat")
async def websocket_endpoint(websocket: WebSocket, username: str = Depends(get_current_user)):
await manager.connect(websocket, username)

# Announce new user
await manager.broadcast(f"{username} has joined the chat", "System")

try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data, username)
except WebSocketDisconnect:
manager.disconnect(websocket, username)
await manager.broadcast(f"{username} has left the chat", "System")

This complete example creates a chat room where:

  1. Users authenticate with a simple token
  2. The authentication is handled by a dependency
  3. The connection manager keeps track of all websocket connections
  4. Messages are broadcast to all connected users
  5. The UI shows who sent each message

Summary

WebSocket dependencies in FastAPI provide a powerful way to organize your code, implement authentication, and reuse common functionality across WebSocket endpoints. Key takeaways include:

  • Dependencies work with WebSocket routes similar to how they work with HTTP routes
  • They're great for connection validation and authentication
  • You can use them to inject services like database connections
  • Class-based dependencies allow for parameterization
  • They integrate well with other FastAPI features like background tasks

By using WebSocket dependencies, you can create cleaner, more maintainable, and more secure real-time applications.

Additional Resources

Exercises

  1. Create a WebSocket endpoint with a dependency that limits the number of connections per IP address.
  2. Implement a chat system with different "rooms" using dependencies to validate room access permissions.
  3. Create a class-based WebSocket dependency that monitors and logs all messages for a specific user.
  4. Implement rate limiting in a WebSocket endpoint using a dependency.
  5. Create a WebSocket endpoint with a dependency that verifies a database connection before allowing the connection.


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)