Spaces:
Runtime error
Runtime error
File size: 3,747 Bytes
0a1b571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import hashlib
from collections.abc import Awaitable
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
from cashews import Cache
from pydantic import BaseModel
from pydantic.decorator import ValidatedFunction
from .config import Config
from .log import logger
CACHE_CONFIG_KEY = "_cache_config"
AsyncFunc = Callable[..., Awaitable[Any]]
T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc)
CACHE_ENABLED = Config["cache"]["enabled"].as_bool()
CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number())
CACHE_URI = Config["cache"]["uri"].as_str()
CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool()
cache = Cache(name="hibiapi")
try:
cache.setup(CACHE_URI)
except Exception as e:
logger.warning(
f"Cache URI <y>{CACHE_URI!r}</y> setup <r><b>failed</b></r>: "
f"<r>{e!r}</r>, use memory backend instead."
)
class CacheConfig(BaseModel):
endpoint: AsyncFunc
namespace: str
enabled: bool = True
ttl: timedelta = CACHE_DELTA
@staticmethod
def new(
function: AsyncFunc,
*,
enabled: bool = True,
ttl: timedelta = CACHE_DELTA,
namespace: Optional[str] = None,
):
return CacheConfig(
endpoint=function,
enabled=enabled,
ttl=ttl,
namespace=namespace or function.__qualname__,
)
def cache_config(
enabled: bool = True,
ttl: timedelta = CACHE_DELTA,
namespace: Optional[str] = None,
):
def decorator(function: T_AsyncFunc) -> T_AsyncFunc:
setattr(
function,
CACHE_CONFIG_KEY,
CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace),
)
return function
return decorator
disable_cache = cache_config(enabled=False)
class CachedValidatedFunction(ValidatedFunction):
def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel:
values = self.build_values(args=args, kwargs=kwargs)
return self.model(**values)
def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc:
from .routing import request_headers, response_headers
vf = CachedValidatedFunction(function, config={})
config = cast(
CacheConfig,
getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function),
)
config.enabled = CACHE_ENABLED and config.enabled
@wraps(function)
async def wrapper(*args, **kwargs):
cache_policy = "public"
if CACHE_CONTROLLABLE:
cache_policy = request_headers.get().get("cache-control", cache_policy)
if not config.enabled or cache_policy.casefold() == "no-store":
return await vf.call(*args, **kwargs)
key = (
f"{config.namespace}:"
+ hashlib.md5(
(model := vf.serialize(args=args, kwargs=kwargs))
.json(exclude={"self"}, sort_keys=True, ensure_ascii=False)
.encode()
).hexdigest()
)
response_header = response_headers.get()
result: Optional[Any] = None
if cache_policy.casefold() == "no-cache":
await cache.delete(key)
elif result := await cache.get(key):
logger.debug(f"Request hit cache <b><e>{key}</e></b>")
response_header.setdefault("X-Cache-Hit", key)
if result is None:
result = await vf.execute(model)
await cache.set(key, result, expire=config.ttl)
if (cache_remain := await cache.get_expire(key)) > 0:
response_header.setdefault("Cache-Control", f"max-age={cache_remain}")
return result
return wrapper # type:ignore
|