File size: 3,747 Bytes
0a1b571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import hashlib
from collections.abc import Awaitable
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast

from cashews import Cache
from pydantic import BaseModel
from pydantic.decorator import ValidatedFunction

from .config import Config
from .log import logger

CACHE_CONFIG_KEY = "_cache_config"

AsyncFunc = Callable[..., Awaitable[Any]]
T_AsyncFunc = TypeVar("T_AsyncFunc", bound=AsyncFunc)


CACHE_ENABLED = Config["cache"]["enabled"].as_bool()
CACHE_DELTA = timedelta(seconds=Config["cache"]["ttl"].as_number())
CACHE_URI = Config["cache"]["uri"].as_str()
CACHE_CONTROLLABLE = Config["cache"]["controllable"].as_bool()

cache = Cache(name="hibiapi")
try:
    cache.setup(CACHE_URI)
except Exception as e:
    logger.warning(
        f"Cache URI <y>{CACHE_URI!r}</y> setup <r><b>failed</b></r>: "
        f"<r>{e!r}</r>, use memory backend instead."
    )


class CacheConfig(BaseModel):
    endpoint: AsyncFunc
    namespace: str
    enabled: bool = True
    ttl: timedelta = CACHE_DELTA

    @staticmethod
    def new(
        function: AsyncFunc,
        *,
        enabled: bool = True,
        ttl: timedelta = CACHE_DELTA,
        namespace: Optional[str] = None,
    ):
        return CacheConfig(
            endpoint=function,
            enabled=enabled,
            ttl=ttl,
            namespace=namespace or function.__qualname__,
        )


def cache_config(
    enabled: bool = True,
    ttl: timedelta = CACHE_DELTA,
    namespace: Optional[str] = None,
):
    def decorator(function: T_AsyncFunc) -> T_AsyncFunc:
        setattr(
            function,
            CACHE_CONFIG_KEY,
            CacheConfig.new(function, enabled=enabled, ttl=ttl, namespace=namespace),
        )
        return function

    return decorator


disable_cache = cache_config(enabled=False)


class CachedValidatedFunction(ValidatedFunction):
    def serialize(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> BaseModel:
        values = self.build_values(args=args, kwargs=kwargs)
        return self.model(**values)


def endpoint_cache(function: T_AsyncFunc) -> T_AsyncFunc:
    from .routing import request_headers, response_headers

    vf = CachedValidatedFunction(function, config={})
    config = cast(
        CacheConfig,
        getattr(function, CACHE_CONFIG_KEY, None) or CacheConfig.new(function),
    )

    config.enabled = CACHE_ENABLED and config.enabled

    @wraps(function)
    async def wrapper(*args, **kwargs):
        cache_policy = "public"

        if CACHE_CONTROLLABLE:
            cache_policy = request_headers.get().get("cache-control", cache_policy)

        if not config.enabled or cache_policy.casefold() == "no-store":
            return await vf.call(*args, **kwargs)

        key = (
            f"{config.namespace}:"
            + hashlib.md5(
                (model := vf.serialize(args=args, kwargs=kwargs))
                .json(exclude={"self"}, sort_keys=True, ensure_ascii=False)
                .encode()
            ).hexdigest()
        )

        response_header = response_headers.get()
        result: Optional[Any] = None

        if cache_policy.casefold() == "no-cache":
            await cache.delete(key)
        elif result := await cache.get(key):
            logger.debug(f"Request hit cache <b><e>{key}</e></b>")
            response_header.setdefault("X-Cache-Hit", key)

        if result is None:
            result = await vf.execute(model)
            await cache.set(key, result, expire=config.ttl)

        if (cache_remain := await cache.get_expire(key)) > 0:
            response_header.setdefault("Cache-Control", f"max-age={cache_remain}")

        return result

    return wrapper  # type:ignore