import hashlib from collections.abc import Awaitable from datetime import timedelta from functools import wraps from typing import Any, Callable, Optional, TypeVar, cast from cashews import Cache from pydantic import BaseModel from pydantic.decorator import ValidatedFunction from .config import Config from .log import logger CACHE_CONFIG_KEY = "_cache_config" AsyncFunc = Callable[..., Awaitable[Any]] T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc) CACHE_ENABLED = Config["cache"]["enabled"].as_bool() CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number()) CACHE_URI = Config["cache"]["uri"].as_str() CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool() cache = Cache(name="hibiapi") try: cache.setup(CACHE_URI) except Exception as e: logger.warning( f"Cache URI {CACHE_URI!r} setup failed: " f"{e!r}, use memory backend instead." ) class CacheConfig(BaseModel): endpoint: AsyncFunc namespace: str enabled: bool = True ttl: timedelta = CACHE_DELTA @staticmethod def new( function: AsyncFunc, *, enabled: bool = True, ttl: timedelta = CACHE_DELTA, namespace: Optional[str] = None, ): return CacheConfig( endpoint=function, enabled=enabled, ttl=ttl, namespace=namespace or function.__qualname__, ) def cache_config( enabled: bool = True, ttl: timedelta = CACHE_DELTA, namespace: Optional[str] = None, ): def decorator(function: T_AsyncFunc) -> T_AsyncFunc: setattr( function, CACHE_CONFIG_KEY, CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace), ) return function return decorator disable_cache = cache_config(enabled=False) class CachedValidatedFunction(ValidatedFunction): def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel: values = self.build_values(args=args, kwargs=kwargs) return self.model(**values) def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc: from .routing import request_headers, response_headers vf = CachedValidatedFunction(function, config={}) config = cast( CacheConfig, getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function), ) config.enabled = CACHE_ENABLED and config.enabled @wraps(function) async def wrapper(*args, **kwargs): cache_policy = "public" if CACHE_CONTROLLABLE: cache_policy = request_headers.get().get("cache-control", cache_policy) if not config.enabled or cache_policy.casefold() == "no-store": return await vf.call(*args, **kwargs) key = ( f"{config.namespace}:" + hashlib.md5( (model := vf.serialize(args=args, kwargs=kwargs)) .json(exclude={"self"}, sort_keys=True, ensure_ascii=False) .encode() ).hexdigest() ) response_header = response_headers.get() result: Optional[Any] = None if cache_policy.casefold() == "no-cache": await cache.delete(key) elif result := await cache.get(key): logger.debug(f"Request hit cache {key}") response_header.setdefault("X-Cache-Hit", key) if result is None: result = await vf.execute(model) await cache.set(key, result, expire=config.ttl) if (cache_remain := await cache.get_expire(key)) > 0: response_header.setdefault("Cache-Control", f"max-age={cache_remain}") return result return wrapper # type:ignore