import functools
from collections.abc import Coroutine
from types import TracebackType
from typing import (
Any,
Callable,
ClassVar,
Optional,
TypeVar,
Union,
)
from httpx import (
URL,
AsyncClient,
Cookies,
HTTPError,
HTTPStatusError,
Request,
Response,
ResponseNotRead,
TransportError,
)
from .decorators import Retry, TimeIt
from .exceptions import UpstreamAPIException
from .log import logger
AsyncCallable_T = TypeVar("AsyncCallable_T", bound=Callable[..., Coroutine])
class AsyncHTTPClient(AsyncClient):
net_client: "BaseNetClient"
@staticmethod
async def _log_request(request: Request):
method, url = request.method, request.url
logger.debug(
f"Network request sent: {method} {url}"
)
@staticmethod
async def _log_response(response: Response):
method, url = response.request.method, response.url
try:
length, code = len(response.content), response.status_code
except ResponseNotRead:
length, code = -1, response.status_code
logger.debug(
f"Network request finished: {method} "
f"{url} {code} {length}"
)
@Retry(exceptions=[TransportError])
async def request(self, method: str, url: Union[URL, str], **kwargs):
self.event_hooks = {
"request": [self._log_request],
"response": [self._log_response],
}
return await super().request(method, url, **kwargs)
class BaseNetClient:
connections: ClassVar[int] = 0
clients: ClassVar[list[AsyncHTTPClient]] = []
client: Optional[AsyncHTTPClient] = None
def __init__(
self,
headers: Optional[dict[str, Any]] = None,
cookies: Optional[Cookies] = None,
proxies: Optional[dict[str, str]] = None,
client_class: type[AsyncHTTPClient] = AsyncHTTPClient,
):
self.cookies, self.client_class = cookies or Cookies(), client_class
self.headers: dict[str, Any] = headers or {}
self.proxies: Any = proxies or {} # Bypass type checker
self.create_client()
def create_client(self):
self.client = self.client_class(
headers=self.headers,
proxies=self.proxies,
cookies=self.cookies,
http2=True,
follow_redirects=True,
)
self.client.net_client = self
BaseNetClient.clients.append(self.client)
return self.client
async def __aenter__(self):
if not self.client or self.client.is_closed:
self.client = await self.create_client().__aenter__()
self.__class__.connections += 1
return self.client
async def __aexit__(
self,
exc_type: Optional[type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
):
self.__class__.connections -= 1
if not (exc_type and exc_value and traceback):
return
if self.client and not self.client.is_closed:
client = self.client
self.client = None
await client.__aexit__(exc_type, exc_value, traceback)
return
def catch_network_error(function: AsyncCallable_T) -> AsyncCallable_T:
timed_func = TimeIt(function)
@functools.wraps(timed_func)
async def wrapper(*args, **kwargs):
try:
return await timed_func(*args, **kwargs)
except HTTPStatusError as e:
raise UpstreamAPIException(detail=e.response.text) from e
except HTTPError as e:
raise UpstreamAPIException from e
return wrapper # type:ignore