"""Shared test helpers for observability / log field assertions.""" from __future__ import annotations import logging from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient from starlette.exceptions import HTTPException as StarletteHTTPException from core.exception_handlers import ( circuit_open_error_handler, client_disconnected_handler, configuration_error_handler, external_service_error_handler, general_exception_handler, http_exception_handler, request_validation_error_handler, resource_not_found_handler, source_resolution_error_handler, starlette_http_exception_handler, validation_error_handler, ) from core.exceptions import ( ClientDisconnectedError, ConfigurationError, ExternalServiceError, ResourceNotFoundError, SourceResolutionError, ValidationError, ) from infrastructure.resilience.retry import CircuitOpenError def add_production_exception_handlers(app: FastAPI) -> FastAPI: app.add_exception_handler(ClientDisconnectedError, client_disconnected_handler) app.add_exception_handler(ResourceNotFoundError, resource_not_found_handler) app.add_exception_handler(ExternalServiceError, external_service_error_handler) app.add_exception_handler(CircuitOpenError, circuit_open_error_handler) app.add_exception_handler(ValidationError, validation_error_handler) app.add_exception_handler(ConfigurationError, configuration_error_handler) app.add_exception_handler(SourceResolutionError, source_resolution_error_handler) app.add_exception_handler(HTTPException, http_exception_handler) app.add_exception_handler(StarletteHTTPException, starlette_http_exception_handler) app.add_exception_handler(RequestValidationError, request_validation_error_handler) app.add_exception_handler(Exception, general_exception_handler) return app def build_test_client(app: FastAPI) -> TestClient: add_production_exception_handlers(app) return TestClient(app, raise_server_exceptions=False) def assert_log_fields( records: list[logging.LogRecord], prefix: str, required_fields: list[str], *, min_count: int = 1, ) -> list[str]: """Assert that log records matching *prefix* contain all *required_fields*. Returns the matching messages for further inspection. Parameters ---------- records: ``caplog.records`` or equivalent list of ``LogRecord``. prefix: The log message prefix to filter on (e.g. ``"audiodb.cache"``). required_fields: Key names that must appear as ``key=`` in every matching message. min_count: Minimum number of matching records expected (default 1). """ matching = [r.message for r in records if r.message.startswith(prefix)] assert len(matching) >= min_count, ( f"Expected >= {min_count} log(s) starting with '{prefix}', found {len(matching)}" ) for msg in matching: for field in required_fields: assert f"{field}=" in msg, ( f"Field '{field}=' missing in log: {msg}" ) return matching