Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import logging | |
import os | |
import ssl | |
import typing | |
from pathlib import Path | |
import certifi | |
from ._compat import set_minimum_tls_version_1_2 | |
from ._models import Headers | |
from ._types import CertTypes, HeaderTypes, TimeoutTypes, VerifyTypes | |
from ._urls import URL | |
from ._utils import get_ca_bundle_from_env | |
__all__ = ["Limits", "Proxy", "Timeout", "create_ssl_context"] | |
DEFAULT_CIPHERS = ":".join( | |
[ | |
"ECDHE+AESGCM", | |
"ECDHE+CHACHA20", | |
"DHE+AESGCM", | |
"DHE+CHACHA20", | |
"ECDH+AESGCM", | |
"DH+AESGCM", | |
"ECDH+AES", | |
"DH+AES", | |
"RSA+AESGCM", | |
"RSA+AES", | |
"!aNULL", | |
"!eNULL", | |
"!MD5", | |
"!DSS", | |
] | |
) | |
logger = logging.getLogger("httpx") | |
class UnsetType: | |
pass # pragma: no cover | |
UNSET = UnsetType() | |
def create_ssl_context( | |
cert: CertTypes | None = None, | |
verify: VerifyTypes = True, | |
trust_env: bool = True, | |
http2: bool = False, | |
) -> ssl.SSLContext: | |
return SSLConfig( | |
cert=cert, verify=verify, trust_env=trust_env, http2=http2 | |
).ssl_context | |
class SSLConfig: | |
""" | |
SSL Configuration. | |
""" | |
DEFAULT_CA_BUNDLE_PATH = Path(certifi.where()) | |
def __init__( | |
self, | |
*, | |
cert: CertTypes | None = None, | |
verify: VerifyTypes = True, | |
trust_env: bool = True, | |
http2: bool = False, | |
) -> None: | |
self.cert = cert | |
self.verify = verify | |
self.trust_env = trust_env | |
self.http2 = http2 | |
self.ssl_context = self.load_ssl_context() | |
def load_ssl_context(self) -> ssl.SSLContext: | |
logger.debug( | |
"load_ssl_context verify=%r cert=%r trust_env=%r http2=%r", | |
self.verify, | |
self.cert, | |
self.trust_env, | |
self.http2, | |
) | |
if self.verify: | |
return self.load_ssl_context_verify() | |
return self.load_ssl_context_no_verify() | |
def load_ssl_context_no_verify(self) -> ssl.SSLContext: | |
""" | |
Return an SSL context for unverified connections. | |
""" | |
context = self._create_default_ssl_context() | |
context.check_hostname = False | |
context.verify_mode = ssl.CERT_NONE | |
self._load_client_certs(context) | |
return context | |
def load_ssl_context_verify(self) -> ssl.SSLContext: | |
""" | |
Return an SSL context for verified connections. | |
""" | |
if self.trust_env and self.verify is True: | |
ca_bundle = get_ca_bundle_from_env() | |
if ca_bundle is not None: | |
self.verify = ca_bundle | |
if isinstance(self.verify, ssl.SSLContext): | |
# Allow passing in our own SSLContext object that's pre-configured. | |
context = self.verify | |
self._load_client_certs(context) | |
return context | |
elif isinstance(self.verify, bool): | |
ca_bundle_path = self.DEFAULT_CA_BUNDLE_PATH | |
elif Path(self.verify).exists(): | |
ca_bundle_path = Path(self.verify) | |
else: | |
raise IOError( | |
"Could not find a suitable TLS CA certificate bundle, " | |
"invalid path: {}".format(self.verify) | |
) | |
context = self._create_default_ssl_context() | |
context.verify_mode = ssl.CERT_REQUIRED | |
context.check_hostname = True | |
# Signal to server support for PHA in TLS 1.3. Raises an | |
# AttributeError if only read-only access is implemented. | |
try: | |
context.post_handshake_auth = True | |
except AttributeError: # pragma: no cover | |
pass | |
# Disable using 'commonName' for SSLContext.check_hostname | |
# when the 'subjectAltName' extension isn't available. | |
try: | |
context.hostname_checks_common_name = False | |
except AttributeError: # pragma: no cover | |
pass | |
if ca_bundle_path.is_file(): | |
cafile = str(ca_bundle_path) | |
logger.debug("load_verify_locations cafile=%r", cafile) | |
context.load_verify_locations(cafile=cafile) | |
elif ca_bundle_path.is_dir(): | |
capath = str(ca_bundle_path) | |
logger.debug("load_verify_locations capath=%r", capath) | |
context.load_verify_locations(capath=capath) | |
self._load_client_certs(context) | |
return context | |
def _create_default_ssl_context(self) -> ssl.SSLContext: | |
""" | |
Creates the default SSLContext object that's used for both verified | |
and unverified connections. | |
""" | |
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | |
set_minimum_tls_version_1_2(context) | |
context.options |= ssl.OP_NO_COMPRESSION | |
context.set_ciphers(DEFAULT_CIPHERS) | |
if ssl.HAS_ALPN: | |
alpn_idents = ["http/1.1", "h2"] if self.http2 else ["http/1.1"] | |
context.set_alpn_protocols(alpn_idents) | |
keylogfile = os.environ.get("SSLKEYLOGFILE") | |
if keylogfile and self.trust_env: | |
context.keylog_filename = keylogfile | |
return context | |
def _load_client_certs(self, ssl_context: ssl.SSLContext) -> None: | |
""" | |
Loads client certificates into our SSLContext object | |
""" | |
if self.cert is not None: | |
if isinstance(self.cert, str): | |
ssl_context.load_cert_chain(certfile=self.cert) | |
elif isinstance(self.cert, tuple) and len(self.cert) == 2: | |
ssl_context.load_cert_chain(certfile=self.cert[0], keyfile=self.cert[1]) | |
elif isinstance(self.cert, tuple) and len(self.cert) == 3: | |
ssl_context.load_cert_chain( | |
certfile=self.cert[0], | |
keyfile=self.cert[1], | |
password=self.cert[2], | |
) | |
class Timeout: | |
""" | |
Timeout configuration. | |
**Usage**: | |
Timeout(None) # No timeouts. | |
Timeout(5.0) # 5s timeout on all operations. | |
Timeout(None, connect=5.0) # 5s timeout on connect, no other timeouts. | |
Timeout(5.0, connect=10.0) # 10s timeout on connect. 5s timeout elsewhere. | |
Timeout(5.0, pool=None) # No timeout on acquiring connection from pool. | |
# 5s timeout elsewhere. | |
""" | |
def __init__( | |
self, | |
timeout: TimeoutTypes | UnsetType = UNSET, | |
*, | |
connect: None | float | UnsetType = UNSET, | |
read: None | float | UnsetType = UNSET, | |
write: None | float | UnsetType = UNSET, | |
pool: None | float | UnsetType = UNSET, | |
) -> None: | |
if isinstance(timeout, Timeout): | |
# Passed as a single explicit Timeout. | |
assert connect is UNSET | |
assert read is UNSET | |
assert write is UNSET | |
assert pool is UNSET | |
self.connect = timeout.connect # type: typing.Optional[float] | |
self.read = timeout.read # type: typing.Optional[float] | |
self.write = timeout.write # type: typing.Optional[float] | |
self.pool = timeout.pool # type: typing.Optional[float] | |
elif isinstance(timeout, tuple): | |
# Passed as a tuple. | |
self.connect = timeout[0] | |
self.read = timeout[1] | |
self.write = None if len(timeout) < 3 else timeout[2] | |
self.pool = None if len(timeout) < 4 else timeout[3] | |
elif not ( | |
isinstance(connect, UnsetType) | |
or isinstance(read, UnsetType) | |
or isinstance(write, UnsetType) | |
or isinstance(pool, UnsetType) | |
): | |
self.connect = connect | |
self.read = read | |
self.write = write | |
self.pool = pool | |
else: | |
if isinstance(timeout, UnsetType): | |
raise ValueError( | |
"httpx.Timeout must either include a default, or set all " | |
"four parameters explicitly." | |
) | |
self.connect = timeout if isinstance(connect, UnsetType) else connect | |
self.read = timeout if isinstance(read, UnsetType) else read | |
self.write = timeout if isinstance(write, UnsetType) else write | |
self.pool = timeout if isinstance(pool, UnsetType) else pool | |
def as_dict(self) -> dict[str, float | None]: | |
return { | |
"connect": self.connect, | |
"read": self.read, | |
"write": self.write, | |
"pool": self.pool, | |
} | |
def __eq__(self, other: typing.Any) -> bool: | |
return ( | |
isinstance(other, self.__class__) | |
and self.connect == other.connect | |
and self.read == other.read | |
and self.write == other.write | |
and self.pool == other.pool | |
) | |
def __repr__(self) -> str: | |
class_name = self.__class__.__name__ | |
if len({self.connect, self.read, self.write, self.pool}) == 1: | |
return f"{class_name}(timeout={self.connect})" | |
return ( | |
f"{class_name}(connect={self.connect}, " | |
f"read={self.read}, write={self.write}, pool={self.pool})" | |
) | |
class Limits: | |
""" | |
Configuration for limits to various client behaviors. | |
**Parameters:** | |
* **max_connections** - The maximum number of concurrent connections that may be | |
established. | |
* **max_keepalive_connections** - Allow the connection pool to maintain | |
keep-alive connections below this point. Should be less than or equal | |
to `max_connections`. | |
* **keepalive_expiry** - Time limit on idle keep-alive connections in seconds. | |
""" | |
def __init__( | |
self, | |
*, | |
max_connections: int | None = None, | |
max_keepalive_connections: int | None = None, | |
keepalive_expiry: float | None = 5.0, | |
) -> None: | |
self.max_connections = max_connections | |
self.max_keepalive_connections = max_keepalive_connections | |
self.keepalive_expiry = keepalive_expiry | |
def __eq__(self, other: typing.Any) -> bool: | |
return ( | |
isinstance(other, self.__class__) | |
and self.max_connections == other.max_connections | |
and self.max_keepalive_connections == other.max_keepalive_connections | |
and self.keepalive_expiry == other.keepalive_expiry | |
) | |
def __repr__(self) -> str: | |
class_name = self.__class__.__name__ | |
return ( | |
f"{class_name}(max_connections={self.max_connections}, " | |
f"max_keepalive_connections={self.max_keepalive_connections}, " | |
f"keepalive_expiry={self.keepalive_expiry})" | |
) | |
class Proxy: | |
def __init__( | |
self, | |
url: URL | str, | |
*, | |
ssl_context: ssl.SSLContext | None = None, | |
auth: tuple[str, str] | None = None, | |
headers: HeaderTypes | None = None, | |
) -> None: | |
url = URL(url) | |
headers = Headers(headers) | |
if url.scheme not in ("http", "https", "socks5"): | |
raise ValueError(f"Unknown scheme for proxy URL {url!r}") | |
if url.username or url.password: | |
# Remove any auth credentials from the URL. | |
auth = (url.username, url.password) | |
url = url.copy_with(username=None, password=None) | |
self.url = url | |
self.auth = auth | |
self.headers = headers | |
self.ssl_context = ssl_context | |
def raw_auth(self) -> tuple[bytes, bytes] | None: | |
# The proxy authentication as raw bytes. | |
return ( | |
None | |
if self.auth is None | |
else (self.auth[0].encode("utf-8"), self.auth[1].encode("utf-8")) | |
) | |
def __repr__(self) -> str: | |
# The authentication is represented with the password component masked. | |
auth = (self.auth[0], "********") if self.auth else None | |
# Build a nice concise representation. | |
url_str = f"{str(self.url)!r}" | |
auth_str = f", auth={auth!r}" if auth else "" | |
headers_str = f", headers={dict(self.headers)!r}" if self.headers else "" | |
return f"Proxy({url_str}{auth_str}{headers_str})" | |
DEFAULT_TIMEOUT_CONFIG = Timeout(timeout=5.0) | |
DEFAULT_LIMITS = Limits(max_connections=100, max_keepalive_connections=20) | |
DEFAULT_MAX_REDIRECTS = 20 | |