File size: 3,333 Bytes
0a1b571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from collections.abc import Awaitable
from datetime import datetime
from typing import Callable

from fastapi import Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
from sentry_sdk.integrations.httpx import HttpxIntegration
from starlette.datastructures import MutableHeaders

from hibiapi.utils.config import Config
from hibiapi.utils.exceptions import BaseServerException, UncaughtException
from hibiapi.utils.log import LoguruHandler, logger
from hibiapi.utils.routing import request_headers, response_headers

from .application import app
from .handlers import exception_handler

RequestHandler = Callable[[Request], Awaitable[Response]]


if Config["server"]["gzip"].as_bool():
    app.add_middleware(GZipMiddleware)
app.add_middleware(
    CORSMiddleware,
    allow_origins=Config["server"]["cors"]["origins"].get(list[str]),
    allow_credentials=Config["server"]["cors"]["credentials"].as_bool(),
    allow_methods=Config["server"]["cors"]["methods"].get(list[str]),
    allow_headers=Config["server"]["cors"]["headers"].get(list[str]),
)
app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=Config["server"]["allowed"].get(list[str]),
)
app.add_middleware(SentryAsgiMiddleware)

HttpxIntegration.setup_once()


@app.middleware("http")
async def request_logger(request: Request, call_next: RequestHandler) -> Response:
    start_time = datetime.now()
    host, port = request.client or (None, None)
    response = await call_next(request)
    process_time = (datetime.now() - start_time).total_seconds() * 1000
    response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}")
    bg, fg = (
        ("green", "red")
        if response.status_code < 400
        else ("yellow", "blue")
        if response.status_code < 500
        else ("red", "green")
    )
    status_code, method = response.status_code, request.method.upper()
    user_agent = (
        LoguruHandler.escape_tag(request.headers["user-agent"])
        if "user-agent" in request.headers
        else "<d>Unknown</d>"
    )
    logger.info(
        f"<m><b>{host}</b>:{port}</m>"
        f" | <{bg.upper()}><b><{fg}>{method}</{fg}></b></{bg.upper()}>"
        f" | <n><b>{str(request.url)!r}</b></n>"
        f" | <c>{process_time:.3f}ms</c>"
        f" | <e>{user_agent}</e>"
        f" | <b><{bg}>{status_code}</{bg}></b>"
    )
    return response


@app.middleware("http")
async def contextvar_setter(request: Request, call_next: RequestHandler):
    request_headers.set(request.headers)
    response_headers.set(MutableHeaders())
    response = await call_next(request)
    response.headers.update({**response_headers.get()})
    return response


@app.middleware("http")
async def uncaught_exception_handler(
    request: Request, call_next: RequestHandler
) -> Response:
    try:
        response = await call_next(request)
    except Exception as error:
        response = await exception_handler(
            request,
            exc=(
                error
                if isinstance(error, BaseServerException)
                else UncaughtException.with_exception(error)
            ),
        )
    return response