Spaces:
Running
Running
from __future__ import annotations | |
import logging | |
import re | |
import threading | |
import types | |
import typing | |
import h2.config # type: ignore[import-untyped] | |
import h2.connection # type: ignore[import-untyped] | |
import h2.events # type: ignore[import-untyped] | |
from .._base_connection import _TYPE_BODY | |
from .._collections import HTTPHeaderDict | |
from ..connection import HTTPSConnection, _get_default_user_agent | |
from ..exceptions import ConnectionError | |
from ..response import BaseHTTPResponse | |
orig_HTTPSConnection = HTTPSConnection | |
T = typing.TypeVar("T") | |
log = logging.getLogger(__name__) | |
RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$") | |
RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$") | |
def _is_legal_header_name(name: bytes) -> bool: | |
""" | |
"An implementation that validates fields according to the definitions in Sections | |
5.1 and 5.5 of [HTTP] only needs an additional check that field names do not | |
include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity) | |
`http.client._is_legal_header_name` does not validate the field name according to the | |
HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters. | |
This does not allow for the `:` character in the header name, so should not | |
be used to validate pseudo-headers. | |
""" | |
return bool(RE_IS_LEGAL_HEADER_NAME.match(name)) | |
def _is_illegal_header_value(value: bytes) -> bool: | |
""" | |
"A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed | |
(ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field | |
value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, | |
0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity) | |
""" | |
return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value)) | |
class _LockedObject(typing.Generic[T]): | |
""" | |
A wrapper class that hides a specific object behind a lock. | |
The goal here is to provide a simple way to protect access to an object | |
that cannot safely be simultaneously accessed from multiple threads. The | |
intended use of this class is simple: take hold of it with a context | |
manager, which returns the protected object. | |
""" | |
__slots__ = ( | |
"lock", | |
"_obj", | |
) | |
def __init__(self, obj: T): | |
self.lock = threading.RLock() | |
self._obj = obj | |
def __enter__(self) -> T: | |
self.lock.acquire() | |
return self._obj | |
def __exit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: types.TracebackType | None, | |
) -> None: | |
self.lock.release() | |
class HTTP2Connection(HTTPSConnection): | |
def __init__( | |
self, host: str, port: int | None = None, **kwargs: typing.Any | |
) -> None: | |
self._h2_conn = self._new_h2_conn() | |
self._h2_stream: int | None = None | |
self._headers: list[tuple[bytes, bytes]] = [] | |
if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive: | |
raise NotImplementedError("Proxies aren't supported with HTTP/2") | |
super().__init__(host, port, **kwargs) | |
if self._tunnel_host is not None: | |
raise NotImplementedError("Tunneling isn't supported with HTTP/2") | |
def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: | |
config = h2.config.H2Configuration(client_side=True) | |
return _LockedObject(h2.connection.H2Connection(config=config)) | |
def connect(self) -> None: | |
super().connect() | |
with self._h2_conn as conn: | |
conn.initiate_connection() | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
def putrequest( # type: ignore[override] | |
self, | |
method: str, | |
url: str, | |
**kwargs: typing.Any, | |
) -> None: | |
"""putrequest | |
This deviates from the HTTPConnection method signature since we never need to override | |
sending accept-encoding headers or the host header. | |
""" | |
if "skip_host" in kwargs: | |
raise NotImplementedError("`skip_host` isn't supported") | |
if "skip_accept_encoding" in kwargs: | |
raise NotImplementedError("`skip_accept_encoding` isn't supported") | |
self._request_url = url or "/" | |
self._validate_path(url) # type: ignore[attr-defined] | |
if ":" in self.host: | |
authority = f"[{self.host}]:{self.port or 443}" | |
else: | |
authority = f"{self.host}:{self.port or 443}" | |
self._headers.append((b":scheme", b"https")) | |
self._headers.append((b":method", method.encode())) | |
self._headers.append((b":authority", authority.encode())) | |
self._headers.append((b":path", url.encode())) | |
with self._h2_conn as conn: | |
self._h2_stream = conn.get_next_available_stream_id() | |
def putheader(self, header: str | bytes, *values: str | bytes) -> None: | |
# TODO SKIPPABLE_HEADERS from urllib3 are ignored. | |
header = header.encode() if isinstance(header, str) else header | |
header = header.lower() # A lot of upstream code uses capitalized headers. | |
if not _is_legal_header_name(header): | |
raise ValueError(f"Illegal header name {str(header)}") | |
for value in values: | |
value = value.encode() if isinstance(value, str) else value | |
if _is_illegal_header_value(value): | |
raise ValueError(f"Illegal header value {str(value)}") | |
self._headers.append((header, value)) | |
def endheaders(self, message_body: typing.Any = None) -> None: # type: ignore[override] | |
if self._h2_stream is None: | |
raise ConnectionError("Must call `putrequest` first.") | |
with self._h2_conn as conn: | |
conn.send_headers( | |
stream_id=self._h2_stream, | |
headers=self._headers, | |
end_stream=(message_body is None), | |
) | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
self._headers = [] # Reset headers for the next request. | |
def send(self, data: typing.Any) -> None: | |
"""Send data to the server. | |
`data` can be: `str`, `bytes`, an iterable, or file-like objects | |
that support a .read() method. | |
""" | |
if self._h2_stream is None: | |
raise ConnectionError("Must call `putrequest` first.") | |
with self._h2_conn as conn: | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
if hasattr(data, "read"): # file-like objects | |
while True: | |
chunk = data.read(self.blocksize) | |
if not chunk: | |
break | |
if isinstance(chunk, str): | |
chunk = chunk.encode() # pragma: no cover | |
conn.send_data(self._h2_stream, chunk, end_stream=False) | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
conn.end_stream(self._h2_stream) | |
return | |
if isinstance(data, str): # str -> bytes | |
data = data.encode() | |
try: | |
if isinstance(data, bytes): | |
conn.send_data(self._h2_stream, data, end_stream=True) | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
else: | |
for chunk in data: | |
conn.send_data(self._h2_stream, chunk, end_stream=False) | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
conn.end_stream(self._h2_stream) | |
except TypeError: | |
raise TypeError( | |
"`data` should be str, bytes, iterable, or file. got %r" | |
% type(data) | |
) | |
def set_tunnel( | |
self, | |
host: str, | |
port: int | None = None, | |
headers: typing.Mapping[str, str] | None = None, | |
scheme: str = "http", | |
) -> None: | |
raise NotImplementedError( | |
"HTTP/2 does not support setting up a tunnel through a proxy" | |
) | |
def getresponse( # type: ignore[override] | |
self, | |
) -> HTTP2Response: | |
status = None | |
data = bytearray() | |
with self._h2_conn as conn: | |
end_stream = False | |
while not end_stream: | |
# TODO: Arbitrary read value. | |
if received_data := self.sock.recv(65535): | |
events = conn.receive_data(received_data) | |
for event in events: | |
if isinstance(event, h2.events.ResponseReceived): | |
headers = HTTPHeaderDict() | |
for header, value in event.headers: | |
if header == b":status": | |
status = int(value.decode()) | |
else: | |
headers.add( | |
header.decode("ascii"), value.decode("ascii") | |
) | |
elif isinstance(event, h2.events.DataReceived): | |
data += event.data | |
conn.acknowledge_received_data( | |
event.flow_controlled_length, event.stream_id | |
) | |
elif isinstance(event, h2.events.StreamEnded): | |
end_stream = True | |
if data_to_send := conn.data_to_send(): | |
self.sock.sendall(data_to_send) | |
assert status is not None | |
return HTTP2Response( | |
status=status, | |
headers=headers, | |
request_url=self._request_url, | |
data=bytes(data), | |
) | |
def request( # type: ignore[override] | |
self, | |
method: str, | |
url: str, | |
body: _TYPE_BODY | None = None, | |
headers: typing.Mapping[str, str] | None = None, | |
*, | |
preload_content: bool = True, | |
decode_content: bool = True, | |
enforce_content_length: bool = True, | |
**kwargs: typing.Any, | |
) -> None: | |
"""Send an HTTP/2 request""" | |
if "chunked" in kwargs: | |
# TODO this is often present from upstream. | |
# raise NotImplementedError("`chunked` isn't supported with HTTP/2") | |
pass | |
if self.sock is not None: | |
self.sock.settimeout(self.timeout) | |
self.putrequest(method, url) | |
headers = headers or {} | |
for k, v in headers.items(): | |
if k.lower() == "transfer-encoding" and v == "chunked": | |
continue | |
else: | |
self.putheader(k, v) | |
if b"user-agent" not in dict(self._headers): | |
self.putheader(b"user-agent", _get_default_user_agent()) | |
if body: | |
self.endheaders(message_body=body) | |
self.send(body) | |
else: | |
self.endheaders() | |
def close(self) -> None: | |
with self._h2_conn as conn: | |
try: | |
conn.close_connection() | |
if data := conn.data_to_send(): | |
self.sock.sendall(data) | |
except Exception: | |
pass | |
# Reset all our HTTP/2 connection state. | |
self._h2_conn = self._new_h2_conn() | |
self._h2_stream = None | |
self._headers = [] | |
super().close() | |
class HTTP2Response(BaseHTTPResponse): | |
# TODO: This is a woefully incomplete response object, but works for non-streaming. | |
def __init__( | |
self, | |
status: int, | |
headers: HTTPHeaderDict, | |
request_url: str, | |
data: bytes, | |
decode_content: bool = False, # TODO: support decoding | |
) -> None: | |
super().__init__( | |
status=status, | |
headers=headers, | |
# Following CPython, we map HTTP versions to major * 10 + minor integers | |
version=20, | |
version_string="HTTP/2", | |
# No reason phrase in HTTP/2 | |
reason=None, | |
decode_content=decode_content, | |
request_url=request_url, | |
) | |
self._data = data | |
self.length_remaining = 0 | |
def data(self) -> bytes: | |
return self._data | |
def get_redirect_location(self) -> None: | |
return None | |
def close(self) -> None: | |
pass | |