File size: 7,949 Bytes
870ab6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
from __future__ import annotations
import json
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, Union
import fastapi
from gradio_client.documentation import document, set_documentation_group
from gradio import utils
from gradio.data_classes import PredictBody
from gradio.exceptions import Error
from gradio.helpers import EventData
if TYPE_CHECKING:
from gradio.routes import App
set_documentation_group("routes")
class Obj:
"""
Using a class to convert dictionaries into objects. Used by the `Request` class.
Credit: https://www.geeksforgeeks.org/convert-nested-python-dictionary-to-object/
"""
def __init__(self, dict_):
self.__dict__.update(dict_)
for key, value in dict_.items():
if isinstance(value, (dict, list)):
value = Obj(value)
setattr(self, key, value)
def __getitem__(self, item):
return self.__dict__[item]
def __setitem__(self, item, value):
self.__dict__[item] = value
def __iter__(self):
for key, value in self.__dict__.items():
if isinstance(value, Obj):
yield (key, dict(value))
else:
yield (key, value)
def __contains__(self, item) -> bool:
if item in self.__dict__:
return True
for value in self.__dict__.values():
if isinstance(value, Obj) and item in value:
return True
return False
def keys(self):
return self.__dict__.keys()
def values(self):
return self.__dict__.values()
def items(self):
return self.__dict__.items()
def __str__(self) -> str:
return str(self.__dict__)
def __repr__(self) -> str:
return str(self.__dict__)
@document()
class Request:
"""
A Gradio request object that can be used to access the request headers, cookies,
query parameters and other information about the request from within the prediction
function. The class is a thin wrapper around the fastapi.Request class. Attributes
of this class include: `headers`, `client`, `query_params`, and `path_params`. If
auth is enabled, the `username` attribute can be used to get the logged in user.
Example:
import gradio as gr
def echo(name, request: gr.Request):
print("Request headers dictionary:", request.headers)
print("IP address:", request.client.host)
return name
io = gr.Interface(echo, "textbox", "textbox").launch()
"""
def __init__(
self,
request: fastapi.Request | None = None,
username: str | None = None,
**kwargs,
):
"""
Can be instantiated with either a fastapi.Request or by manually passing in
attributes (needed for websocket-based queueing).
Parameters:
request: A fastapi.Request
"""
self.request = request
self.username = username
self.kwargs: dict = kwargs
def dict_to_obj(self, d):
if isinstance(d, dict):
return json.loads(json.dumps(d), object_hook=Obj)
else:
return d
def __getattr__(self, name):
if self.request:
return self.dict_to_obj(getattr(self.request, name))
else:
try:
obj = self.kwargs[name]
except KeyError as ke:
raise AttributeError(
f"'Request' object has no attribute '{name}'"
) from ke
return self.dict_to_obj(obj)
class FnIndexInferError(Exception):
pass
def infer_fn_index(app: App, api_name: str, body: PredictBody) -> int:
if body.fn_index is None:
for i, fn in enumerate(app.get_blocks().dependencies):
if fn["api_name"] == api_name:
return i
raise FnIndexInferError(f"Could not infer fn_index for api_name {api_name}.")
else:
return body.fn_index
def compile_gr_request(
app: App,
body: PredictBody,
fn_index_inferred: int,
username: Optional[str],
request: Optional[fastapi.Request],
):
# If this fn_index cancels jobs, then the only input we need is the
# current session hash
if app.get_blocks().dependencies[fn_index_inferred]["cancels"]:
body.data = [body.session_hash]
if body.request:
if body.batched:
gr_request = [Request(username=username, **req) for req in body.request]
else:
assert isinstance(body.request, dict)
gr_request = Request(username=username, **body.request)
else:
if request is None:
raise ValueError("request must be provided if body.request is None")
gr_request = Request(username=username, request=request)
return gr_request
def restore_session_state(app: App, body: PredictBody):
fn_index = body.fn_index
session_hash = getattr(body, "session_hash", None)
if session_hash is not None:
if session_hash not in app.state_holder:
app.state_holder[session_hash] = {
_id: deepcopy(getattr(block, "value", None))
for _id, block in app.get_blocks().blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[session_hash]
# The should_reset set keeps track of the fn_indices
# that have been cancelled. When a job is cancelled,
# the /reset route will mark the jobs as having been reset.
# That way if the cancel job finishes BEFORE the job being cancelled
# the job being cancelled will not overwrite the state of the iterator.
if fn_index in app.iterators_to_reset[session_hash]:
iterators = {}
app.iterators_to_reset[session_hash].remove(fn_index)
else:
iterators = app.iterators[session_hash]
else:
session_state = {}
iterators = {}
return session_state, iterators
async def call_process_api(
app: App,
body: PredictBody,
gr_request: Union[Request, list[Request]],
fn_index_inferred,
):
session_state, iterators = restore_session_state(app=app, body=body)
dependency = app.get_blocks().dependencies[fn_index_inferred]
target = dependency["targets"][0] if len(dependency["targets"]) else None
event_data = EventData(
app.get_blocks().blocks.get(target) if target else None,
body.event_data,
)
event_id = getattr(body, "event_id", None)
fn_index = body.fn_index
session_hash = getattr(body, "session_hash", None)
inputs = body.data
batch_in_single_out = not body.batched and dependency["batch"]
if batch_in_single_out:
inputs = [inputs]
try:
with utils.MatplotlibBackendMananger():
output = await app.get_blocks().process_api(
fn_index=fn_index_inferred,
inputs=inputs,
request=gr_request,
state=session_state,
iterators=iterators,
session_hash=session_hash,
event_id=event_id,
event_data=event_data,
)
iterator = output.pop("iterator", None)
if hasattr(body, "session_hash"):
app.iterators[body.session_hash][fn_index] = iterator
if isinstance(output, Error):
raise output
except BaseException:
iterator = iterators.get(fn_index, None)
if iterator is not None: # close off any streams that are still open
run_id = id(iterator)
pending_streams: dict[int, list] = (
app.get_blocks().pending_streams[session_hash].get(run_id, {})
)
for stream in pending_streams.values():
stream.append(None)
raise
if batch_in_single_out:
output["data"] = output["data"][0]
return output
|