Spaces:
Runtime error
Runtime error
import inspect | |
from collections.abc import Mapping | |
from contextvars import ContextVar | |
from enum import Enum | |
from fnmatch import fnmatch | |
from functools import wraps | |
from typing import Annotated, Any, Callable, Literal, Optional | |
from urllib.parse import ParseResult, urlparse | |
from fastapi import Depends, Request | |
from fastapi.routing import APIRouter | |
from httpx import URL | |
from pydantic import AnyHttpUrl | |
from pydantic.errors import UrlHostError | |
from starlette.datastructures import Headers, MutableHeaders | |
from hibiapi.utils.cache import endpoint_cache | |
from hibiapi.utils.net import AsyncCallable_T, AsyncHTTPClient, BaseNetClient | |
DONT_ROUTE_KEY = "_dont_route" | |
def dont_route(func: AsyncCallable_T) -> AsyncCallable_T: | |
setattr(func, DONT_ROUTE_KEY, True) | |
return func | |
class EndpointMeta(type): | |
def _list_router_function(members: dict[str, Any]): | |
return { | |
name: object | |
for name, object in members.items() | |
if ( | |
inspect.iscoroutinefunction(object) | |
and not name.startswith("_") | |
and not getattr(object, DONT_ROUTE_KEY, False) | |
) | |
} | |
def __new__( | |
cls, | |
name: str, | |
bases: tuple[type, ...], | |
namespace: dict[str, Any], | |
*, | |
cache_endpoints: bool = True, | |
**kwargs, | |
): | |
for object_name, object in cls._list_router_function(namespace).items(): | |
namespace[object_name] = ( | |
endpoint_cache(object) if cache_endpoints else object | |
) | |
return super().__new__(cls, name, bases, namespace, **kwargs) | |
def router_functions(self): | |
return self._list_router_function(dict(inspect.getmembers(self))) | |
class BaseEndpoint(metaclass=EndpointMeta, cache_endpoints=False): | |
def __init__(self, client: AsyncHTTPClient): | |
self.client = client | |
def _join(base: str, endpoint: str, params: dict[str, Any]) -> URL: | |
host: ParseResult = urlparse(base) | |
params = { | |
k: (v.value if isinstance(v, Enum) else v) | |
for k, v in params.items() | |
if v is not None | |
} | |
return URL( | |
url=ParseResult( | |
scheme=host.scheme, | |
netloc=host.netloc, | |
path=endpoint.format(**params), | |
params="", | |
query="", | |
fragment="", | |
).geturl(), | |
params=params, | |
) | |
class SlashRouter(APIRouter): | |
def api_route(self, path: str, **kwargs): | |
path = path if path.startswith("/") else f"/{path}" | |
return super().api_route(path, **kwargs) | |
class EndpointRouter(SlashRouter): | |
def _exclude_params(func: Callable, params: Mapping[str, Any]) -> dict[str, Any]: | |
func_params = inspect.signature(func).parameters | |
return {k: v for k, v in params.items() if k in func_params} | |
def _router_signature_convert( | |
func, | |
endpoint_class: type["BaseEndpoint"], | |
request_client: Callable, | |
method_name: Optional[str] = None, | |
): | |
async def route_func(endpoint: endpoint_class, **kwargs): | |
endpoint_method = getattr(endpoint, method_name or func.__name__) | |
return await endpoint_method(**kwargs) | |
route_func.__signature__ = inspect.signature(route_func).replace( # type:ignore | |
parameters=[ | |
inspect.Parameter( | |
name="endpoint", | |
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
annotation=endpoint_class, | |
default=Depends(request_client), | |
), | |
*( | |
param | |
for param in inspect.signature(func).parameters.values() | |
if param.kind == inspect.Parameter.KEYWORD_ONLY | |
), | |
] | |
) | |
return route_func | |
def include_endpoint( | |
self, | |
endpoint_class: type[BaseEndpoint], | |
net_client: BaseNetClient, | |
add_match_all: bool = True, | |
): | |
router_functions = endpoint_class.router_functions | |
async def request_client(): | |
async with net_client as client: | |
yield endpoint_class(client) | |
for func_name, func in router_functions.items(): | |
self.add_api_route( | |
path=f"/{func_name}", | |
endpoint=self._router_signature_convert( | |
func, | |
endpoint_class=endpoint_class, | |
request_client=request_client, | |
method_name=func_name, | |
), | |
methods=["GET"], | |
) | |
if not add_match_all: | |
return | |
async def match_all( | |
endpoint: Annotated[endpoint_class, Depends(request_client)], | |
request: Request, | |
type: Literal[tuple(router_functions.keys())], # type: ignore | |
): | |
func = router_functions[type] | |
return await func( | |
endpoint, **self._exclude_params(func, request.query_params) | |
) | |
class BaseHostUrl(AnyHttpUrl): | |
allowed_hosts: list[str] = [] | |
def validate_host(cls, parts) -> tuple[str, Optional[str], str, bool]: | |
host, tld, host_type, rebuild = super().validate_host(parts) | |
if not cls._check_domain(host): | |
raise UrlHostError(allowed=cls.allowed_hosts) | |
return host, tld, host_type, rebuild | |
def _check_domain(cls, host: str) -> bool: | |
return any( | |
filter( | |
lambda x: fnmatch(host, x), # type:ignore | |
cls.allowed_hosts, | |
) | |
) | |
request_headers = ContextVar[Headers]("request_headers") | |
response_headers = ContextVar[MutableHeaders]("response_headers") | |