Files
2026-04-03 15:53:00 +01:00

112 lines
3.7 KiB
Python

import logging
import time
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from infrastructure.degradation import (
init_degradation_context,
try_get_degradation_context,
clear_degradation_context,
)
from infrastructure.resilience.rate_limiter import TokenBucketRateLimiter
from infrastructure.msgspec_fastapi import MsgSpecJSONResponse
logger = logging.getLogger(__name__)
SLOW_REQUEST_THRESHOLD = 1.0
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Per-process token-bucket rate limiter with per-path overrides."""
def __init__(
self,
app: ASGIApp,
default_rate: float = 30.0,
default_capacity: int = 60,
overrides: dict[str, tuple[float, int]] | None = None,
):
super().__init__(app)
self._default = TokenBucketRateLimiter(rate=default_rate, capacity=default_capacity)
self._overrides: list[tuple[str, TokenBucketRateLimiter]] = []
for prefix, (rate, capacity) in (overrides or {}).items():
self._overrides.append((prefix, TokenBucketRateLimiter(rate=rate, capacity=capacity)))
def _get_limiter(self, path: str) -> TokenBucketRateLimiter:
for prefix, limiter in self._overrides:
if path.startswith(prefix):
return limiter
return self._default
async def dispatch(self, request: Request, call_next):
path = request.url.path
if not path.startswith("/api/"):
return await call_next(request)
limiter = self._get_limiter(path)
acquired = await limiter.try_acquire()
if acquired:
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(limiter.capacity)
response.headers["X-RateLimit-Remaining"] = str(limiter.remaining)
return response
retry_after = limiter.retry_after()
return MsgSpecJSONResponse(
status_code=429,
content={
"error": {
"code": "RATE_LIMITED",
"message": "Too many requests",
"details": None,
}
},
headers={
"Retry-After": str(int(retry_after)),
"X-RateLimit-Limit": str(limiter.capacity),
"X-RateLimit-Remaining": "0",
},
)
class DegradationMiddleware(BaseHTTPMiddleware):
"""Initialise a per-request DegradationContext and surface results in a header."""
async def dispatch(self, request: Request, call_next):
init_degradation_context()
try:
response = await call_next(request)
ctx = try_get_degradation_context()
if ctx and ctx.has_degradation():
sources = ",".join(
name for name, status in ctx.summary().items() if status != "ok"
)
if sources:
response.headers["X-Degraded-Services"] = sources
return response
finally:
clear_degradation_context()
class PerformanceMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp):
super().__init__(app)
async def dispatch(self, request: Request, call_next):
start_time = time.perf_counter()
response = await call_next(request)
process_time = time.perf_counter() - start_time
response.headers["X-Response-Time"] = f"{process_time:.3f}s"
if process_time > SLOW_REQUEST_THRESHOLD:
logger.warning(
f"Slow request: {request.method} {request.url.path} "
f"took {process_time:.2f}s"
)
return response