File size: 3,784 Bytes
0a1b571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import functools
from collections.abc import Coroutine
from types import TracebackType
from typing import (
    Any,
    Callable,
    ClassVar,
    Optional,
    TypeVar,
    Union,
)

from httpx import (
    URL,
    AsyncClient,
    Cookies,
    HTTPError,
    HTTPStatusError,
    Request,
    Response,
    ResponseNotRead,
    TransportError,
)

from .decorators import Retry, TimeIt
from .exceptions import UpstreamAPIException
from .log import logger

AsyncCallable_T = TypeVar("AsyncCallable_T", bound=Callable[..., Coroutine])


class AsyncHTTPClient(AsyncClient):
    net_client: "BaseNetClient"

    @staticmethod
    async def _log_request(request: Request):
        method, url = request.method, request.url
        logger.debug(
            f"Network request <y>sent</y>: <b><e>{method}</e> <u>{url}</u></b>"
        )

    @staticmethod
    async def _log_response(response: Response):
        method, url = response.request.method, response.url
        try:
            length, code = len(response.content), response.status_code
        except ResponseNotRead:
            length, code = -1, response.status_code
        logger.debug(
            f"Network request <g>finished</g>: <b><e>{method}</e> "
            f"<u>{url}</u> <m>{code}</m></b> <m>{length}</m>"
        )

    @Retry(exceptions=[TransportError])
    async def request(self, method: str, url: Union[URL, str], **kwargs):
        self.event_hooks = {
            "request": [self._log_request],
            "response": [self._log_response],
        }
        return await super().request(method, url, **kwargs)


class BaseNetClient:
    connections: ClassVar[int] = 0
    clients: ClassVar[list[AsyncHTTPClient]] = []

    client: Optional[AsyncHTTPClient] = None

    def __init__(
        self,
        headers: Optional[dict[str, Any]] = None,
        cookies: Optional[Cookies] = None,
        proxies: Optional[dict[str, str]] = None,
        client_class: type[AsyncHTTPClient] = AsyncHTTPClient,
    ):
        self.cookies, self.client_class = cookies or Cookies(), client_class
        self.headers: dict[str, Any] = headers or {}
        self.proxies: Any = proxies or {}  # Bypass type checker

        self.create_client()

    def create_client(self):
        self.client = self.client_class(
            headers=self.headers,
            proxies=self.proxies,
            cookies=self.cookies,
            http2=True,
            follow_redirects=True,
        )
        self.client.net_client = self
        BaseNetClient.clients.append(self.client)
        return self.client

    async def __aenter__(self):
        if not self.client or self.client.is_closed:
            self.client = await self.create_client().__aenter__()

        self.__class__.connections += 1
        return self.client

    async def __aexit__(
        self,
        exc_type: Optional[type[BaseException]] = None,
        exc_value: Optional[BaseException] = None,
        traceback: Optional[TracebackType] = None,
    ):
        self.__class__.connections -= 1

        if not (exc_type and exc_value and traceback):
            return
        if self.client and not self.client.is_closed:
            client = self.client
            self.client = None
            await client.__aexit__(exc_type, exc_value, traceback)
        return


def catch_network_error(function: AsyncCallable_T) -> AsyncCallable_T:
    timed_func = TimeIt(function)

    @functools.wraps(timed_func)
    async def wrapper(*args, **kwargs):
        try:
            return await timed_func(*args, **kwargs)
        except HTTPStatusError as e:
            raise UpstreamAPIException(detail=e.response.text) from e
        except HTTPError as e:
            raise UpstreamAPIException from e

    return wrapper  # type:ignore