Spaces:
Runtime error
Runtime error
| import hashlib | |
| from collections.abc import Awaitable | |
| from datetime import timedelta | |
| from functools import wraps | |
| from typing import Any, Callable, Optional, TypeVar, cast | |
| from cashews import Cache | |
| from pydantic import BaseModel | |
| from pydantic.decorator import ValidatedFunction | |
| from .config import Config | |
| from .log import logger | |
| CACHE_CONFIG_KEY = "_cache_config" | |
| AsyncFunc = Callable[..., Awaitable[Any]] | |
| T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc) | |
| CACHE_ENABLED = Config["cache"]["enabled"].as_bool() | |
| CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number()) | |
| CACHE_URI = Config["cache"]["uri"].as_str() | |
| CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool() | |
| cache = Cache(name="hibiapi") | |
| try: | |
| cache.setup(CACHE_URI) | |
| except Exception as e: | |
| logger.warning( | |
| f"Cache URI <y>{CACHE_URI!r}</y> setup <r><b>failed</b></r>: " | |
| f"<r>{e!r}</r>, use memory backend instead." | |
| ) | |
| class CacheConfig(BaseModel): | |
| endpoint: AsyncFunc | |
| namespace: str | |
| enabled: bool = True | |
| ttl: timedelta = CACHE_DELTA | |
| def new( | |
| function: AsyncFunc, | |
| *, | |
| enabled: bool = True, | |
| ttl: timedelta = CACHE_DELTA, | |
| namespace: Optional[str] = None, | |
| ): | |
| return CacheConfig( | |
| endpoint=function, | |
| enabled=enabled, | |
| ttl=ttl, | |
| namespace=namespace or function.__qualname__, | |
| ) | |
| def cache_config( | |
| enabled: bool = True, | |
| ttl: timedelta = CACHE_DELTA, | |
| namespace: Optional[str] = None, | |
| ): | |
| def decorator(function: T_AsyncFunc) -> T_AsyncFunc: | |
| setattr( | |
| function, | |
| CACHE_CONFIG_KEY, | |
| CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace), | |
| ) | |
| return function | |
| return decorator | |
| disable_cache = cache_config(enabled=False) | |
| class CachedValidatedFunction(ValidatedFunction): | |
| def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel: | |
| values = self.build_values(args=args, kwargs=kwargs) | |
| return self.model(**values) | |
| def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc: | |
| from .routing import request_headers, response_headers | |
| vf = CachedValidatedFunction(function, config={}) | |
| config = cast( | |
| CacheConfig, | |
| getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function), | |
| ) | |
| config.enabled = CACHE_ENABLED and config.enabled | |
| async def wrapper(*args, **kwargs): | |
| cache_policy = "public" | |
| if CACHE_CONTROLLABLE: | |
| cache_policy = request_headers.get().get("cache-control", cache_policy) | |
| if not config.enabled or cache_policy.casefold() == "no-store": | |
| return await vf.call(*args, **kwargs) | |
| key = ( | |
| f"{config.namespace}:" | |
| + hashlib.md5( | |
| (model := vf.serialize(args=args, kwargs=kwargs)) | |
| .json(exclude={"self"}, sort_keys=True, ensure_ascii=False) | |
| .encode() | |
| ).hexdigest() | |
| ) | |
| response_header = response_headers.get() | |
| result: Optional[Any] = None | |
| if cache_policy.casefold() == "no-cache": | |
| await cache.delete(key) | |
| elif result := await cache.get(key): | |
| logger.debug(f"Request hit cache <b><e>{key}</e></b>") | |
| response_header.setdefault("X-Cache-Hit", key) | |
| if result is None: | |
| result = await vf.execute(model) | |
| await cache.set(key, result, expire=config.ttl) | |
| if (cache_remain := await cache.get_expire(key)) > 0: | |
| response_header.setdefault("Cache-Control", f"max-age={cache_remain}") | |
| return result | |
| return wrapper # type:ignore | |