from collections.abc import Awaitable from datetime import datetime from typing import Callable from fastapi import Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from sentry_sdk.integrations.asgi import SentryAsgiMiddleware from sentry_sdk.integrations.httpx import HttpxIntegration from starlette.datastructures import MutableHeaders from hibiapi.utils.config import Config from hibiapi.utils.exceptions import BaseServerException, UncaughtException from hibiapi.utils.log import LoguruHandler, logger from hibiapi.utils.routing import request_headers, response_headers from .application import app from .handlers import exception_handler RequestHandler = Callable[[Request], Awaitable[Response]] if Config["server"]["gzip"].as_bool(): app.add_middleware(GZipMiddleware) app.add_middleware( CORSMiddleware, allow_origins=Config["server"]["cors"]["origins"].get(list[str]), allow_credentials=Config["server"]["cors"]["credentials"].as_bool(), allow_methods=Config["server"]["cors"]["methods"].get(list[str]), allow_headers=Config["server"]["cors"]["headers"].get(list[str]), ) app.add_middleware( TrustedHostMiddleware, allowed_hosts=Config["server"]["allowed"].get(list[str]), ) app.add_middleware(SentryAsgiMiddleware) HttpxIntegration.setup_once() @app.middleware("http") async def request_logger(request: Request, call_next: RequestHandler) -> Response: start_time = datetime.now() host, port = request.client or (None, None) response = await call_next(request) process_time = (datetime.now() - start_time).total_seconds() * 1000 response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}") bg, fg = ( ("green", "red") if response.status_code < 400 else ("yellow", "blue") if response.status_code < 500 else ("red", "green") ) status_code, method = response.status_code, request.method.upper() user_agent = ( LoguruHandler.escape_tag(request.headers["user-agent"]) if "user-agent" in request.headers else "Unknown" ) logger.info( f"{host}:{port}" f" | <{bg.upper()}><{fg}>{method}" f" | {str(request.url)!r}" f" | {process_time:.3f}ms" f" | {user_agent}" f" | <{bg}>{status_code}" ) return response @app.middleware("http") async def contextvar_setter(request: Request, call_next: RequestHandler): request_headers.set(request.headers) response_headers.set(MutableHeaders()) response = await call_next(request) response.headers.update({**response_headers.get()}) return response @app.middleware("http") async def uncaught_exception_handler( request: Request, call_next: RequestHandler ) -> Response: try: response = await call_next(request) except Exception as error: response = await exception_handler( request, exc=( error if isinstance(error, BaseServerException) else UncaughtException.with_exception(error) ), ) return response