Skip to main content

FastAPI WebSocket Monitoring

Introduction

WebSockets provide a powerful way to create real-time communication between clients and servers in FastAPI applications. However, as your application scales, monitoring these WebSocket connections becomes crucial for maintaining performance, debugging issues, and ensuring system reliability.

In this tutorial, we'll explore various methods to monitor WebSocket connections in FastAPI applications. Whether you're building a chat application, a real-time dashboard, or a collaborative tool, proper WebSocket monitoring will help you maintain a robust application.

Why Monitor WebSockets?

WebSockets create persistent connections that stay open for extended periods. This introduces unique challenges:

  • Resource management (memory, connections)
  • Performance tracking
  • Debugging disconnections
  • Understanding user behavior
  • Security monitoring

Basic WebSocket Connection Tracking

Let's start with a simple approach to track active WebSocket connections in your FastAPI application.

Creating a Connection Manager

python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List, Dict
import time
import asyncio

app = FastAPI()

class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.connection_stats: Dict = {
"total_connections": 0,
"active_connections": 0,
"messages_received": 0,
"messages_sent": 0,
}

async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
self.connection_stats["total_connections"] += 1
self.connection_stats["active_connections"] += 1

def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
self.connection_stats["active_connections"] -= 1

async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
self.connection_stats["messages_sent"] += 1

async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
self.connection_stats["messages_sent"] += 1

def get_stats(self):
return self.connection_stats

manager = ConnectionManager()

Implementing WebSocket Endpoints

python
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
manager.connection_stats["messages_received"] += 1
await manager.send_personal_message(f"You sent: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")

Adding a Stats Endpoint

To monitor your WebSocket connections, create an HTTP endpoint that provides statistics:

python
@app.get("/ws-stats")
async def get_stats():
return manager.get_stats()

When you access this endpoint, you'll see output similar to:

json
{
"total_connections": 10,
"active_connections": 3,
"messages_received": 25,
"messages_sent": 42
}

Advanced WebSocket Monitoring

For more detailed monitoring, let's implement additional features:

Per-Connection Metrics

python
import uuid
from datetime import datetime

class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Dict] = {}
self.connection_stats: Dict = {
"total_connections": 0,
"active_connections": 0,
"messages_received": 0,
"messages_sent": 0,
}

async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
connection_id = str(uuid.uuid4())
self.active_connections[connection_id] = {
"websocket": websocket,
"client_id": client_id,
"connected_at": datetime.now(),
"messages_received": 0,
"messages_sent": 0,
"last_activity": datetime.now(),
}
self.connection_stats["total_connections"] += 1
self.connection_stats["active_connections"] += 1
return connection_id

def disconnect(self, connection_id: str):
if connection_id in self.active_connections:
del self.active_connections[connection_id]
self.connection_stats["active_connections"] -= 1

async def send_personal_message(self, message: str, connection_id: str):
if connection_id in self.active_connections:
await self.active_connections[connection_id]["websocket"].send_text(message)
self.active_connections[connection_id]["messages_sent"] += 1
self.active_connections[connection_id]["last_activity"] = datetime.now()
self.connection_stats["messages_sent"] += 1

async def broadcast(self, message: str):
for conn_id, connection in self.active_connections.items():
await connection["websocket"].send_text(message)
connection["messages_sent"] += 1
connection["last_activity"] = datetime.now()
self.connection_stats["messages_sent"] += 1

def record_message_received(self, connection_id: str):
if connection_id in self.active_connections:
self.active_connections[connection_id]["messages_received"] += 1
self.active_connections[connection_id]["last_activity"] = datetime.now()
self.connection_stats["messages_received"] += 1

def get_stats(self):
connection_details = []
for conn_id, conn in self.active_connections.items():
connection_details.append({
"connection_id": conn_id,
"client_id": conn["client_id"],
"connected_at": conn["connected_at"].isoformat(),
"messages_received": conn["messages_received"],
"messages_sent": conn["messages_sent"],
"last_activity": conn["last_activity"].isoformat(),
"duration": (datetime.now() - conn["connected_at"]).total_seconds()
})

return {
"global_stats": self.connection_stats,
"connections": connection_details
}

manager = ConnectionManager()

Updated WebSocket Endpoint

python
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
connection_id = await manager.connect(websocket, client_id)
try:
while True:
data = await websocket.receive_text()
manager.record_message_received(connection_id)
await manager.send_personal_message(f"You sent: {data}", connection_id)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(connection_id)
await manager.broadcast(f"Client #{client_id} left the chat")

Periodic Monitoring Task

To monitor inactive connections and perform cleanup, you can add a background task:

python
@app.on_event("startup")
async def startup_event():
asyncio.create_task(periodic_monitoring())

async def periodic_monitoring():
while True:
await asyncio.sleep(30) # Check every 30 seconds
now = datetime.now()

# Find inactive connections (idle for more than 5 minutes)
inactive_connections = []
for conn_id, conn in manager.active_connections.items():
if (now - conn["last_activity"]).total_seconds() > 300: # 5 minutes
inactive_connections.append(conn_id)

# Log inactive connections
if inactive_connections:
print(f"Found {len(inactive_connections)} inactive connections")

# Optionally disconnect inactive connections
# for conn_id in inactive_connections:
# manager.disconnect(conn_id)

Integration with Monitoring Tools

For production applications, you might want to integrate with dedicated monitoring tools:

Prometheus Integration

Prometheus is a popular monitoring system that can track various metrics.

First, install the required libraries:

bash
pip install prometheus-client

Now, set up Prometheus metrics in your FastAPI application:

python
from prometheus_client import Counter, Gauge, start_http_server

# Define metrics
WEBSOCKET_CONNECTIONS = Gauge(
'websocket_connections_total',
'Number of active WebSocket connections'
)
WEBSOCKET_MESSAGES_RECEIVED = Counter(
'websocket_messages_received_total',
'Number of WebSocket messages received'
)
WEBSOCKET_MESSAGES_SENT = Counter(
'websocket_messages_sent_total',
'Number of WebSocket messages sent'
)

# Start Prometheus server on a different port
@app.on_event("startup")
def start_prometheus():
start_http_server(9090) # Prometheus metrics will be available on this port

# Update the ConnectionManager methods to track metrics
class ConnectionManager:
# ...existing code...

async def connect(self, websocket: WebSocket, client_id: str):
# ...existing code...
WEBSOCKET_CONNECTIONS.inc()
# ...rest of the method...

def disconnect(self, connection_id: str):
# ...existing code...
WEBSOCKET_CONNECTIONS.dec()
# ...rest of the method...

async def send_personal_message(self, message: str, connection_id: str):
# ...existing code...
WEBSOCKET_MESSAGES_SENT.inc()
# ...rest of the method...

def record_message_received(self, connection_id: str):
# ...existing code...
WEBSOCKET_MESSAGES_RECEIVED.inc()
# ...rest of the method...

With this setup, you can use Prometheus to monitor and alert on WebSocket metrics, and visualize them using tools like Grafana.

Debugging WebSocket Connections

Sometimes you need to debug specific WebSocket issues. Here's how you can add logging to help:

python
import logging

# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket_monitor")

@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
connection_id = await manager.connect(websocket, client_id)
logger.info(f"Client {client_id} connected with connection_id {connection_id}")

try:
while True:
data = await websocket.receive_text()
manager.record_message_received(connection_id)
logger.debug(f"Received message from {client_id}: {data[:50]}...")
await manager.send_personal_message(f"You sent: {data}", connection_id)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected")
manager.disconnect(connection_id)
await manager.broadcast(f"Client #{client_id} left the chat")
except Exception as e:
logger.error(f"Error handling connection for {client_id}: {str(e)}", exc_info=True)
manager.disconnect(connection_id)

Real-world Application: Admin Dashboard

Let's create a simple admin dashboard that displays WebSocket statistics using the API endpoint we built:

html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket Admin Dashboard</title>
<style>
body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
.stats-container { display: flex; gap: 20px; margin-bottom: 20px; }
.stat-box { background: #f5f5f5; border-radius: 5px; padding: 15px; flex: 1; }
.stat-box h3 { margin-top: 0; }
.stat-value { font-size: 2em; font-weight: bold; color: #2a6fa8; }
table { width: 100%; border-collapse: collapse; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
th { background-color: #f2f2f2; }
tr:nth-child(even) { background-color: #f9f9f9; }
</style>
</head>
<body>
<h1>WebSocket Monitoring Dashboard</h1>

<div class="stats-container">
<div class="stat-box">
<h3>Total Connections</h3>
<div id="total-connections" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Active Connections</h3>
<div id="active-connections" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Messages Received</h3>
<div id="messages-received" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Messages Sent</h3>
<div id="messages-sent" class="stat-value">0</div>
</div>
</div>

<h2>Active Connections</h2>
<table id="connections-table">
<thead>
<tr>
<th>Connection ID</th>
<th>Client ID</th>
<th>Connected At</th>
<th>Last Activity</th>
<th>Duration (s)</th>
<th>Messages Received</th>
<th>Messages Sent</th>
</tr>
</thead>
<tbody id="connections-body"></tbody>
</table>

<script>
// Fetch stats every 5 seconds
function fetchStats() {
fetch('/ws-stats')
.then(response => response.json())
.then(data => {
// Update global stats
document.getElementById('total-connections').textContent =
data.global_stats.total_connections;
document.getElementById('active-connections').textContent =
data.global_stats.active_connections;
document.getElementById('messages-received').textContent =
data.global_stats.messages_received;
document.getElementById('messages-sent').textContent =
data.global_stats.messages_sent;

// Update connections table
const tbody = document.getElementById('connections-body');
tbody.innerHTML = '';

data.connections.forEach(conn => {
const row = document.createElement('tr');
row.innerHTML = `
<td>${conn.connection_id}</td>
<td>${conn.client_id}</td>
<td>${new Date(conn.connected_at).toLocaleString()}</td>
<td>${new Date(conn.last_activity).toLocaleString()}</td>
<td>${conn.duration.toFixed(0)}</td>
<td>${conn.messages_received}</td>
<td>${conn.messages_sent}</td>
`;
tbody.appendChild(row);
});
})
.catch(error => console.error('Error fetching stats:', error));
}

// Initial fetch and set interval
fetchStats();
setInterval(fetchStats, 5000);
</script>
</body>
</html>

Save this as an HTML file and serve it with FastAPI:

python
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

@app.get("/admin", response_class=HTMLResponse)
async def admin_dashboard():
with open("admin_dashboard.html", "r") as file:
return file.read()

Summary

Monitoring WebSocket connections is essential for maintaining robust real-time applications. In this tutorial, we've covered:

  1. Basic tracking of WebSocket connections using a connection manager
  2. Advanced monitoring with per-connection metrics
  3. Integration with Prometheus for production-grade monitoring
  4. Debugging techniques for WebSocket connections
  5. Creating an admin dashboard to visualize WebSocket activity

By implementing these monitoring techniques, you'll gain valuable insights into your WebSocket application's behavior, making it easier to debug issues, optimize performance, and provide a better user experience.

Additional Resources

Exercises

  1. Extend the connection manager to track the IP address of each client
  2. Create a feature to send warning messages to clients that have been inactive for too long
  3. Implement rate limiting to prevent clients from sending too many messages
  4. Add authentication to the admin dashboard
  5. Create a visualization of connection durations using a charting library like Chart.js


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)