|
from __future__ import annotations |
|
|
|
import json |
|
from copy import deepcopy |
|
from typing import TYPE_CHECKING, Optional, Union |
|
|
|
import fastapi |
|
from gradio_client.documentation import document, set_documentation_group |
|
|
|
from gradio import utils |
|
from gradio.data_classes import PredictBody |
|
from gradio.exceptions import Error |
|
from gradio.helpers import EventData |
|
|
|
if TYPE_CHECKING: |
|
from gradio.routes import App |
|
|
|
set_documentation_group("routes") |
|
|
|
|
|
class Obj: |
|
""" |
|
Using a class to convert dictionaries into objects. Used by the `Request` class. |
|
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/ |
|
""" |
|
|
|
def __init__(self, dict_): |
|
self.__dict__.update(dict_) |
|
for key, value in dict_.items(): |
|
if isinstance(value, (dict, list)): |
|
value = Obj(value) |
|
setattr(self, key, value) |
|
|
|
def __getitem__(self, item): |
|
return self.__dict__[item] |
|
|
|
def __setitem__(self, item, value): |
|
self.__dict__[item] = value |
|
|
|
def __iter__(self): |
|
for key, value in self.__dict__.items(): |
|
if isinstance(value, Obj): |
|
yield (key, dict(value)) |
|
else: |
|
yield (key, value) |
|
|
|
def __contains__(self, item) -> bool: |
|
if item in self.__dict__: |
|
return True |
|
for value in self.__dict__.values(): |
|
if isinstance(value, Obj) and item in value: |
|
return True |
|
return False |
|
|
|
def keys(self): |
|
return self.__dict__.keys() |
|
|
|
def values(self): |
|
return self.__dict__.values() |
|
|
|
def items(self): |
|
return self.__dict__.items() |
|
|
|
def __str__(self) -> str: |
|
return str(self.__dict__) |
|
|
|
def __repr__(self) -> str: |
|
return str(self.__dict__) |
|
|
|
|
|
@document() |
|
class Request: |
|
""" |
|
A Gradio request object that can be used to access the request headers, cookies, |
|
query parameters and other information about the request from within the prediction |
|
function. The class is a thin wrapper around the fastapi.Request class. Attributes |
|
of this class include: `headers`, `client`, `query_params`, and `path_params`. If |
|
auth is enabled, the `username` attribute can be used to get the logged in user. |
|
Example: |
|
import gradio as gr |
|
def echo(name, request: gr.Request): |
|
print("Request headers dictionary:", request.headers) |
|
print("IP address:", request.client.host) |
|
return name |
|
io = gr.Interface(echo, "textbox", "textbox").launch() |
|
""" |
|
|
|
def __init__( |
|
self, |
|
request: fastapi.Request | None = None, |
|
username: str | None = None, |
|
**kwargs, |
|
): |
|
""" |
|
Can be instantiated with either a fastapi.Request or by manually passing in |
|
attributes (needed for websocket-based queueing). |
|
Parameters: |
|
request: A fastapi.Request |
|
""" |
|
self.request = request |
|
self.username = username |
|
self.kwargs: dict = kwargs |
|
|
|
def dict_to_obj(self, d): |
|
if isinstance(d, dict): |
|
return json.loads(json.dumps(d), object_hook=Obj) |
|
else: |
|
return d |
|
|
|
def __getattr__(self, name): |
|
if self.request: |
|
return self.dict_to_obj(getattr(self.request, name)) |
|
else: |
|
try: |
|
obj = self.kwargs[name] |
|
except KeyError as ke: |
|
raise AttributeError( |
|
f"'Request' object has no attribute '{name}'" |
|
) from ke |
|
return self.dict_to_obj(obj) |
|
|
|
|
|
class FnIndexInferError(Exception): |
|
pass |
|
|
|
|
|
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int: |
|
if body.fn_index is None: |
|
for i, fn in enumerate(app.get_blocks().dependencies): |
|
if fn["api_name"] == api_name: |
|
return i |
|
|
|
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.") |
|
else: |
|
return body.fn_index |
|
|
|
|
|
def compile_gr_request( |
|
app: App, |
|
body: PredictBody, |
|
fn_index_inferred: int, |
|
username: Optional[str], |
|
request: Optional[fastapi.Request], |
|
): |
|
|
|
|
|
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]: |
|
body.data = [body.session_hash] |
|
if body.request: |
|
if body.batched: |
|
gr_request = [Request(username=username, **req) for req in body.request] |
|
else: |
|
assert isinstance(body.request, dict) |
|
gr_request = Request(username=username, **body.request) |
|
else: |
|
if request is None: |
|
raise ValueError("request must be provided if body.request is None") |
|
gr_request = Request(username=username, request=request) |
|
|
|
return gr_request |
|
|
|
|
|
def restore_session_state(app: App, body: PredictBody): |
|
fn_index = body.fn_index |
|
session_hash = getattr(body, "session_hash", None) |
|
if session_hash is not None: |
|
if session_hash not in app.state_holder: |
|
app.state_holder[session_hash] = { |
|
_id: deepcopy(getattr(block, "value", None)) |
|
for _id, block in app.get_blocks().blocks.items() |
|
if getattr(block, "stateful", False) |
|
} |
|
session_state = app.state_holder[session_hash] |
|
|
|
|
|
|
|
|
|
|
|
if fn_index in app.iterators_to_reset[session_hash]: |
|
iterators = {} |
|
app.iterators_to_reset[session_hash].remove(fn_index) |
|
else: |
|
iterators = app.iterators[session_hash] |
|
else: |
|
session_state = {} |
|
iterators = {} |
|
|
|
return session_state, iterators |
|
|
|
|
|
async def call_process_api( |
|
app: App, |
|
body: PredictBody, |
|
gr_request: Union[Request, list[Request]], |
|
fn_index_inferred, |
|
): |
|
session_state, iterators = restore_session_state(app=app, body=body) |
|
|
|
dependency = app.get_blocks().dependencies[fn_index_inferred] |
|
|
|
target = dependency["targets"][0] if len(dependency["targets"]) else None |
|
event_data = EventData( |
|
app.get_blocks().blocks.get(target) if target else None, |
|
body.event_data, |
|
) |
|
|
|
event_id = getattr(body, "event_id", None) |
|
|
|
fn_index = body.fn_index |
|
session_hash = getattr(body, "session_hash", None) |
|
inputs = body.data |
|
|
|
batch_in_single_out = not body.batched and dependency["batch"] |
|
if batch_in_single_out: |
|
inputs = [inputs] |
|
|
|
try: |
|
with utils.MatplotlibBackendMananger(): |
|
output = await app.get_blocks().process_api( |
|
fn_index=fn_index_inferred, |
|
inputs=inputs, |
|
request=gr_request, |
|
state=session_state, |
|
iterators=iterators, |
|
session_hash=session_hash, |
|
event_id=event_id, |
|
event_data=event_data, |
|
) |
|
iterator = output.pop("iterator", None) |
|
if hasattr(body, "session_hash"): |
|
app.iterators[body.session_hash][fn_index] = iterator |
|
if isinstance(output, Error): |
|
raise output |
|
except BaseException: |
|
iterator = iterators.get(fn_index, None) |
|
if iterator is not None: |
|
run_id = id(iterator) |
|
pending_streams: dict[int, list] = ( |
|
app.get_blocks().pending_streams[session_hash].get(run_id, {}) |
|
) |
|
for stream in pending_streams.values(): |
|
stream.append(None) |
|
raise |
|
|
|
if batch_in_single_out: |
|
output["data"] = output["data"][0] |
|
|
|
return output |
|
|