112 lines
3.7 KiB
Python
112 lines
3.7 KiB
Python
import logging
|
|
import time
|
|
from fastapi import Request
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
from infrastructure.degradation import (
|
|
init_degradation_context,
|
|
try_get_degradation_context,
|
|
clear_degradation_context,
|
|
)
|
|
from infrastructure.resilience.rate_limiter import TokenBucketRateLimiter
|
|
from infrastructure.msgspec_fastapi import MsgSpecJSONResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SLOW_REQUEST_THRESHOLD = 1.0
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Per-process token-bucket rate limiter with per-path overrides."""
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
default_rate: float = 30.0,
|
|
default_capacity: int = 60,
|
|
overrides: dict[str, tuple[float, int]] | None = None,
|
|
):
|
|
super().__init__(app)
|
|
self._default = TokenBucketRateLimiter(rate=default_rate, capacity=default_capacity)
|
|
self._overrides: list[tuple[str, TokenBucketRateLimiter]] = []
|
|
for prefix, (rate, capacity) in (overrides or {}).items():
|
|
self._overrides.append((prefix, TokenBucketRateLimiter(rate=rate, capacity=capacity)))
|
|
|
|
def _get_limiter(self, path: str) -> TokenBucketRateLimiter:
|
|
for prefix, limiter in self._overrides:
|
|
if path.startswith(prefix):
|
|
return limiter
|
|
return self._default
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
path = request.url.path
|
|
if not path.startswith("/api/"):
|
|
return await call_next(request)
|
|
|
|
limiter = self._get_limiter(path)
|
|
acquired = await limiter.try_acquire()
|
|
|
|
if acquired:
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Limit"] = str(limiter.capacity)
|
|
response.headers["X-RateLimit-Remaining"] = str(limiter.remaining)
|
|
return response
|
|
|
|
retry_after = limiter.retry_after()
|
|
return MsgSpecJSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": {
|
|
"code": "RATE_LIMITED",
|
|
"message": "Too many requests",
|
|
"details": None,
|
|
}
|
|
},
|
|
headers={
|
|
"Retry-After": str(int(retry_after)),
|
|
"X-RateLimit-Limit": str(limiter.capacity),
|
|
"X-RateLimit-Remaining": "0",
|
|
},
|
|
)
|
|
|
|
|
|
class DegradationMiddleware(BaseHTTPMiddleware):
|
|
"""Initialise a per-request DegradationContext and surface results in a header."""
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
init_degradation_context()
|
|
try:
|
|
response = await call_next(request)
|
|
ctx = try_get_degradation_context()
|
|
if ctx and ctx.has_degradation():
|
|
sources = ",".join(
|
|
name for name, status in ctx.summary().items() if status != "ok"
|
|
)
|
|
if sources:
|
|
response.headers["X-Degraded-Services"] = sources
|
|
return response
|
|
finally:
|
|
clear_degradation_context()
|
|
|
|
|
|
class PerformanceMiddleware(BaseHTTPMiddleware):
|
|
|
|
def __init__(self, app: ASGIApp):
|
|
super().__init__(app)
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
start_time = time.perf_counter()
|
|
response = await call_next(request)
|
|
process_time = time.perf_counter() - start_time
|
|
|
|
response.headers["X-Response-Time"] = f"{process_time:.3f}s"
|
|
|
|
if process_time > SLOW_REQUEST_THRESHOLD:
|
|
logger.warning(
|
|
f"Slow request: {request.method} {request.url.path} "
|
|
f"took {process_time:.2f}s"
|
|
)
|
|
|
|
return response
|