Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import typing | |
import sniffio | |
from .._models import Request, Response | |
from .._types import AsyncByteStream | |
from .base import AsyncBaseTransport | |
if typing.TYPE_CHECKING: # pragma: no cover | |
import asyncio | |
import trio | |
Event = typing.Union[asyncio.Event, trio.Event] | |
_Message = typing.MutableMapping[str, typing.Any] | |
_Receive = typing.Callable[[], typing.Awaitable[_Message]] | |
_Send = typing.Callable[ | |
[typing.MutableMapping[str, typing.Any]], typing.Awaitable[None] | |
] | |
_ASGIApp = typing.Callable[ | |
[typing.MutableMapping[str, typing.Any], _Receive, _Send], typing.Awaitable[None] | |
] | |
__all__ = ["ASGITransport"] | |
def create_event() -> Event: | |
if sniffio.current_async_library() == "trio": | |
import trio | |
return trio.Event() | |
else: | |
import asyncio | |
return asyncio.Event() | |
class ASGIResponseStream(AsyncByteStream): | |
def __init__(self, body: list[bytes]) -> None: | |
self._body = body | |
async def __aiter__(self) -> typing.AsyncIterator[bytes]: | |
yield b"".join(self._body) | |
class ASGITransport(AsyncBaseTransport): | |
""" | |
A custom AsyncTransport that handles sending requests directly to an ASGI app. | |
```python | |
transport = httpx.ASGITransport( | |
app=app, | |
root_path="/submount", | |
client=("1.2.3.4", 123) | |
) | |
client = httpx.AsyncClient(transport=transport) | |
``` | |
Arguments: | |
* `app` - The ASGI application. | |
* `raise_app_exceptions` - Boolean indicating if exceptions in the application | |
should be raised. Default to `True`. Can be set to `False` for use cases | |
such as testing the content of a client 500 response. | |
* `root_path` - The root path on which the ASGI application should be mounted. | |
* `client` - A two-tuple indicating the client IP and port of incoming requests. | |
``` | |
""" | |
def __init__( | |
self, | |
app: _ASGIApp, | |
raise_app_exceptions: bool = True, | |
root_path: str = "", | |
client: tuple[str, int] = ("127.0.0.1", 123), | |
) -> None: | |
self.app = app | |
self.raise_app_exceptions = raise_app_exceptions | |
self.root_path = root_path | |
self.client = client | |
async def handle_async_request( | |
self, | |
request: Request, | |
) -> Response: | |
assert isinstance(request.stream, AsyncByteStream) | |
# ASGI scope. | |
scope = { | |
"type": "http", | |
"asgi": {"version": "3.0"}, | |
"http_version": "1.1", | |
"method": request.method, | |
"headers": [(k.lower(), v) for (k, v) in request.headers.raw], | |
"scheme": request.url.scheme, | |
"path": request.url.path, | |
"raw_path": request.url.raw_path.split(b"?")[0], | |
"query_string": request.url.query, | |
"server": (request.url.host, request.url.port), | |
"client": self.client, | |
"root_path": self.root_path, | |
} | |
# Request. | |
request_body_chunks = request.stream.__aiter__() | |
request_complete = False | |
# Response. | |
status_code = None | |
response_headers = None | |
body_parts = [] | |
response_started = False | |
response_complete = create_event() | |
# ASGI callables. | |
async def receive() -> dict[str, typing.Any]: | |
nonlocal request_complete | |
if request_complete: | |
await response_complete.wait() | |
return {"type": "http.disconnect"} | |
try: | |
body = await request_body_chunks.__anext__() | |
except StopAsyncIteration: | |
request_complete = True | |
return {"type": "http.request", "body": b"", "more_body": False} | |
return {"type": "http.request", "body": body, "more_body": True} | |
async def send(message: typing.MutableMapping[str, typing.Any]) -> None: | |
nonlocal status_code, response_headers, response_started | |
if message["type"] == "http.response.start": | |
assert not response_started | |
status_code = message["status"] | |
response_headers = message.get("headers", []) | |
response_started = True | |
elif message["type"] == "http.response.body": | |
assert not response_complete.is_set() | |
body = message.get("body", b"") | |
more_body = message.get("more_body", False) | |
if body and request.method != "HEAD": | |
body_parts.append(body) | |
if not more_body: | |
response_complete.set() | |
try: | |
await self.app(scope, receive, send) | |
except Exception: # noqa: PIE-786 | |
if self.raise_app_exceptions: | |
raise | |
response_complete.set() | |
if status_code is None: | |
status_code = 500 | |
if response_headers is None: | |
response_headers = {} | |
assert response_complete.is_set() | |
assert status_code is not None | |
assert response_headers is not None | |
stream = ASGIResponseStream(body_parts) | |
return Response(status_code, headers=response_headers, stream=stream) | |