Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import hashlib | |
import json | |
from collections import deque | |
from dataclasses import dataclass as python_dataclass | |
from tempfile import NamedTemporaryFile, _TemporaryFileWrapper | |
from typing import TYPE_CHECKING, AsyncGenerator, BinaryIO, List, Optional, Tuple, Union | |
import fastapi | |
import httpx | |
import multipart | |
from gradio_client.documentation import document, set_documentation_group | |
from multipart.multipart import parse_options_header | |
from starlette.datastructures import FormData, Headers, UploadFile | |
from starlette.formparsers import MultiPartException, MultipartPart | |
from gradio import utils | |
from gradio.data_classes import PredictBody | |
from gradio.exceptions import Error | |
from gradio.helpers import EventData | |
from gradio.state_holder import SessionState | |
if TYPE_CHECKING: | |
from gradio.blocks import Blocks | |
from gradio.routes import App | |
set_documentation_group("routes") | |
class Obj: | |
""" | |
Using a class to convert dictionaries into objects. Used by the `Request` class. | |
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ | |
""" | |
def __init__(self, dict_): | |
self.__dict__.update(dict_) | |
for key, value in dict_.items(): | |
if isinstance(value, (dict, list)): | |
value = Obj(value) | |
setattr(self, key, value) | |
def __getitem__(self, item): | |
return self.__dict__[item] | |
def __setitem__(self, item, value): | |
self.__dict__[item] = value | |
def __iter__(self): | |
for key, value in self.__dict__.items(): | |
if isinstance(value, Obj): | |
yield (key, dict(value)) | |
else: | |
yield (key, value) | |
def __contains__(self, item) -> bool: | |
if item in self.__dict__: | |
return True | |
for value in self.__dict__.values(): | |
if isinstance(value, Obj) and item in value: | |
return True | |
return False | |
def keys(self): | |
return self.__dict__.keys() | |
def values(self): | |
return self.__dict__.values() | |
def items(self): | |
return self.__dict__.items() | |
def __str__(self) -> str: | |
return str(self.__dict__) | |
def __repr__(self) -> str: | |
return str(self.__dict__) | |
class Request: | |
""" | |
A Gradio request object that can be used to access the request headers, cookies, | |
query parameters and other information about the request from within the prediction | |
function. The class is a thin wrapper around the fastapi.Request class. Attributes | |
of this class include: `headers`, `client`, `query_params`, and `path_params`. If | |
auth is enabled, the `username` attribute can be used to get the logged in user. | |
Example: | |
import gradio as gr | |
def echo(text, request: gr.Request): | |
if request: | |
print("Request headers dictionary:", request.headers) | |
print("IP address:", request.client.host) | |
print("Query parameters:", dict(request.query_params)) | |
return text | |
io = gr.Interface(echo, "textbox", "textbox").launch() | |
Demos: request_ip_headers | |
""" | |
def __init__( | |
self, | |
request: fastapi.Request | None = None, | |
username: str | None = None, | |
**kwargs, | |
): | |
""" | |
Can be instantiated with either a fastapi.Request or by manually passing in | |
attributes (needed for queueing). | |
Parameters: | |
request: A fastapi.Request | |
""" | |
self.request = request | |
self.username = username | |
self.kwargs: dict = kwargs | |
def dict_to_obj(self, d): | |
if isinstance(d, dict): | |
return json.loads(json.dumps(d), object_hook=Obj) | |
else: | |
return d | |
def __getattr__(self, name): | |
if self.request: | |
return self.dict_to_obj(getattr(self.request, name)) | |
else: | |
try: | |
obj = self.kwargs[name] | |
except KeyError as ke: | |
raise AttributeError( | |
f"'Request' object has no attribute '{name}'" | |
) from ke | |
return self.dict_to_obj(obj) | |
class FnIndexInferError(Exception): | |
pass | |
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: | |
if body.fn_index is None: | |
for i, fn in enumerate(app.get_blocks().dependencies): | |
if fn["api_name"] == api_name: | |
return i | |
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") | |
else: | |
return body.fn_index | |
def compile_gr_request( | |
app: App, | |
body: PredictBody, | |
fn_index_inferred: int, | |
username: Optional[str], | |
request: Optional[fastapi.Request], | |
): | |
# If this fn_index cancels jobs, then the only input we need is the | |
# current session hash | |
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]: | |
body.data = [body.session_hash] | |
if body.request: | |
if body.batched: | |
gr_request = [Request(username=username, request=request)] | |
else: | |
gr_request = Request(username=username, request=body.request) | |
else: | |
if request is None: | |
raise ValueError("request must be provided if body.request is None") | |
gr_request = Request(username=username, request=request) | |
return gr_request | |
def restore_session_state(app: App, body: PredictBody): | |
event_id = body.event_id | |
session_hash = getattr(body, "session_hash", None) | |
if session_hash is not None: | |
session_state = app.state_holder[session_hash] | |
# The should_reset set keeps track of the fn_indices | |
# that have been cancelled. When a job is cancelled, | |
# the /reset route will mark the jobs as having been reset. | |
# That way if the cancel job finishes BEFORE the job being cancelled | |
# the job being cancelled will not overwrite the state of the iterator. | |
if event_id is None: | |
iterator = None | |
elif event_id in app.iterators_to_reset: | |
iterator = None | |
app.iterators_to_reset.remove(event_id) | |
else: | |
iterator = app.iterators.get(event_id) | |
else: | |
session_state = SessionState(app.get_blocks()) | |
iterator = None | |
return session_state, iterator | |
def prepare_event_data( | |
blocks: Blocks, | |
body: PredictBody, | |
) -> EventData: | |
target = body.trigger_id | |
event_data = EventData( | |
blocks.blocks.get(target) if target else None, | |
body.event_data, | |
) | |
return event_data | |
async def call_process_api( | |
app: App, | |
body: PredictBody, | |
gr_request: Union[Request, list[Request]], | |
fn_index_inferred: int, | |
): | |
session_state, iterator = restore_session_state(app=app, body=body) | |
dependency = app.get_blocks().dependencies[fn_index_inferred] | |
event_data = prepare_event_data(app.get_blocks(), body) | |
event_id = body.event_id | |
session_hash = getattr(body, "session_hash", None) | |
inputs = body.data | |
batch_in_single_out = not body.batched and dependency["batch"] | |
if batch_in_single_out: | |
inputs = [inputs] | |
try: | |
with utils.MatplotlibBackendMananger(): | |
output = await app.get_blocks().process_api( | |
fn_index=fn_index_inferred, | |
inputs=inputs, | |
request=gr_request, | |
state=session_state, | |
iterator=iterator, | |
session_hash=session_hash, | |
event_id=event_id, | |
event_data=event_data, | |
in_event_listener=True, | |
) | |
iterator = output.pop("iterator", None) | |
if event_id is not None: | |
app.iterators[event_id] = iterator # type: ignore | |
if isinstance(output, Error): | |
raise output | |
except BaseException: | |
iterator = app.iterators.get(event_id) if event_id is not None else None | |
if iterator is not None: # close off any streams that are still open | |
run_id = id(iterator) | |
pending_streams: dict[int, list] = ( | |
app.get_blocks().pending_streams[session_hash].get(run_id, {}) | |
) | |
for stream in pending_streams.values(): | |
stream.append(None) | |
raise | |
if batch_in_single_out: | |
output["data"] = output["data"][0] | |
return output | |
def strip_url(orig_url: str) -> str: | |
""" | |
Strips the query parameters and trailing slash from a URL. | |
""" | |
parsed_url = httpx.URL(orig_url) | |
stripped_url = parsed_url.copy_with(query=None) | |
stripped_url = str(stripped_url) | |
return stripped_url.rstrip("/") | |
def _user_safe_decode(src: bytes, codec: str) -> str: | |
try: | |
return src.decode(codec) | |
except (UnicodeDecodeError, LookupError): | |
return src.decode("latin-1") | |
class GradioUploadFile(UploadFile): | |
"""UploadFile with a sha attribute.""" | |
def __init__( | |
self, | |
file: BinaryIO, | |
*, | |
size: int | None = None, | |
filename: str | None = None, | |
headers: Headers | None = None, | |
) -> None: | |
super().__init__(file, size=size, filename=filename, headers=headers) | |
self.sha = hashlib.sha1() | |
class FileUploadProgressUnit: | |
filename: str | |
chunk_size: int | |
is_done: bool | |
class FileUploadProgress: | |
def __init__(self) -> None: | |
self._statuses: dict[str, deque[FileUploadProgressUnit]] = {} | |
def track(self, upload_id: str): | |
if upload_id not in self._statuses: | |
self._statuses[upload_id] = deque() | |
def update(self, upload_id: str, filename: str, message_bytes: bytes): | |
if upload_id not in self._statuses: | |
self._statuses[upload_id] = deque() | |
self._statuses[upload_id].append( | |
FileUploadProgressUnit(filename, len(message_bytes), is_done=False) | |
) | |
def set_done(self, upload_id: str): | |
self._statuses[upload_id].append(FileUploadProgressUnit("", 0, is_done=True)) | |
def stop_tracking(self, upload_id: str): | |
if upload_id in self._statuses: | |
del self._statuses[upload_id] | |
def status(self, upload_id: str) -> deque[FileUploadProgressUnit]: | |
if upload_id not in self._statuses: | |
return deque() | |
return self._statuses[upload_id] | |
def is_tracked(self, upload_id: str): | |
return upload_id in self._statuses | |
class GradioMultiPartParser: | |
"""Vendored from starlette.MultipartParser. | |
Thanks starlette! | |
Made the following modifications | |
- Use GradioUploadFile instead of UploadFile | |
- Use NamedTemporaryFile instead of SpooledTemporaryFile | |
- Compute hash of data as the request is streamed | |
""" | |
max_file_size = 1024 * 1024 | |
def __init__( | |
self, | |
headers: Headers, | |
stream: AsyncGenerator[bytes, None], | |
*, | |
max_files: Union[int, float] = 1000, | |
max_fields: Union[int, float] = 1000, | |
upload_id: str | None = None, | |
upload_progress: FileUploadProgress | None = None, | |
) -> None: | |
assert ( | |
multipart is not None | |
), "The `python-multipart` library must be installed to use form parsing." | |
self.headers = headers | |
self.stream = stream | |
self.max_files = max_files | |
self.max_fields = max_fields | |
self.items: List[Tuple[str, Union[str, UploadFile]]] = [] | |
self.upload_id = upload_id | |
self.upload_progress = upload_progress | |
self._current_files = 0 | |
self._current_fields = 0 | |
self._current_partial_header_name: bytes = b"" | |
self._current_partial_header_value: bytes = b"" | |
self._current_part = MultipartPart() | |
self._charset = "" | |
self._file_parts_to_write: List[Tuple[MultipartPart, bytes]] = [] | |
self._file_parts_to_finish: List[MultipartPart] = [] | |
self._files_to_close_on_error: List[_TemporaryFileWrapper] = [] | |
def on_part_begin(self) -> None: | |
self._current_part = MultipartPart() | |
def on_part_data(self, data: bytes, start: int, end: int) -> None: | |
message_bytes = data[start:end] | |
if self.upload_progress is not None: | |
self.upload_progress.update( | |
self.upload_id, # type: ignore | |
self._current_part.file.filename, # type: ignore | |
message_bytes, | |
) | |
if self._current_part.file is None: | |
self._current_part.data += message_bytes | |
else: | |
self._file_parts_to_write.append((self._current_part, message_bytes)) | |
def on_part_end(self) -> None: | |
if self._current_part.file is None: | |
self.items.append( | |
( | |
self._current_part.field_name, | |
_user_safe_decode(self._current_part.data, self._charset), | |
) | |
) | |
else: | |
self._file_parts_to_finish.append(self._current_part) | |
# The file can be added to the items right now even though it's not | |
# finished yet, because it will be finished in the `parse()` method, before | |
# self.items is used in the return value. | |
self.items.append((self._current_part.field_name, self._current_part.file)) | |
def on_header_field(self, data: bytes, start: int, end: int) -> None: | |
self._current_partial_header_name += data[start:end] | |
def on_header_value(self, data: bytes, start: int, end: int) -> None: | |
self._current_partial_header_value += data[start:end] | |
def on_header_end(self) -> None: | |
field = self._current_partial_header_name.lower() | |
if field == b"content-disposition": | |
self._current_part.content_disposition = self._current_partial_header_value | |
self._current_part.item_headers.append( | |
(field, self._current_partial_header_value) | |
) | |
self._current_partial_header_name = b"" | |
self._current_partial_header_value = b"" | |
def on_headers_finished(self) -> None: | |
disposition, options = parse_options_header( | |
self._current_part.content_disposition | |
) | |
try: | |
self._current_part.field_name = _user_safe_decode( | |
options[b"name"], self._charset | |
) | |
except KeyError as e: | |
raise MultiPartException( | |
'The Content-Disposition header field "name" must be ' "provided." | |
) from e | |
if b"filename" in options: | |
self._current_files += 1 | |
if self._current_files > self.max_files: | |
raise MultiPartException( | |
f"Too many files. Maximum number of files is {self.max_files}." | |
) | |
filename = _user_safe_decode(options[b"filename"], self._charset) | |
tempfile = NamedTemporaryFile(delete=False) | |
self._files_to_close_on_error.append(tempfile) | |
self._current_part.file = GradioUploadFile( | |
file=tempfile, # type: ignore[arg-type] | |
size=0, | |
filename=filename, | |
headers=Headers(raw=self._current_part.item_headers), | |
) | |
else: | |
self._current_fields += 1 | |
if self._current_fields > self.max_fields: | |
raise MultiPartException( | |
f"Too many fields. Maximum number of fields is {self.max_fields}." | |
) | |
self._current_part.file = None | |
def on_end(self) -> None: | |
pass | |
async def parse(self) -> FormData: | |
# Parse the Content-Type header to get the multipart boundary. | |
_, params = parse_options_header(self.headers["Content-Type"]) | |
charset = params.get(b"charset", "utf-8") | |
if isinstance(charset, bytes): | |
charset = charset.decode("latin-1") | |
self._charset = charset | |
try: | |
boundary = params[b"boundary"] | |
except KeyError as e: | |
raise MultiPartException("Missing boundary in multipart.") from e | |
# Callbacks dictionary. | |
callbacks = { | |
"on_part_begin": self.on_part_begin, | |
"on_part_data": self.on_part_data, | |
"on_part_end": self.on_part_end, | |
"on_header_field": self.on_header_field, | |
"on_header_value": self.on_header_value, | |
"on_header_end": self.on_header_end, | |
"on_headers_finished": self.on_headers_finished, | |
"on_end": self.on_end, | |
} | |
# Create the parser. | |
parser = multipart.MultipartParser(boundary, callbacks) | |
try: | |
# Feed the parser with data from the request. | |
async for chunk in self.stream: | |
parser.write(chunk) | |
# Write file data, it needs to use await with the UploadFile methods | |
# that call the corresponding file methods *in a threadpool*, | |
# otherwise, if they were called directly in the callback methods above | |
# (regular, non-async functions), that would block the event loop in | |
# the main thread. | |
for part, data in self._file_parts_to_write: | |
assert part.file # for type checkers | |
await part.file.write(data) | |
part.file.sha.update(data) # type: ignore | |
for part in self._file_parts_to_finish: | |
assert part.file # for type checkers | |
await part.file.seek(0) | |
self._file_parts_to_write.clear() | |
self._file_parts_to_finish.clear() | |
except MultiPartException as exc: | |
# Close all the files if there was an error. | |
for file in self._files_to_close_on_error: | |
file.close() | |
raise exc | |
parser.finalize() | |
if self.upload_progress is not None: | |
self.upload_progress.set_done(self.upload_id) # type: ignore | |
return FormData(self.items) | |