File size: 6,012 Bytes
0a1b571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import inspect
from collections.abc import Mapping
from contextvars import ContextVar
from enum import Enum
from fnmatch import fnmatch
from functools import wraps
from typing import Annotated, Any, Callable, Literal, Optional
from urllib.parse import ParseResult, urlparse

from fastapi import Depends, Request
from fastapi.routing import APIRouter
from httpx import URL
from pydantic import AnyHttpUrl
from pydantic.errors import UrlHostError
from starlette.datastructures import Headers, MutableHeaders

from hibiapi.utils.cache import endpoint_cache
from hibiapi.utils.net import AsyncCallable_T, AsyncHTTPClient, BaseNetClient

DONT_ROUTE_KEY = "_dont_route"


def dont_route(func: AsyncCallable_T) -> AsyncCallable_T:
    setattr(func, DONT_ROUTE_KEY, True)
    return func


class EndpointMeta(type):
    @staticmethod
    def _list_router_function(members: dict[str, Any]):
        return {
            name: object
            for name, object in members.items()
            if (
                inspect.iscoroutinefunction(object)
                and not name.startswith("_")
                and not getattr(object, DONT_ROUTE_KEY, False)
            )
        }

    def __new__(
        cls,
        name: str,
        bases: tuple[type, ...],
        namespace: dict[str, Any],
        *,
        cache_endpoints: bool = True,
        **kwargs,
    ):
        for object_name, object in cls._list_router_function(namespace).items():
            namespace[object_name] = (
                endpoint_cache(object) if cache_endpoints else object
            )
        return super().__new__(cls, name, bases, namespace, **kwargs)

    @property
    def router_functions(self):
        return self._list_router_function(dict(inspect.getmembers(self)))


class BaseEndpoint(metaclass=EndpointMeta, cache_endpoints=False):
    def __init__(self, client: AsyncHTTPClient):
        self.client = client

    @staticmethod
    def _join(base: str, endpoint: str, params: dict[str, Any]) -> URL:
        host: ParseResult = urlparse(base)
        params = {
            k: (v.value if isinstance(v, Enum) else v)
            for k, v in params.items()
            if v is not None
        }
        return URL(
            url=ParseResult(
                scheme=host.scheme,
                netloc=host.netloc,
                path=endpoint.format(**params),
                params="",
                query="",
                fragment="",
            ).geturl(),
            params=params,
        )


class SlashRouter(APIRouter):
    def api_route(self, path: str, **kwargs):
        path = path if path.startswith("/") else f"/{path}"
        return super().api_route(path, **kwargs)


class EndpointRouter(SlashRouter):
    @staticmethod
    def _exclude_params(func: Callable, params: Mapping[str, Any]) -> dict[str, Any]:
        func_params = inspect.signature(func).parameters
        return {k: v for k, v in params.items() if k in func_params}

    @staticmethod
    def _router_signature_convert(
        func,
        endpoint_class: type["BaseEndpoint"],
        request_client: Callable,
        method_name: Optional[str] = None,
    ):
        @wraps(func)
        async def route_func(endpoint: endpoint_class, **kwargs):
            endpoint_method = getattr(endpoint, method_name or func.__name__)
            return await endpoint_method(**kwargs)

        route_func.__signature__ = inspect.signature(route_func).replace(  # type:ignore
            parameters=[
                inspect.Parameter(
                    name="endpoint",
                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    annotation=endpoint_class,
                    default=Depends(request_client),
                ),
                *(
                    param
                    for param in inspect.signature(func).parameters.values()
                    if param.kind == inspect.Parameter.KEYWORD_ONLY
                ),
            ]
        )
        return route_func

    def include_endpoint(
        self,
        endpoint_class: type[BaseEndpoint],
        net_client: BaseNetClient,
        add_match_all: bool = True,
    ):
        router_functions = endpoint_class.router_functions

        async def request_client():
            async with net_client as client:
                yield endpoint_class(client)

        for func_name, func in router_functions.items():
            self.add_api_route(
                path=f"/{func_name}",
                endpoint=self._router_signature_convert(
                    func,
                    endpoint_class=endpoint_class,
                    request_client=request_client,
                    method_name=func_name,
                ),
                methods=["GET"],
            )

        if not add_match_all:
            return

        @self.get("/", description="JournalAD style API routing", deprecated=True)
        async def match_all(
            endpoint: Annotated[endpoint_class, Depends(request_client)],
            request: Request,
            type: Literal[tuple(router_functions.keys())],  # type: ignore
        ):
            func = router_functions[type]
            return await func(
                endpoint, **self._exclude_params(func, request.query_params)
            )


class BaseHostUrl(AnyHttpUrl):
    allowed_hosts: list[str] = []

    @classmethod
    def validate_host(cls, parts) -> tuple[str, Optional[str], str, bool]:
        host, tld, host_type, rebuild = super().validate_host(parts)
        if not cls._check_domain(host):
            raise UrlHostError(allowed=cls.allowed_hosts)
        return host, tld, host_type, rebuild

    @classmethod
    def _check_domain(cls, host: str) -> bool:
        return any(
            filter(
                lambda x: fnmatch(host, x),  # type:ignore
                cls.allowed_hosts,
            )
        )


request_headers = ContextVar[Headers]("request_headers")
response_headers = ContextVar[MutableHeaders]("response_headers")