Spaces:
Runtime error
Runtime error
from collections.abc import Awaitable | |
from datetime import datetime | |
from typing import Callable | |
from fastapi import Request, Response | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | |
from sentry_sdk.integrations.httpx import HttpxIntegration | |
from starlette.datastructures import MutableHeaders | |
from hibiapi.utils.config import Config | |
from hibiapi.utils.exceptions import BaseServerException, UncaughtException | |
from hibiapi.utils.log import LoguruHandler, logger | |
from hibiapi.utils.routing import request_headers, response_headers | |
from .application import app | |
from .handlers import exception_handler | |
RequestHandler = Callable[[Request], Awaitable[Response]] | |
if Config["server"]["gzip"].as_bool(): | |
app.add_middleware(GZipMiddleware) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=Config["server"]["cors"]["origins"].get(list[str]), | |
allow_credentials=Config["server"]["cors"]["credentials"].as_bool(), | |
allow_methods=Config["server"]["cors"]["methods"].get(list[str]), | |
allow_headers=Config["server"]["cors"]["headers"].get(list[str]), | |
) | |
app.add_middleware( | |
TrustedHostMiddleware, | |
allowed_hosts=Config["server"]["allowed"].get(list[str]), | |
) | |
app.add_middleware(SentryAsgiMiddleware) | |
HttpxIntegration.setup_once() | |
async def request_logger(request: Request, call_next: RequestHandler) -> Response: | |
start_time = datetime.now() | |
host, port = request.client or (None, None) | |
response = await call_next(request) | |
process_time = (datetime.now() - start_time).total_seconds() * 1000 | |
response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}") | |
bg, fg = ( | |
("green", "red") | |
if response.status_code < 400 | |
else ("yellow", "blue") | |
if response.status_code < 500 | |
else ("red", "green") | |
) | |
status_code, method = response.status_code, request.method.upper() | |
user_agent = ( | |
LoguruHandler.escape_tag(request.headers["user-agent"]) | |
if "user-agent" in request.headers | |
else "<d>Unknown</d>" | |
) | |
logger.info( | |
f"<m><b>{host}</b>:{port}</m>" | |
f" | <{bg.upper()}><b><{fg}>{method}</{fg}></b></{bg.upper()}>" | |
f" | <n><b>{str(request.url)!r}</b></n>" | |
f" | <c>{process_time:.3f}ms</c>" | |
f" | <e>{user_agent}</e>" | |
f" | <b><{bg}>{status_code}</{bg}></b>" | |
) | |
return response | |
async def contextvar_setter(request: Request, call_next: RequestHandler): | |
request_headers.set(request.headers) | |
response_headers.set(MutableHeaders()) | |
response = await call_next(request) | |
response.headers.update({**response_headers.get()}) | |
return response | |
async def uncaught_exception_handler( | |
request: Request, call_next: RequestHandler | |
) -> Response: | |
try: | |
response = await call_next(request) | |
except Exception as error: | |
response = await exception_handler( | |
request, | |
exc=( | |
error | |
if isinstance(error, BaseServerException) | |
else UncaughtException.with_exception(error) | |
), | |
) | |
return response | |