HibiAPI / hibiapi /app /middlewares.py
DengFengLai's picture
DF.
0a1b571
raw
history blame
3.33 kB
from collections.abc import Awaitable
from datetime import datetime
from typing import Callable
from fastapi import Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from sentry_sdk.integrations.httpx import HttpxIntegration
from starlette.datastructures import MutableHeaders
from hibiapi.utils.config import Config
from hibiapi.utils.exceptions import BaseServerException, UncaughtException
from hibiapi.utils.log import LoguruHandler, logger
from hibiapi.utils.routing import request_headers, response_headers
from .application import app
from .handlers import exception_handler
RequestHandler = Callable[[Request], Awaitable[Response]]
if Config["server"]["gzip"].as_bool():
app.add_middleware(GZipMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=Config["server"]["cors"]["origins"].get(list[str]),
allow_credentials=Config["server"]["cors"]["credentials"].as_bool(),
allow_methods=Config["server"]["cors"]["methods"].get(list[str]),
allow_headers=Config["server"]["cors"]["headers"].get(list[str]),
)
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=Config["server"]["allowed"].get(list[str]),
)
app.add_middleware(SentryAsgiMiddleware)
HttpxIntegration.setup_once()
@app.middleware("http")
async def request_logger(request: Request, call_next: RequestHandler) -> Response:
start_time = datetime.now()
host, port = request.client or (None, None)
response = await call_next(request)
process_time = (datetime.now() - start_time).total_seconds() * 1000
response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}")
bg, fg = (
("green", "red")
if response.status_code < 400
else ("yellow", "blue")
if response.status_code < 500
else ("red", "green")
)
status_code, method = response.status_code, request.method.upper()
user_agent = (
LoguruHandler.escape_tag(request.headers["user-agent"])
if "user-agent" in request.headers
else "<d>Unknown</d>"
)
logger.info(
f"<m><b>{host}</b>:{port}</m>"
f" | <{bg.upper()}><b><{fg}>{method}</{fg}></b></{bg.upper()}>"
f" | <n><b>{str(request.url)!r}</b></n>"
f" | <c>{process_time:.3f}ms</c>"
f" | <e>{user_agent}</e>"
f" | <b><{bg}>{status_code}</{bg}></b>"
)
return response
@app.middleware("http")
async def contextvar_setter(request: Request, call_next: RequestHandler):
request_headers.set(request.headers)
response_headers.set(MutableHeaders())
response = await call_next(request)
response.headers.update({**response_headers.get()})
return response
@app.middleware("http")
async def uncaught_exception_handler(
request: Request, call_next: RequestHandler
) -> Response:
try:
response = await call_next(request)
except Exception as error:
response = await exception_handler(
request,
exc=(
error
if isinstance(error, BaseServerException)
else UncaughtException.with_exception(error)
),
)
return response