Spaces:
Runtime error
Runtime error
File size: 6,012 Bytes
0a1b571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import inspect
from collections.abc import Mapping
from contextvars import ContextVar
from enum import Enum
from fnmatch import fnmatch
from functools import wraps
from typing import Annotated, Any, Callable, Literal, Optional
from urllib.parse import ParseResult, urlparse
from fastapi import Depends, Request
from fastapi.routing import APIRouter
from httpx import URL
from pydantic import AnyHttpUrl
from pydantic.errors import UrlHostError
from starlette.datastructures import Headers, MutableHeaders
from hibiapi.utils.cache import endpoint_cache
from hibiapi.utils.net import AsyncCallable_T, AsyncHTTPClient, BaseNetClient
DONT_ROUTE_KEY = "_dont_route"
def dont_route(func: AsyncCallable_T) -> AsyncCallable_T:
setattr(func, DONT_ROUTE_KEY, True)
return func
class EndpointMeta(type):
@staticmethod
def _list_router_function(members: dict[str, Any]):
return {
name: object
for name, object in members.items()
if (
inspect.iscoroutinefunction(object)
and not name.startswith("_")
and not getattr(object, DONT_ROUTE_KEY, False)
)
}
def __new__(
cls,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
*,
cache_endpoints: bool = True,
**kwargs,
):
for object_name, object in cls._list_router_function(namespace).items():
namespace[object_name] = (
endpoint_cache(object) if cache_endpoints else object
)
return super().__new__(cls, name, bases, namespace, **kwargs)
@property
def router_functions(self):
return self._list_router_function(dict(inspect.getmembers(self)))
class BaseEndpoint(metaclass=EndpointMeta, cache_endpoints=False):
def __init__(self, client: AsyncHTTPClient):
self.client = client
@staticmethod
def _join(base: str, endpoint: str, params: dict[str, Any]) -> URL:
host: ParseResult = urlparse(base)
params = {
k: (v.value if isinstance(v, Enum) else v)
for k, v in params.items()
if v is not None
}
return URL(
url=ParseResult(
scheme=host.scheme,
netloc=host.netloc,
path=endpoint.format(**params),
params="",
query="",
fragment="",
).geturl(),
params=params,
)
class SlashRouter(APIRouter):
def api_route(self, path: str, **kwargs):
path = path if path.startswith("/") else f"/{path}"
return super().api_route(path, **kwargs)
class EndpointRouter(SlashRouter):
@staticmethod
def _exclude_params(func: Callable, params: Mapping[str, Any]) -> dict[str, Any]:
func_params = inspect.signature(func).parameters
return {k: v for k, v in params.items() if k in func_params}
@staticmethod
def _router_signature_convert(
func,
endpoint_class: type["BaseEndpoint"],
request_client: Callable,
method_name: Optional[str] = None,
):
@wraps(func)
async def route_func(endpoint: endpoint_class, **kwargs):
endpoint_method = getattr(endpoint, method_name or func.__name__)
return await endpoint_method(**kwargs)
route_func.__signature__ = inspect.signature(route_func).replace( # type:ignore
parameters=[
inspect.Parameter(
name="endpoint",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=endpoint_class,
default=Depends(request_client),
),
*(
param
for param in inspect.signature(func).parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
),
]
)
return route_func
def include_endpoint(
self,
endpoint_class: type[BaseEndpoint],
net_client: BaseNetClient,
add_match_all: bool = True,
):
router_functions = endpoint_class.router_functions
async def request_client():
async with net_client as client:
yield endpoint_class(client)
for func_name, func in router_functions.items():
self.add_api_route(
path=f"/{func_name}",
endpoint=self._router_signature_convert(
func,
endpoint_class=endpoint_class,
request_client=request_client,
method_name=func_name,
),
methods=["GET"],
)
if not add_match_all:
return
@self.get("/", description="JournalAD style API routing", deprecated=True)
async def match_all(
endpoint: Annotated[endpoint_class, Depends(request_client)],
request: Request,
type: Literal[tuple(router_functions.keys())], # type: ignore
):
func = router_functions[type]
return await func(
endpoint, **self._exclude_params(func, request.query_params)
)
class BaseHostUrl(AnyHttpUrl):
allowed_hosts: list[str] = []
@classmethod
def validate_host(cls, parts) -> tuple[str, Optional[str], str, bool]:
host, tld, host_type, rebuild = super().validate_host(parts)
if not cls._check_domain(host):
raise UrlHostError(allowed=cls.allowed_hosts)
return host, tld, host_type, rebuild
@classmethod
def _check_domain(cls, host: str) -> bool:
return any(
filter(
lambda x: fnmatch(host, x), # type:ignore
cls.allowed_hosts,
)
)
request_headers = ContextVar[Headers]("request_headers")
response_headers = ContextVar[MutableHeaders]("response_headers")
|