Skip to main content

FastAPI Security Testing

Introduction

Security is a critical aspect of any web application. As developers, we need to ensure our FastAPI applications are protected against common vulnerabilities and that authentication and authorization mechanisms work correctly. In this tutorial, we'll explore how to test security features in FastAPI applications.

Security testing helps identify vulnerabilities before they can be exploited. For FastAPI applications, this means verifying that:

  • Authentication systems work as expected
  • Authorization controls protect sensitive endpoints
  • The application is resistant to common web attacks
  • API rate limiting and other protections function correctly

Setting Up a Test Environment for Security Testing

Before we start writing security tests, we need to set up a proper testing environment.

Basic Test Setup

python
import pytest
from fastapi.testclient import TestClient
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from datetime import datetime, timedelta
from typing import Optional

# Your FastAPI app
app = FastAPI()

# Secret key for JWT encoding/decoding
SECRET_KEY = "testsecretkey"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")

# Test client
client = TestClient(app)

Creating Example User Authentication System

Let's create a simple authentication system to test:

python
# Mock user database
fake_users_db = {
"testuser": {
"username": "testuser",
"full_name": "Test User",
"email": "[email protected]",
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "secret"
"disabled": False,
}
}

# Function to create access token
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=15))
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

# Function to get current user
async def get_current_user(token: str = Depends(oauth2_scheme)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception

user = fake_users_db.get(username)
if user is None:
raise credentials_exception
return user

# Function to get current active user
async def get_current_active_user(current_user = Depends(get_current_user)):
if current_user.get("disabled"):
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

# Login endpoint for generating tokens
@app.post("/token")
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
user = fake_users_db.get(form_data.username)
if not user or form_data.password != "secret": # For testing purposes only, use proper hashing in production
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)

access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user["username"]}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}

# Protected endpoint
@app.get("/users/me")
async def read_users_me(current_user = Depends(get_current_active_user)):
return current_user

# Admin-only endpoint
@app.get("/admin")
async def admin_route(current_user = Depends(get_current_active_user)):
if current_user["username"] != "admin":
raise HTTPException(status_code=403, detail="Not authorized")
return {"message": "Admin area"}

Testing Authentication

Let's start testing our authentication system:

1. Testing Login and Token Generation

python
def test_login_success():
response = client.post(
"/token",
data={"username": "testuser", "password": "secret"}
)
assert response.status_code == 200
assert "access_token" in response.json()
assert response.json()["token_type"] == "bearer"

def test_login_invalid_credentials():
response = client.post(
"/token",
data={"username": "testuser", "password": "wrongpassword"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Incorrect username or password"

2. Testing Protected Endpoints

python
def test_protected_route_with_valid_token():
# First get a token
login_response = client.post(
"/token",
data={"username": "testuser", "password": "secret"}
)
token = login_response.json()["access_token"]

# Use token to access protected endpoint
response = client.get(
"/users/me",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 200
assert response.json()["username"] == "testuser"

def test_protected_route_without_token():
response = client.get("/users/me")
assert response.status_code == 401
assert response.json()["detail"] == "Not authenticated"

def test_protected_route_with_invalid_token():
response = client.get(
"/users/me",
headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == 401
assert response.json()["detail"] == "Could not validate credentials"

Testing Authorization

Authorization testing verifies that users can only access resources they have permission to:

python
def test_admin_route_not_authorized():
# Log in as regular user
login_response = client.post(
"/token",
data={"username": "testuser", "password": "secret"}
)
token = login_response.json()["access_token"]

# Try to access admin route
response = client.get(
"/admin",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 403
assert response.json()["detail"] == "Not authorized"

# Add admin user to fake DB for testing
fake_users_db["admin"] = {
"username": "admin",
"full_name": "Admin User",
"email": "[email protected]",
"hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW", # "secret"
"disabled": False,
}

def test_admin_route_authorized():
# Log in as admin
login_response = client.post(
"/token",
data={"username": "admin", "password": "secret"}
)
token = login_response.json()["access_token"]

# Access admin route
response = client.get(
"/admin",
headers={"Authorization": f"Bearer {token}"}
)
assert response.status_code == 200
assert response.json()["message"] == "Admin area"

Testing CORS Configuration

Cross-Origin Resource Sharing (CORS) is important for security. Let's test CORS configuration:

python
# First, add CORS middleware to your application
from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
CORSMiddleware,
allow_origins=["https://allowed-origin.com"],
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["*"],
)

def test_cors_headers():
response = client.options(
"/users/me",
headers={
"Origin": "https://allowed-origin.com",
"Access-Control-Request-Method": "GET",
},
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://allowed-origin.com"
assert "GET" in response.headers["access-control-allow-methods"]

def test_cors_disallowed_origin():
response = client.options(
"/users/me",
headers={
"Origin": "https://disallowed-origin.com",
"Access-Control-Request-Method": "GET",
},
)
# Check that the disallowed origin doesn't get CORS headers
assert "access-control-allow-origin" not in response.headers

Testing Rate Limiting

Rate limiting is an essential security feature to prevent abuse. Let's implement and test a simple rate limiter:

python
import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

# Simple rate limiter middleware
class RateLimitMiddleware(BaseHTTPMiddleware):
def __init__(self, app, calls=5, period=60):
super().__init__(app)
self.calls = calls
self.period = period
self.requests = {}

async def dispatch(self, request: Request, call_next):
client_ip = request.client.host
current_time = time.time()

# Clean old requests
if client_ip in self.requests:
self.requests[client_ip] = [
timestamp for timestamp in self.requests[client_ip]
if timestamp > current_time - self.period
]
else:
self.requests[client_ip] = []

# Check if rate limit is exceeded
if len(self.requests[client_ip]) >= self.calls:
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded"}
)

# Add current request
self.requests[client_ip].append(current_time)

return await call_next(request)

# Add rate limiter middleware (for testing with very strict limits)
app.add_middleware(RateLimitMiddleware, calls=3, period=10)

def test_rate_limiting():
# Make 3 requests (allowed)
for _ in range(3):
response = client.get("/")
assert response.status_code != 429

# 4th request should be rate limited
response = client.get("/")
assert response.status_code == 429
assert response.json()["detail"] == "Rate limit exceeded"

# Wait for rate limit to reset
time.sleep(10)

# Should be able to make requests again
response = client.get("/")
assert response.status_code != 429

Testing for SQL Injection Vulnerabilities

Let's test that our application is resistant to SQL injection:

python
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

# Setup SQLAlchemy with SQLite in-memory database
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

# Define User model
class User(Base):
__tablename__ = "users"

id = Column(Integer, primary_key=True, index=True)
username = Column(String, unique=True, index=True)
email = Column(String, unique=True, index=True)

# Create the tables
Base.metadata.create_all(bind=engine)

# Dependency to get database session
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

# Endpoint potentially vulnerable to SQL injection
@app.get("/users/search")
def search_users(query: str):
# DO NOT DO THIS IN REAL APPLICATIONS - This is vulnerable to SQL injection
# We're creating this on purpose to test injection prevention
db = SessionLocal()
result = db.execute(f"SELECT * FROM users WHERE username LIKE '%{query}%'").fetchall()
db.close()
return {"users": [dict(r) for r in result]}

# Safe endpoint using SQLAlchemy ORM
@app.get("/users/search-safe")
def search_users_safe(query: str, db = Depends(get_db)):
# This is the safe way - Using ORM
users = db.query(User).filter(User.username.like(f"%{query}%")).all()
return {"users": users}

def test_sql_injection():
# Add a test user
db = SessionLocal()
test_user = User(username="testuser", email="[email protected]")
db.add(test_user)
db.commit()

# Test normal query
response = client.get("/users/search?query=test")
assert response.status_code == 200

# Test SQL injection attempt
injection = "test' UNION SELECT username, email, id FROM users--"
response = client.get(f"/users/search?query={injection}")

# If application is vulnerable, this would return all users
# This test is expected to fail in a vulnerable application
# In a secure application, the query would fail or return no results

# Test secure endpoint
response = client.get("/users/search-safe?query=test")
assert response.status_code == 200

# SQL injection shouldn't work on the safe endpoint
response = client.get(f"/users/search-safe?query={injection}")
assert response.status_code == 200
# Should only return users with username containing "test"
for user in response.json()["users"]:
assert "test" in user["username"].lower()

Testing XSS Protection

Cross-Site Scripting (XSS) vulnerabilities can be dangerous. Let's test if our application properly escapes user input:

python
from fastapi.responses import HTMLResponse

@app.get("/render/{name}", response_class=HTMLResponse)
async def render_name(name: str):
# Unsafe rendering of user input
return f"<html><body><h1>Hello, {name}!</h1></body></html>"

@app.get("/render-safe/{name}", response_class=HTMLResponse)
async def render_name_safe(name: str):
# Safe rendering with escaped input
from markupsafe import escape
return f"<html><body><h1>Hello, {escape(name)}!</h1></body></html>"

def test_xss_vulnerability():
# Test with normal input
response = client.get("/render/John")
assert response.status_code == 200
assert "<h1>Hello, John!</h1>" in response.text

# Test with XSS payload
xss_payload = "<script>alert('XSS')</script>"
response = client.get(f"/render/{xss_payload}")
assert response.status_code == 200

# In vulnerable endpoint, script tag would be included
assert "<script>alert('XSS')</script>" in response.text

# Test safe endpoint
response = client.get(f"/render-safe/{xss_payload}")
assert response.status_code == 200

# In safe endpoint, script tag should be escaped
assert "<script>" not in response.text
assert "&lt;script&gt;" in response.text

Testing For Information Disclosure

Make sure your API isn't leaking sensitive information:

python
@app.exception_handler(Exception)
async def generic_exception_handler(request, exc):
return {"detail": "Internal server error occurred"}

def test_error_information_disclosure():
# Create an endpoint that will generate an error
@app.get("/error")
def trigger_error():
# Deliberately cause an error
return 1/0

response = client.get("/error")
assert response.status_code == 500

# Error response should not contain stack traces or sensitive info
assert "traceback" not in response.text.lower()
assert "zero division error" not in response.text.lower()

# It should contain only the generic error message
assert response.json() == {"detail": "Internal server error occurred"}

Advanced Security Testing with Third-Party Tools

For comprehensive security testing, you can integrate with third-party security tools:

python
import subprocess
import json

def test_with_owasp_zap():
# This requires OWASP ZAP to be installed and running
# You would typically run this in CI/CD pipeline

# Example of how to use OWASP ZAP with Python
# This is a simplified version - in practice, you'd use the ZAP API
try:
# Start your FastAPI app for testing (on a specific port)
# ...

# Run ZAP scan (would be implemented with the ZAP API)
# result = zap_scan('http://localhost:8000')

# Analyze results
# assert no_high_risk_vulnerabilities(result)

pass # Placeholder
except Exception:
# Handle exceptions
pass

Summary

Security testing is a critical part of developing FastAPI applications. In this tutorial, we've covered:

  1. Authentication Testing: Verifying login mechanisms and token validation
  2. Authorization Testing: Ensuring proper access controls
  3. CORS Configuration Testing: Checking that cross-origin policies are enforced
  4. Rate Limiting Testing: Verifying protection against abuse
  5. SQL Injection Testing: Ensuring database queries are safe
  6. XSS Protection Testing: Checking if user input is properly escaped
  7. Information Disclosure Testing: Ensuring error handling doesn't leak sensitive data

Remember that security testing should be an ongoing process, not a one-time activity. Regularly test your applications for security vulnerabilities, especially after making significant changes.

Additional Resources

Exercises

  1. Implement and test a Password Reset feature with security in mind
  2. Add Content Security Policy headers to your FastAPI application and test them
  3. Implement CSRF protection for forms and write tests to verify it works
  4. Test your API for security headers (X-Content-Type-Options, X-Frame-Options, etc.)
  5. Create a dependency that validates and sanitizes input parameters to prevent injection attacks


If you spot any mistakes on this website, please let me know at [email protected]. I’d greatly appreciate your feedback! :)