Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import io | |
import os | |
import typing | |
from pathlib import Path | |
from ._types import ( | |
AsyncByteStream, | |
FileContent, | |
FileTypes, | |
RequestData, | |
RequestFiles, | |
SyncByteStream, | |
) | |
from ._utils import ( | |
format_form_param, | |
guess_content_type, | |
peek_filelike_length, | |
primitive_value_to_str, | |
to_bytes, | |
) | |
def get_multipart_boundary_from_content_type( | |
content_type: bytes | None, | |
) -> bytes | None: | |
if not content_type or not content_type.startswith(b"multipart/form-data"): | |
return None | |
# parse boundary according to | |
# https://www.rfc-editor.org/rfc/rfc2046#section-5.1.1 | |
if b";" in content_type: | |
for section in content_type.split(b";"): | |
if section.strip().lower().startswith(b"boundary="): | |
return section.strip()[len(b"boundary=") :].strip(b'"') | |
return None | |
class DataField: | |
""" | |
A single form field item, within a multipart form field. | |
""" | |
def __init__(self, name: str, value: str | bytes | int | float | None) -> None: | |
if not isinstance(name, str): | |
raise TypeError( | |
f"Invalid type for name. Expected str, got {type(name)}: {name!r}" | |
) | |
if value is not None and not isinstance(value, (str, bytes, int, float)): | |
raise TypeError( | |
"Invalid type for value. Expected primitive type," | |
f" got {type(value)}: {value!r}" | |
) | |
self.name = name | |
self.value: str | bytes = ( | |
value if isinstance(value, bytes) else primitive_value_to_str(value) | |
) | |
def render_headers(self) -> bytes: | |
if not hasattr(self, "_headers"): | |
name = format_form_param("name", self.name) | |
self._headers = b"".join( | |
[b"Content-Disposition: form-data; ", name, b"\r\n\r\n"] | |
) | |
return self._headers | |
def render_data(self) -> bytes: | |
if not hasattr(self, "_data"): | |
self._data = to_bytes(self.value) | |
return self._data | |
def get_length(self) -> int: | |
headers = self.render_headers() | |
data = self.render_data() | |
return len(headers) + len(data) | |
def render(self) -> typing.Iterator[bytes]: | |
yield self.render_headers() | |
yield self.render_data() | |
class FileField: | |
""" | |
A single file field item, within a multipart form field. | |
""" | |
CHUNK_SIZE = 64 * 1024 | |
def __init__(self, name: str, value: FileTypes) -> None: | |
self.name = name | |
fileobj: FileContent | |
headers: dict[str, str] = {} | |
content_type: str | None = None | |
# This large tuple based API largely mirror's requests' API | |
# It would be good to think of better APIs for this that we could | |
# include in httpx 2.0 since variable length tuples(especially of 4 elements) | |
# are quite unwieldly | |
if isinstance(value, tuple): | |
if len(value) == 2: | |
# neither the 3rd parameter (content_type) nor the 4th (headers) | |
# was included | |
filename, fileobj = value | |
elif len(value) == 3: | |
filename, fileobj, content_type = value | |
else: | |
# all 4 parameters included | |
filename, fileobj, content_type, headers = value # type: ignore | |
else: | |
filename = Path(str(getattr(value, "name", "upload"))).name | |
fileobj = value | |
if content_type is None: | |
content_type = guess_content_type(filename) | |
has_content_type_header = any("content-type" in key.lower() for key in headers) | |
if content_type is not None and not has_content_type_header: | |
# note that unlike requests, we ignore the content_type provided in the 3rd | |
# tuple element if it is also included in the headers requests does | |
# the opposite (it overwrites the headerwith the 3rd tuple element) | |
headers["Content-Type"] = content_type | |
if isinstance(fileobj, io.StringIO): | |
raise TypeError( | |
"Multipart file uploads require 'io.BytesIO', not 'io.StringIO'." | |
) | |
if isinstance(fileobj, io.TextIOBase): | |
raise TypeError( | |
"Multipart file uploads must be opened in binary mode, not text mode." | |
) | |
self.filename = filename | |
self.file = fileobj | |
self.headers = headers | |
def get_length(self) -> int | None: | |
headers = self.render_headers() | |
if isinstance(self.file, (str, bytes)): | |
return len(headers) + len(to_bytes(self.file)) | |
file_length = peek_filelike_length(self.file) | |
# If we can't determine the filesize without reading it into memory, | |
# then return `None` here, to indicate an unknown file length. | |
if file_length is None: | |
return None | |
return len(headers) + file_length | |
def render_headers(self) -> bytes: | |
if not hasattr(self, "_headers"): | |
parts = [ | |
b"Content-Disposition: form-data; ", | |
format_form_param("name", self.name), | |
] | |
if self.filename: | |
filename = format_form_param("filename", self.filename) | |
parts.extend([b"; ", filename]) | |
for header_name, header_value in self.headers.items(): | |
key, val = f"\r\n{header_name}: ".encode(), header_value.encode() | |
parts.extend([key, val]) | |
parts.append(b"\r\n\r\n") | |
self._headers = b"".join(parts) | |
return self._headers | |
def render_data(self) -> typing.Iterator[bytes]: | |
if isinstance(self.file, (str, bytes)): | |
yield to_bytes(self.file) | |
return | |
if hasattr(self.file, "seek"): | |
try: | |
self.file.seek(0) | |
except io.UnsupportedOperation: | |
pass | |
chunk = self.file.read(self.CHUNK_SIZE) | |
while chunk: | |
yield to_bytes(chunk) | |
chunk = self.file.read(self.CHUNK_SIZE) | |
def render(self) -> typing.Iterator[bytes]: | |
yield self.render_headers() | |
yield from self.render_data() | |
class MultipartStream(SyncByteStream, AsyncByteStream): | |
""" | |
Request content as streaming multipart encoded form data. | |
""" | |
def __init__( | |
self, | |
data: RequestData, | |
files: RequestFiles, | |
boundary: bytes | None = None, | |
) -> None: | |
if boundary is None: | |
boundary = os.urandom(16).hex().encode("ascii") | |
self.boundary = boundary | |
self.content_type = "multipart/form-data; boundary=%s" % boundary.decode( | |
"ascii" | |
) | |
self.fields = list(self._iter_fields(data, files)) | |
def _iter_fields( | |
self, data: RequestData, files: RequestFiles | |
) -> typing.Iterator[FileField | DataField]: | |
for name, value in data.items(): | |
if isinstance(value, (tuple, list)): | |
for item in value: | |
yield DataField(name=name, value=item) | |
else: | |
yield DataField(name=name, value=value) | |
file_items = files.items() if isinstance(files, typing.Mapping) else files | |
for name, value in file_items: | |
yield FileField(name=name, value=value) | |
def iter_chunks(self) -> typing.Iterator[bytes]: | |
for field in self.fields: | |
yield b"--%s\r\n" % self.boundary | |
yield from field.render() | |
yield b"\r\n" | |
yield b"--%s--\r\n" % self.boundary | |
def get_content_length(self) -> int | None: | |
""" | |
Return the length of the multipart encoded content, or `None` if | |
any of the files have a length that cannot be determined upfront. | |
""" | |
boundary_length = len(self.boundary) | |
length = 0 | |
for field in self.fields: | |
field_length = field.get_length() | |
if field_length is None: | |
return None | |
length += 2 + boundary_length + 2 # b"--{boundary}\r\n" | |
length += field_length | |
length += 2 # b"\r\n" | |
length += 2 + boundary_length + 4 # b"--{boundary}--\r\n" | |
return length | |
# Content stream interface. | |
def get_headers(self) -> dict[str, str]: | |
content_length = self.get_content_length() | |
content_type = self.content_type | |
if content_length is None: | |
return {"Transfer-Encoding": "chunked", "Content-Type": content_type} | |
return {"Content-Length": str(content_length), "Content-Type": content_type} | |
def __iter__(self) -> typing.Iterator[bytes]: | |
for chunk in self.iter_chunks(): | |
yield chunk | |
async def __aiter__(self) -> typing.AsyncIterator[bytes]: | |
for chunk in self.iter_chunks(): | |
yield chunk | |