HibiAPI / hibiapi /utils /cache.py
DengFengLai's picture
DF.
0a1b571
raw
history blame
3.75 kB
import hashlib
from collections.abc import Awaitable
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from cashews import Cache
from pydantic import BaseModel
from pydantic.decorator import ValidatedFunction
from .config import Config
from .log import logger
CACHE_CONFIG_KEY = "_cache_config"
AsyncFunc = Callable[..., Awaitable[Any]]
T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc)
CACHE_ENABLED = Config["cache"]["enabled"].as_bool()
CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number())
CACHE_URI = Config["cache"]["uri"].as_str()
CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool()
cache = Cache(name="hibiapi")
try:
cache.setup(CACHE_URI)
except Exception as e:
logger.warning(
f"Cache URI <y>{CACHE_URI!r}</y> setup <r><b>failed</b></r>: "
f"<r>{e!r}</r>, use memory backend instead."
)
class CacheConfig(BaseModel):
endpoint: AsyncFunc
namespace: str
enabled: bool = True
ttl: timedelta = CACHE_DELTA
@staticmethod
def new(
function: AsyncFunc,
*,
enabled: bool = True,
ttl: timedelta = CACHE_DELTA,
namespace: Optional[str] = None,
):
return CacheConfig(
endpoint=function,
enabled=enabled,
ttl=ttl,
namespace=namespace or function.__qualname__,
)
def cache_config(
enabled: bool = True,
ttl: timedelta = CACHE_DELTA,
namespace: Optional[str] = None,
):
def decorator(function: T_AsyncFunc) -> T_AsyncFunc:
setattr(
function,
CACHE_CONFIG_KEY,
CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace),
)
return function
return decorator
disable_cache = cache_config(enabled=False)
class CachedValidatedFunction(ValidatedFunction):
def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel:
values = self.build_values(args=args, kwargs=kwargs)
return self.model(**values)
def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc:
from .routing import request_headers, response_headers
vf = CachedValidatedFunction(function, config={})
config = cast(
CacheConfig,
getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function),
)
config.enabled = CACHE_ENABLED and config.enabled
@wraps(function)
async def wrapper(*args, **kwargs):
cache_policy = "public"
if CACHE_CONTROLLABLE:
cache_policy = request_headers.get().get("cache-control", cache_policy)
if not config.enabled or cache_policy.casefold() == "no-store":
return await vf.call(*args, **kwargs)
key = (
f"{config.namespace}:"
+ hashlib.md5(
(model := vf.serialize(args=args, kwargs=kwargs))
.json(exclude={"self"}, sort_keys=True, ensure_ascii=False)
.encode()
).hexdigest()
)
response_header = response_headers.get()
result: Optional[Any] = None
if cache_policy.casefold() == "no-cache":
await cache.delete(key)
elif result := await cache.get(key):
logger.debug(f"Request hit cache <b><e>{key}</e></b>")
response_header.setdefault("X-Cache-Hit", key)
if result is None:
result = await vf.execute(model)
await cache.set(key, result, expire=config.ttl)
if (cache_remain := await cache.get_expire(key)) > 0:
response_header.setdefault("Cache-Control", f"max-age={cache_remain}")
return result
return wrapper # type:ignore