Skip to main content

FastAPI WebSocket Authentication

Introduction

WebSockets provide powerful real-time communication capabilities in web applications, but they also introduce new security concerns. Unlike regular HTTP endpoints, WebSocket connections are persistent and don't follow the same request-response pattern. This means we need special authentication strategies to ensure that only authorized users can establish and maintain WebSocket connections.

In this tutorial, we'll explore different approaches to authenticate WebSocket connections in FastAPI applications. We'll learn how to:

  • Implement cookie-based authentication for WebSockets
  • Use query parameters for authentication
  • Create custom authentication dependencies for WebSockets
  • Handle authentication failures gracefully

Prerequisites

Before we begin, make sure you have:

  • Basic knowledge of FastAPI
  • Understanding of WebSockets fundamentals
  • Python 3.7+ installed
  • Experience with authentication concepts (sessions, tokens, etc.)

Authentication Challenges with WebSockets

WebSockets present unique authentication challenges compared to regular HTTP endpoints:

  1. Persistent connections: WebSocket connections stay open, so we need to authenticate at connection time
  2. No request-response cycle: We can't use traditional middleware easily
  3. Limited header access: Some environments restrict WebSocket header manipulation

Let's explore solutions to these challenges.

Basic WebSocket Authentication Approaches

1. Authentication via Query Parameters

The simplest approach is to include authentication credentials in the WebSocket connection URL as query parameters.

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status
from typing import Dict

app = FastAPI()

# In a real application, you'd use a proper database
USERS = {
"user1": "password1",
"user2": "password2"
}

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, username: str, password: str):
if username not in USERS or USERS[username] != password:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
print(f"Client {username} disconnected")

To connect to this WebSocket, clients would use a URL like: ws://example.com/ws?username=user1&password=password1

Advantages:

  • Simple to implement
  • Works in almost any environment

Disadvantages:

  • Very insecure, as credentials are exposed in URLs (which might be logged)
  • Not suitable for production applications

A more secure approach is to use cookies for authentication, which can be set during a regular HTTP login process.

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status, Cookie, Depends
from fastapi.responses import HTMLResponse
from typing import Optional

app = FastAPI()

# In a real application, you'd use a proper session management system
SESSIONS = {"valid_session_id": "user1"}

@app.get("/login")
async def login():
response = HTMLResponse("<h1>Logged in!</h1>")
response.set_cookie(key="session_id", value="valid_session_id")
return response

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
session_id: Optional[str] = Cookie(None)
):
if session_id not in SESSIONS:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

user = SESSIONS[session_id]
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"{user} says: {data}")
except WebSocketDisconnect:
print(f"Client {user} disconnected")

In this example, users first visit /login to get a session cookie, then connect to the WebSocket.

Creating a WebSocket Authentication Dependency

For more complex applications, we can create reusable authentication dependencies for WebSockets:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, status, Depends, Cookie
from typing import Optional, Dict, Callable, Awaitable

app = FastAPI()

# Mock database of users and sessions
SESSIONS = {"valid_session_id": "user1"}

class WebSocketAuthMiddleware:
def __init__(self):
pass

async def authenticate(self, websocket: WebSocket, session_id: Optional[str] = Cookie(None)):
if not session_id or session_id not in SESSIONS:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return None
return SESSIONS[session_id]

ws_auth = WebSocketAuthMiddleware()

@app.websocket("/ws")
async def websocket_endpoint(
websocket: WebSocket,
user: str = Depends(ws_auth.authenticate)
):
if not user: # Authentication failed
return

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"{user} says: {data}")
except WebSocketDisconnect:
print(f"Client {user} disconnected")

This dependency pattern allows for clean separation of authentication logic from WebSocket handling.

Token-Based Authentication

For more modern applications, token-based authentication (like JWT) can be a better choice:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Query, status
import jwt
from datetime import datetime, timedelta

app = FastAPI()

# In a real application, store this securely
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"

def create_access_token(data: dict):
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

@app.get("/token/{username}")
async def get_token(username: str):
# In a real app, verify username/password first
access_token = create_access_token(data={"sub": username})
return {"access_token": access_token}

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = Query(...)):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username = payload.get("sub")
if not username:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
except jwt.PyJWTError:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"User {username}: {data}")
except WebSocketDisconnect:
print(f"Client {username} disconnected")

Clients would first make a GET request to /token/{username} to obtain a JWT token, then connect using: ws://example.com/ws?token=eyJ0eXAiOi...

Real-World Example: Chat Room with Authentication

Let's create a more comprehensive example of a chat room with authentication:

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from typing import Dict, List, Optional
from pydantic import BaseModel
import jwt
from datetime import datetime, timedelta

app = FastAPI()

# Mock database
USERS = {
"john": {
"username": "john",
"full_name": "John Doe",
"email": "[email protected]",
"hashed_password": "fakehashedsecret",
}
}

# Security settings
SECRET_KEY = "super-secret-key-replace-in-production"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

# OAuth2 security setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# Models
class User(BaseModel):
username: str
email: Optional[str] = None
full_name: Optional[str] = None

class Token(BaseModel):
access_token: str
token_type: str

# Helper functions
def verify_password(plain_password, hashed_password):
# In a real app, use proper password hashing (like bcrypt)
return plain_password == "secret" and hashed_password == "fakehashedsecret"

def get_user(db, username: str):
if username in db:
return db[username]
return None

def authenticate_user(db, username: str, password: str):
user = get_user(db, username)
if not user:
return False
if not verify_password(password, user["hashed_password"]):
return False
return user

def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

# Token endpoint
@app.post("/token", response_model=Token)
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(USERS, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}

# WebSocket connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, List[WebSocket]] = {}

async def connect(self, websocket: WebSocket, user: str):
await websocket.accept()
if user not in self.active_connections:
self.active_connections[user] = []
self.active_connections[user].append(websocket)

def disconnect(self, websocket: WebSocket, user: str):
self.active_connections[user].remove(websocket)
if not self.active_connections[user]:
del self.active_connections[user]

async def broadcast(self, message: str, sender: str):
for user in self.active_connections:
for connection in self.active_connections[user]:
await connection.send_text(f"{sender}: {message}")

manager = ConnectionManager()

# WebSocket authentication dependency
async def get_websocket_user(token: str):
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
return None
except jwt.PyJWTError:
return None
return username

# WebSocket endpoint
@app.websocket("/ws/{token}")
async def websocket_endpoint(websocket: WebSocket, token: str):
user = await get_websocket_user(token)
if user is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return

await manager.connect(websocket, user)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(data, user)
except WebSocketDisconnect:
manager.disconnect(websocket, user)
await manager.broadcast(f"{user} has left the chat", "System")

To use this chat application:

  1. First, obtain a token by POST request to /token with username and password
  2. Connect to the WebSocket using: ws://example.com/ws/{token}
  3. Messages sent to the WebSocket are broadcast to all connected users

Best Practices for WebSocket Authentication

  1. Never send sensitive credentials in the WebSocket URL (except for carefully designed tokens)
  2. Implement token expiration to limit the lifetime of WebSocket connections
  3. Use secure cookies with the Secure and HttpOnly flags when possible
  4. Consider rate limiting to prevent abuse of your WebSocket endpoints
  5. Implement proper error handling to avoid exposing sensitive information
  6. Log authentication failures to detect potential attacks
  7. Use HTTPS/WSS for all connections to encrypt traffic

Summary

Authentication for WebSockets in FastAPI requires a different approach than traditional HTTP endpoints. The most common strategies include:

  • Query parameter-based token authentication
  • Cookie-based authentication leveraging existing session mechanisms
  • Custom dependency-based authentication for more complex scenarios

When implementing WebSocket authentication, remember that security is paramount. Tokens should be short-lived, credentials should never be exposed in URLs (except for one-time tokens), and all communications should be encrypted using WSS (WebSocket Secure).

Additional Resources

Exercises

  1. Extend the chat room example to include user rooms where users can join specific chat rooms.
  2. Implement a mechanism to refresh tokens without disconnecting the WebSocket.
  3. Add role-based access control to limit certain users' ability to broadcast messages.
  4. Create a system that tracks and limits the number of concurrent WebSocket connections per user.
  5. Implement a heartbeat mechanism that periodically verifies the token is still valid.

By following these patterns and best practices, you can build secure, real-time applications with FastAPI WebSockets that protect your users' data and your application's resources.



If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)