Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import io | |
import itertools | |
import sys | |
import typing | |
from .._models import Request, Response | |
from .._types import SyncByteStream | |
from .base import BaseTransport | |
if typing.TYPE_CHECKING: | |
from _typeshed import OptExcInfo # pragma: no cover | |
from _typeshed.wsgi import WSGIApplication # pragma: no cover | |
_T = typing.TypeVar("_T") | |
__all__ = ["WSGITransport"] | |
def _skip_leading_empty_chunks(body: typing.Iterable[_T]) -> typing.Iterable[_T]: | |
body = iter(body) | |
for chunk in body: | |
if chunk: | |
return itertools.chain([chunk], body) | |
return [] | |
class WSGIByteStream(SyncByteStream): | |
def __init__(self, result: typing.Iterable[bytes]) -> None: | |
self._close = getattr(result, "close", None) | |
self._result = _skip_leading_empty_chunks(result) | |
def __iter__(self) -> typing.Iterator[bytes]: | |
for part in self._result: | |
yield part | |
def close(self) -> None: | |
if self._close is not None: | |
self._close() | |
class WSGITransport(BaseTransport): | |
""" | |
A custom transport that handles sending requests directly to an WSGI app. | |
The simplest way to use this functionality is to use the `app` argument. | |
``` | |
client = httpx.Client(app=app) | |
``` | |
Alternatively, you can setup the transport instance explicitly. | |
This allows you to include any additional configuration arguments specific | |
to the WSGITransport class: | |
``` | |
transport = httpx.WSGITransport( | |
app=app, | |
script_name="/submount", | |
remote_addr="1.2.3.4" | |
) | |
client = httpx.Client(transport=transport) | |
``` | |
Arguments: | |
* `app` - The WSGI application. | |
* `raise_app_exceptions` - Boolean indicating if exceptions in the application | |
should be raised. Default to `True`. Can be set to `False` for use cases | |
such as testing the content of a client 500 response. | |
* `script_name` - The root path on which the WSGI application should be mounted. | |
* `remote_addr` - A string indicating the client IP of incoming requests. | |
``` | |
""" | |
def __init__( | |
self, | |
app: WSGIApplication, | |
raise_app_exceptions: bool = True, | |
script_name: str = "", | |
remote_addr: str = "127.0.0.1", | |
wsgi_errors: typing.TextIO | None = None, | |
) -> None: | |
self.app = app | |
self.raise_app_exceptions = raise_app_exceptions | |
self.script_name = script_name | |
self.remote_addr = remote_addr | |
self.wsgi_errors = wsgi_errors | |
def handle_request(self, request: Request) -> Response: | |
request.read() | |
wsgi_input = io.BytesIO(request.content) | |
port = request.url.port or {"http": 80, "https": 443}[request.url.scheme] | |
environ = { | |
"wsgi.version": (1, 0), | |
"wsgi.url_scheme": request.url.scheme, | |
"wsgi.input": wsgi_input, | |
"wsgi.errors": self.wsgi_errors or sys.stderr, | |
"wsgi.multithread": True, | |
"wsgi.multiprocess": False, | |
"wsgi.run_once": False, | |
"REQUEST_METHOD": request.method, | |
"SCRIPT_NAME": self.script_name, | |
"PATH_INFO": request.url.path, | |
"QUERY_STRING": request.url.query.decode("ascii"), | |
"SERVER_NAME": request.url.host, | |
"SERVER_PORT": str(port), | |
"SERVER_PROTOCOL": "HTTP/1.1", | |
"REMOTE_ADDR": self.remote_addr, | |
} | |
for header_key, header_value in request.headers.raw: | |
key = header_key.decode("ascii").upper().replace("-", "_") | |
if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): | |
key = "HTTP_" + key | |
environ[key] = header_value.decode("ascii") | |
seen_status = None | |
seen_response_headers = None | |
seen_exc_info = None | |
def start_response( | |
status: str, | |
response_headers: list[tuple[str, str]], | |
exc_info: OptExcInfo | None = None, | |
) -> typing.Callable[[bytes], typing.Any]: | |
nonlocal seen_status, seen_response_headers, seen_exc_info | |
seen_status = status | |
seen_response_headers = response_headers | |
seen_exc_info = exc_info | |
return lambda _: None | |
result = self.app(environ, start_response) | |
stream = WSGIByteStream(result) | |
assert seen_status is not None | |
assert seen_response_headers is not None | |
if seen_exc_info and seen_exc_info[0] and self.raise_app_exceptions: | |
raise seen_exc_info[1] | |
status_code = int(seen_status.split()[0]) | |
headers = [ | |
(key.encode("ascii"), value.encode("ascii")) | |
for key, value in seen_response_headers | |
] | |
return Response(status_code, headers=headers, stream=stream) | |