FastAPI WebSocket Monitoring
Introduction
WebSockets provide a powerful way to create real-time communication between clients and servers in FastAPI applications. However, as your application scales, monitoring these WebSocket connections becomes crucial for maintaining performance, debugging issues, and ensuring system reliability.
In this tutorial, we'll explore various methods to monitor WebSocket connections in FastAPI applications. Whether you're building a chat application, a real-time dashboard, or a collaborative tool, proper WebSocket monitoring will help you maintain a robust application.
Why Monitor WebSockets?
WebSockets create persistent connections that stay open for extended periods. This introduces unique challenges:
- Resource management (memory, connections)
- Performance tracking
- Debugging disconnections
- Understanding user behavior
- Security monitoring
Basic WebSocket Connection Tracking
Let's start with a simple approach to track active WebSocket connections in your FastAPI application.
Creating a Connection Manager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List, Dict
import time
import asyncio
app = FastAPI()
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.connection_stats: Dict = {
"total_connections": 0,
"active_connections": 0,
"messages_received": 0,
"messages_sent": 0,
}
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
self.connection_stats["total_connections"] += 1
self.connection_stats["active_connections"] += 1
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
self.connection_stats["active_connections"] -= 1
async def send_personal_message(self, message: str, websocket: WebSocket):
await websocket.send_text(message)
self.connection_stats["messages_sent"] += 1
async def broadcast(self, message: str):
for connection in self.active_connections:
await connection.send_text(message)
self.connection_stats["messages_sent"] += 1
def get_stats(self):
return self.connection_stats
manager = ConnectionManager()
Implementing WebSocket Endpoints
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
await manager.connect(websocket)
try:
while True:
data = await websocket.receive_text()
manager.connection_stats["messages_received"] += 1
await manager.send_personal_message(f"You sent: {data}", websocket)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket)
await manager.broadcast(f"Client #{client_id} left the chat")
Adding a Stats Endpoint
To monitor your WebSocket connections, create an HTTP endpoint that provides statistics:
@app.get("/ws-stats")
async def get_stats():
return manager.get_stats()
When you access this endpoint, you'll see output similar to:
{
"total_connections": 10,
"active_connections": 3,
"messages_received": 25,
"messages_sent": 42
}
Advanced WebSocket Monitoring
For more detailed monitoring, let's implement additional features:
Per-Connection Metrics
import uuid
from datetime import datetime
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Dict] = {}
self.connection_stats: Dict = {
"total_connections": 0,
"active_connections": 0,
"messages_received": 0,
"messages_sent": 0,
}
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
connection_id = str(uuid.uuid4())
self.active_connections[connection_id] = {
"websocket": websocket,
"client_id": client_id,
"connected_at": datetime.now(),
"messages_received": 0,
"messages_sent": 0,
"last_activity": datetime.now(),
}
self.connection_stats["total_connections"] += 1
self.connection_stats["active_connections"] += 1
return connection_id
def disconnect(self, connection_id: str):
if connection_id in self.active_connections:
del self.active_connections[connection_id]
self.connection_stats["active_connections"] -= 1
async def send_personal_message(self, message: str, connection_id: str):
if connection_id in self.active_connections:
await self.active_connections[connection_id]["websocket"].send_text(message)
self.active_connections[connection_id]["messages_sent"] += 1
self.active_connections[connection_id]["last_activity"] = datetime.now()
self.connection_stats["messages_sent"] += 1
async def broadcast(self, message: str):
for conn_id, connection in self.active_connections.items():
await connection["websocket"].send_text(message)
connection["messages_sent"] += 1
connection["last_activity"] = datetime.now()
self.connection_stats["messages_sent"] += 1
def record_message_received(self, connection_id: str):
if connection_id in self.active_connections:
self.active_connections[connection_id]["messages_received"] += 1
self.active_connections[connection_id]["last_activity"] = datetime.now()
self.connection_stats["messages_received"] += 1
def get_stats(self):
connection_details = []
for conn_id, conn in self.active_connections.items():
connection_details.append({
"connection_id": conn_id,
"client_id": conn["client_id"],
"connected_at": conn["connected_at"].isoformat(),
"messages_received": conn["messages_received"],
"messages_sent": conn["messages_sent"],
"last_activity": conn["last_activity"].isoformat(),
"duration": (datetime.now() - conn["connected_at"]).total_seconds()
})
return {
"global_stats": self.connection_stats,
"connections": connection_details
}
manager = ConnectionManager()
Updated WebSocket Endpoint
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
connection_id = await manager.connect(websocket, client_id)
try:
while True:
data = await websocket.receive_text()
manager.record_message_received(connection_id)
await manager.send_personal_message(f"You sent: {data}", connection_id)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
manager.disconnect(connection_id)
await manager.broadcast(f"Client #{client_id} left the chat")
Periodic Monitoring Task
To monitor inactive connections and perform cleanup, you can add a background task:
@app.on_event("startup")
async def startup_event():
asyncio.create_task(periodic_monitoring())
async def periodic_monitoring():
while True:
await asyncio.sleep(30) # Check every 30 seconds
now = datetime.now()
# Find inactive connections (idle for more than 5 minutes)
inactive_connections = []
for conn_id, conn in manager.active_connections.items():
if (now - conn["last_activity"]).total_seconds() > 300: # 5 minutes
inactive_connections.append(conn_id)
# Log inactive connections
if inactive_connections:
print(f"Found {len(inactive_connections)} inactive connections")
# Optionally disconnect inactive connections
# for conn_id in inactive_connections:
# manager.disconnect(conn_id)
Integration with Monitoring Tools
For production applications, you might want to integrate with dedicated monitoring tools:
Prometheus Integration
Prometheus is a popular monitoring system that can track various metrics.
First, install the required libraries:
pip install prometheus-client
Now, set up Prometheus metrics in your FastAPI application:
from prometheus_client import Counter, Gauge, start_http_server
# Define metrics
WEBSOCKET_CONNECTIONS = Gauge(
'websocket_connections_total',
'Number of active WebSocket connections'
)
WEBSOCKET_MESSAGES_RECEIVED = Counter(
'websocket_messages_received_total',
'Number of WebSocket messages received'
)
WEBSOCKET_MESSAGES_SENT = Counter(
'websocket_messages_sent_total',
'Number of WebSocket messages sent'
)
# Start Prometheus server on a different port
@app.on_event("startup")
def start_prometheus():
start_http_server(9090) # Prometheus metrics will be available on this port
# Update the ConnectionManager methods to track metrics
class ConnectionManager:
# ...existing code...
async def connect(self, websocket: WebSocket, client_id: str):
# ...existing code...
WEBSOCKET_CONNECTIONS.inc()
# ...rest of the method...
def disconnect(self, connection_id: str):
# ...existing code...
WEBSOCKET_CONNECTIONS.dec()
# ...rest of the method...
async def send_personal_message(self, message: str, connection_id: str):
# ...existing code...
WEBSOCKET_MESSAGES_SENT.inc()
# ...rest of the method...
def record_message_received(self, connection_id: str):
# ...existing code...
WEBSOCKET_MESSAGES_RECEIVED.inc()
# ...rest of the method...
With this setup, you can use Prometheus to monitor and alert on WebSocket metrics, and visualize them using tools like Grafana.
Debugging WebSocket Connections
Sometimes you need to debug specific WebSocket issues. Here's how you can add logging to help:
import logging
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket_monitor")
@app.websocket("/ws/{client_id}")
async def websocket_endpoint(websocket: WebSocket, client_id: str):
connection_id = await manager.connect(websocket, client_id)
logger.info(f"Client {client_id} connected with connection_id {connection_id}")
try:
while True:
data = await websocket.receive_text()
manager.record_message_received(connection_id)
logger.debug(f"Received message from {client_id}: {data[:50]}...")
await manager.send_personal_message(f"You sent: {data}", connection_id)
await manager.broadcast(f"Client #{client_id} says: {data}")
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected")
manager.disconnect(connection_id)
await manager.broadcast(f"Client #{client_id} left the chat")
except Exception as e:
logger.error(f"Error handling connection for {client_id}: {str(e)}", exc_info=True)
manager.disconnect(connection_id)
Real-world Application: Admin Dashboard
Let's create a simple admin dashboard that displays WebSocket statistics using the API endpoint we built:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket Admin Dashboard</title>
<style>
body { font-family: Arial, sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
.stats-container { display: flex; gap: 20px; margin-bottom: 20px; }
.stat-box { background: #f5f5f5; border-radius: 5px; padding: 15px; flex: 1; }
.stat-box h3 { margin-top: 0; }
.stat-value { font-size: 2em; font-weight: bold; color: #2a6fa8; }
table { width: 100%; border-collapse: collapse; }
th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }
th { background-color: #f2f2f2; }
tr:nth-child(even) { background-color: #f9f9f9; }
</style>
</head>
<body>
<h1>WebSocket Monitoring Dashboard</h1>
<div class="stats-container">
<div class="stat-box">
<h3>Total Connections</h3>
<div id="total-connections" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Active Connections</h3>
<div id="active-connections" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Messages Received</h3>
<div id="messages-received" class="stat-value">0</div>
</div>
<div class="stat-box">
<h3>Messages Sent</h3>
<div id="messages-sent" class="stat-value">0</div>
</div>
</div>
<h2>Active Connections</h2>
<table id="connections-table">
<thead>
<tr>
<th>Connection ID</th>
<th>Client ID</th>
<th>Connected At</th>
<th>Last Activity</th>
<th>Duration (s)</th>
<th>Messages Received</th>
<th>Messages Sent</th>
</tr>
</thead>
<tbody id="connections-body"></tbody>
</table>
<script>
// Fetch stats every 5 seconds
function fetchStats() {
fetch('/ws-stats')
.then(response => response.json())
.then(data => {
// Update global stats
document.getElementById('total-connections').textContent =
data.global_stats.total_connections;
document.getElementById('active-connections').textContent =
data.global_stats.active_connections;
document.getElementById('messages-received').textContent =
data.global_stats.messages_received;
document.getElementById('messages-sent').textContent =
data.global_stats.messages_sent;
// Update connections table
const tbody = document.getElementById('connections-body');
tbody.innerHTML = '';
data.connections.forEach(conn => {
const row = document.createElement('tr');
row.innerHTML = `
`;
tbody.appendChild(row);
});
})
.catch(error => console.error('Error fetching stats:', error));
}
// Initial fetch and set interval
fetchStats();
setInterval(fetchStats, 5000);
</script>
</body>
</html>
Save this as an HTML file and serve it with FastAPI:
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
@app.get("/admin", response_class=HTMLResponse)
async def admin_dashboard():
with open("admin_dashboard.html", "r") as file:
return file.read()
Summary
Monitoring WebSocket connections is essential for maintaining robust real-time applications. In this tutorial, we've covered:
- Basic tracking of WebSocket connections using a connection manager
- Advanced monitoring with per-connection metrics
- Integration with Prometheus for production-grade monitoring
- Debugging techniques for WebSocket connections
- Creating an admin dashboard to visualize WebSocket activity
By implementing these monitoring techniques, you'll gain valuable insights into your WebSocket application's behavior, making it easier to debug issues, optimize performance, and provide a better user experience.
Additional Resources
- FastAPI WebSocket Documentation
- Prometheus Documentation
- Grafana Dashboards
- WebSockets MDN Documentation
Exercises
- Extend the connection manager to track the IP address of each client
- Create a feature to send warning messages to clients that have been inactive for too long
- Implement rate limiting to prevent clients from sending too many messages
- Add authentication to the admin dashboard
- Create a visualization of connection durations using a charting library like Chart.js
If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)