Spaces:
Runtime error
Runtime error
File size: 3,784 Bytes
0a1b571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import functools
from collections.abc import Coroutine
from types import TracebackType
from typing import (
Any,
Callable,
ClassVar,
Optional,
TypeVar,
Union,
)
from httpx import (
URL,
AsyncClient,
Cookies,
HTTPError,
HTTPStatusError,
Request,
Response,
ResponseNotRead,
TransportError,
)
from .decorators import Retry, TimeIt
from .exceptions import UpstreamAPIException
from .log import logger
AsyncCallable_T = TypeVar("AsyncCallable_T", bound=Callable[..., Coroutine])
class AsyncHTTPClient(AsyncClient):
net_client: "BaseNetClient"
@staticmethod
async def _log_request(request: Request):
method, url = request.method, request.url
logger.debug(
f"Network request <y>sent</y>: <b><e>{method}</e> <u>{url}</u></b>"
)
@staticmethod
async def _log_response(response: Response):
method, url = response.request.method, response.url
try:
length, code = len(response.content), response.status_code
except ResponseNotRead:
length, code = -1, response.status_code
logger.debug(
f"Network request <g>finished</g>: <b><e>{method}</e> "
f"<u>{url}</u> <m>{code}</m></b> <m>{length}</m>"
)
@Retry(exceptions=[TransportError])
async def request(self, method: str, url: Union[URL, str], **kwargs):
self.event_hooks = {
"request": [self._log_request],
"response": [self._log_response],
}
return await super().request(method, url, **kwargs)
class BaseNetClient:
connections: ClassVar[int] = 0
clients: ClassVar[list[AsyncHTTPClient]] = []
client: Optional[AsyncHTTPClient] = None
def __init__(
self,
headers: Optional[dict[str, Any]] = None,
cookies: Optional[Cookies] = None,
proxies: Optional[dict[str, str]] = None,
client_class: type[AsyncHTTPClient] = AsyncHTTPClient,
):
self.cookies, self.client_class = cookies or Cookies(), client_class
self.headers: dict[str, Any] = headers or {}
self.proxies: Any = proxies or {} # Bypass type checker
self.create_client()
def create_client(self):
self.client = self.client_class(
headers=self.headers,
proxies=self.proxies,
cookies=self.cookies,
http2=True,
follow_redirects=True,
)
self.client.net_client = self
BaseNetClient.clients.append(self.client)
return self.client
async def __aenter__(self):
if not self.client or self.client.is_closed:
self.client = await self.create_client().__aenter__()
self.__class__.connections += 1
return self.client
async def __aexit__(
self,
exc_type: Optional[type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
):
self.__class__.connections -= 1
if not (exc_type and exc_value and traceback):
return
if self.client and not self.client.is_closed:
client = self.client
self.client = None
await client.__aexit__(exc_type, exc_value, traceback)
return
def catch_network_error(function: AsyncCallable_T) -> AsyncCallable_T:
timed_func = TimeIt(function)
@functools.wraps(timed_func)
async def wrapper(*args, **kwargs):
try:
return await timed_func(*args, **kwargs)
except HTTPStatusError as e:
raise UpstreamAPIException(detail=e.response.text) from e
except HTTPError as e:
raise UpstreamAPIException from e
return wrapper # type:ignore
|