Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import asyncio | |
import copy | |
import json | |
import os | |
import random | |
import time | |
import traceback | |
import uuid | |
from collections import defaultdict | |
from queue import Queue as ThreadQueue | |
from typing import TYPE_CHECKING | |
import fastapi | |
from gradio_client.utils import ServerMessage | |
from typing_extensions import Literal | |
from gradio import route_utils, routes | |
from gradio.data_classes import ( | |
Estimation, | |
LogMessage, | |
PredictBody, | |
Progress, | |
ProgressUnit, | |
) | |
from gradio.exceptions import Error | |
from gradio.helpers import TrackedIterable | |
from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name | |
if TYPE_CHECKING: | |
from gradio.blocks import BlockFunction | |
class Event: | |
def __init__( | |
self, | |
session_hash: str, | |
fn_index: int, | |
request: fastapi.Request, | |
username: str | None, | |
concurrency_id: str, | |
): | |
self.session_hash = session_hash | |
self.fn_index = fn_index | |
self.request = request | |
self.username = username | |
self.concurrency_id = concurrency_id | |
self._id = uuid.uuid4().hex | |
self.data: PredictBody | None = None | |
self.progress: Progress | None = None | |
self.progress_pending: bool = False | |
self.alive = True | |
class EventQueue: | |
def __init__(self, concurrency_id: str, concurrency_limit: int | None): | |
self.queue: list[Event] = [] | |
self.concurrency_id = concurrency_id | |
self.concurrency_limit = concurrency_limit | |
self.current_concurrency = 0 | |
self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set) | |
class ProcessTime: | |
def __init__(self): | |
self.process_time = 0 | |
self.count = 0 | |
self.avg_time = 0 | |
def add(self, time: float): | |
self.process_time += time | |
self.count += 1 | |
self.avg_time = self.process_time / self.count | |
class Queue: | |
def __init__( | |
self, | |
live_updates: bool, | |
concurrency_count: int, | |
update_intervals: float, | |
max_size: int | None, | |
block_fns: list[BlockFunction], | |
default_concurrency_limit: int | None | Literal["not_set"] = "not_set", | |
): | |
self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000) | |
self.pending_event_ids_session: dict[str, set[str]] = {} | |
self.pending_message_lock = safe_get_lock() | |
self.event_queue_per_concurrency_id: dict[str, EventQueue] = {} | |
self.stopped = False | |
self.max_thread_count = concurrency_count | |
self.update_intervals = update_intervals | |
self.active_jobs: list[None | list[Event]] = [] | |
self.delete_lock = safe_get_lock() | |
self.server_app = None | |
self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict( | |
ProcessTime | |
) | |
self.live_updates = live_updates | |
self.sleep_when_free = 0.05 | |
self.progress_update_sleep_when_free = 0.1 | |
self.max_size = max_size | |
self.block_fns = block_fns | |
self.continuous_tasks: list[Event] = [] | |
self._asyncio_tasks: list[asyncio.Task] = [] | |
self.default_concurrency_limit = self._resolve_concurrency_limit( | |
default_concurrency_limit | |
) | |
def start(self): | |
self.active_jobs = [None] * self.max_thread_count | |
for block_fn in self.block_fns: | |
concurrency_id = block_fn.concurrency_id | |
concurrency_limit: int | None | |
if block_fn.concurrency_limit == "default": | |
concurrency_limit = self.default_concurrency_limit | |
else: | |
concurrency_limit = block_fn.concurrency_limit | |
if concurrency_id not in self.event_queue_per_concurrency_id: | |
self.event_queue_per_concurrency_id[concurrency_id] = EventQueue( | |
concurrency_id, concurrency_limit | |
) | |
elif ( | |
concurrency_limit is not None | |
): # Update concurrency limit if it is lower than existing limit | |
existing_event_queue = self.event_queue_per_concurrency_id[ | |
concurrency_id | |
] | |
if ( | |
existing_event_queue.concurrency_limit is None | |
or concurrency_limit < existing_event_queue.concurrency_limit | |
): | |
existing_event_queue.concurrency_limit = concurrency_limit | |
run_coro_in_background(self.start_processing) | |
run_coro_in_background(self.start_progress_updates) | |
if not self.live_updates: | |
run_coro_in_background(self.notify_clients) | |
def close(self): | |
self.stopped = True | |
def send_message( | |
self, | |
event: Event, | |
message_type: str, | |
data: dict | None = None, | |
): | |
if not event.alive: | |
return | |
data = {} if data is None else data | |
messages = self.pending_messages_per_session[event.session_hash] | |
messages.put_nowait({"msg": message_type, "event_id": event._id, **data}) | |
def _resolve_concurrency_limit( | |
self, default_concurrency_limit: int | None | Literal["not_set"] | |
) -> int | None: | |
""" | |
Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination | |
of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT` | |
environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable. | |
Parameters: | |
default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`. | |
""" | |
if default_concurrency_limit != "not_set": | |
return default_concurrency_limit | |
if default_concurrency_limit_env := os.environ.get( | |
"GRADIO_DEFAULT_CONCURRENCY_LIMIT" | |
): | |
if default_concurrency_limit_env.lower() == "none": | |
return None | |
else: | |
return int(default_concurrency_limit_env) | |
else: | |
return 1 | |
def __len__(self): | |
total_len = 0 | |
for event_queue in self.event_queue_per_concurrency_id.values(): | |
total_len += len(event_queue.queue) | |
return total_len | |
async def push( | |
self, body: PredictBody, request: fastapi.Request, username: str | None | |
) -> tuple[bool, str]: | |
if body.session_hash is None: | |
return False, "No session hash provided." | |
if body.fn_index is None: | |
return False, "No function index provided." | |
if self.max_size is not None and len(self) >= self.max_size: | |
return ( | |
False, | |
f"Queue is full. Max size is {self.max_size} and size is {len(self)}.", | |
) | |
event = Event( | |
body.session_hash, | |
body.fn_index, | |
request, | |
username, | |
self.block_fns[body.fn_index].concurrency_id, | |
) | |
event.data = body | |
async with self.pending_message_lock: | |
if body.session_hash not in self.pending_messages_per_session: | |
self.pending_messages_per_session[body.session_hash] = ThreadQueue() | |
if body.session_hash not in self.pending_event_ids_session: | |
self.pending_event_ids_session[body.session_hash] = set() | |
self.pending_event_ids_session[body.session_hash].add(event._id) | |
event_queue = self.event_queue_per_concurrency_id[event.concurrency_id] | |
event_queue.queue.append(event) | |
self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1) | |
return True, event._id | |
def _cancel_asyncio_tasks(self): | |
for task in self._asyncio_tasks: | |
task.cancel() | |
self._asyncio_tasks = [] | |
def set_server_app(self, app: routes.App): | |
self.server_app = app | |
def get_active_worker_count(self) -> int: | |
count = 0 | |
for worker in self.active_jobs: | |
if worker is not None: | |
count += 1 | |
return count | |
def get_events(self) -> tuple[list[Event], bool, str] | None: | |
concurrency_ids = list(self.event_queue_per_concurrency_id.keys()) | |
random.shuffle(concurrency_ids) | |
for concurrency_id in concurrency_ids: | |
event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
if len(event_queue.queue) and ( | |
event_queue.concurrency_limit is None | |
or event_queue.current_concurrency < event_queue.concurrency_limit | |
): | |
first_event = event_queue.queue[0] | |
block_fn = self.block_fns[first_event.fn_index] | |
events = [first_event] | |
batch = block_fn.batch | |
if batch: | |
events += [ | |
event | |
for event in event_queue.queue[1:] | |
if event.fn_index == first_event.fn_index | |
][: block_fn.max_batch_size - 1] | |
for event in events: | |
event_queue.queue.remove(event) | |
return events, batch, concurrency_id | |
async def start_processing(self) -> None: | |
try: | |
while not self.stopped: | |
if len(self) == 0: | |
await asyncio.sleep(self.sleep_when_free) | |
continue | |
if None not in self.active_jobs: | |
await asyncio.sleep(self.sleep_when_free) | |
continue | |
# Using mutex to avoid editing a list in use | |
async with self.delete_lock: | |
event_batch = self.get_events() | |
if event_batch: | |
events, batch, concurrency_id = event_batch | |
self.active_jobs[self.active_jobs.index(None)] = events | |
event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
event_queue.current_concurrency += 1 | |
start_time = time.time() | |
event_queue.start_times_per_fn_index[events[0].fn_index].add( | |
start_time | |
) | |
process_event_task = run_coro_in_background( | |
self.process_events, events, batch, start_time | |
) | |
set_task_name( | |
process_event_task, | |
events[0].session_hash, | |
events[0].fn_index, | |
batch, | |
) | |
self._asyncio_tasks.append(process_event_task) | |
if self.live_updates: | |
self.broadcast_estimations(concurrency_id) | |
else: | |
await asyncio.sleep(self.sleep_when_free) | |
finally: | |
self.stopped = True | |
self._cancel_asyncio_tasks() | |
async def start_progress_updates(self) -> None: | |
""" | |
Because progress updates can be very frequent, we do not necessarily want to send a message per update. | |
Rather, we check for progress updates at regular intervals, and send a message if there is a pending update. | |
Consecutive progress updates between sends will overwrite each other so only the most recent update will be sent. | |
""" | |
while not self.stopped: | |
events = [ | |
evt for job in self.active_jobs if job is not None for evt in job | |
] + self.continuous_tasks | |
if len(events) == 0: | |
await asyncio.sleep(self.progress_update_sleep_when_free) | |
continue | |
for event in events: | |
if event.progress_pending and event.progress: | |
event.progress_pending = False | |
self.send_message( | |
event, ServerMessage.progress, event.progress.model_dump() | |
) | |
await asyncio.sleep(self.progress_update_sleep_when_free) | |
def set_progress( | |
self, | |
event_id: str, | |
iterables: list[TrackedIterable] | None, | |
): | |
if iterables is None: | |
return | |
for job in self.active_jobs: | |
if job is None: | |
continue | |
for evt in job: | |
if evt._id == event_id: | |
progress_data: list[ProgressUnit] = [] | |
for iterable in iterables: | |
progress_unit = ProgressUnit( | |
index=iterable.index, | |
length=iterable.length, | |
unit=iterable.unit, | |
progress=iterable.progress, | |
desc=iterable.desc, | |
) | |
progress_data.append(progress_unit) | |
evt.progress = Progress(progress_data=progress_data) | |
evt.progress_pending = True | |
def log_message( | |
self, | |
event_id: str, | |
log: str, | |
level: Literal["info", "warning"], | |
): | |
events = [ | |
evt for job in self.active_jobs if job is not None for evt in job | |
] + self.continuous_tasks | |
for event in events: | |
if event._id == event_id: | |
log_message = LogMessage( | |
log=log, | |
level=level, | |
) | |
self.send_message(event, ServerMessage.log, log_message.model_dump()) | |
async def clean_events( | |
self, *, session_hash: str | None = None, event_id: str | None = None | |
) -> None: | |
for job_set in self.active_jobs: | |
if job_set: | |
for job in job_set: | |
if job.session_hash == session_hash or job._id == event_id: | |
job.alive = False | |
async with self.delete_lock: | |
events_to_remove: list[Event] = [] | |
for event_queue in self.event_queue_per_concurrency_id.values(): | |
for event in event_queue.queue: | |
if event.session_hash == session_hash or event._id == event_id: | |
events_to_remove.append(event) | |
for event in events_to_remove: | |
self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove( | |
event | |
) | |
async def notify_clients(self) -> None: | |
""" | |
Notify clients about events statuses in the queue periodically. | |
""" | |
while not self.stopped: | |
await asyncio.sleep(self.update_intervals) | |
if len(self) > 0: | |
for concurrency_id in self.event_queue_per_concurrency_id: | |
self.broadcast_estimations(concurrency_id) | |
def broadcast_estimations( | |
self, concurrency_id: str, after: int | None = None | |
) -> None: | |
wait_so_far = 0 | |
event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
time_till_available_worker: int | None = 0 | |
if event_queue.current_concurrency == event_queue.concurrency_limit: | |
expected_end_times = [] | |
for fn_index, start_times in event_queue.start_times_per_fn_index.items(): | |
if fn_index not in self.process_time_per_fn_index: | |
time_till_available_worker = None | |
break | |
process_time = self.process_time_per_fn_index[fn_index].avg_time | |
expected_end_times += [ | |
start_time + process_time for start_time in start_times | |
] | |
if time_till_available_worker is not None and len(expected_end_times) > 0: | |
time_of_first_completion = min(expected_end_times) | |
time_till_available_worker = max( | |
time_of_first_completion - time.time(), 0 | |
) | |
for rank, event in enumerate(event_queue.queue): | |
process_time_for_fn = ( | |
self.process_time_per_fn_index[event.fn_index].avg_time | |
if event.fn_index in self.process_time_per_fn_index | |
else None | |
) | |
rank_eta = ( | |
process_time_for_fn + wait_so_far + time_till_available_worker | |
if process_time_for_fn is not None | |
and wait_so_far is not None | |
and time_till_available_worker is not None | |
else None | |
) | |
if after is None or rank >= after: | |
self.send_message( | |
event, | |
ServerMessage.estimation, | |
Estimation( | |
rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue) | |
).model_dump(), | |
) | |
if event_queue.concurrency_limit is None: | |
wait_so_far = 0 | |
elif wait_so_far is not None and process_time_for_fn is not None: | |
wait_so_far += process_time_for_fn / event_queue.concurrency_limit | |
else: | |
wait_so_far = None | |
def get_status(self) -> Estimation: | |
return Estimation( | |
queue_size=len(self), | |
) | |
async def call_prediction(self, events: list[Event], batch: bool): | |
body = events[0].data | |
if body is None: | |
raise ValueError("No event data") | |
username = events[0].username | |
body.event_id = events[0]._id if not batch else None | |
try: | |
body.request = events[0].request | |
except ValueError: | |
pass | |
if batch: | |
body.data = list(zip(*[event.data.data for event in events if event.data])) | |
body.request = events[0].request | |
body.batched = True | |
app = self.server_app | |
if app is None: | |
raise Exception("Server app has not been set.") | |
api_name = "predict" | |
fn_index_inferred = route_utils.infer_fn_index( | |
app=app, api_name=api_name, body=body | |
) | |
gr_request = route_utils.compile_gr_request( | |
app=app, | |
body=body, | |
fn_index_inferred=fn_index_inferred, | |
username=username, | |
request=None, | |
) | |
try: | |
output = await route_utils.call_process_api( | |
app=app, | |
body=body, | |
gr_request=gr_request, | |
fn_index_inferred=fn_index_inferred, | |
) | |
except Exception as error: | |
show_error = app.get_blocks().show_error or isinstance(error, Error) | |
traceback.print_exc() | |
raise Exception(str(error) if show_error else None) from error | |
# To emulate the HTTP response from the predict API, | |
# convert the output to a JSON response string. | |
# This is done by FastAPI automatically in the HTTP endpoint handlers, | |
# but we need to do it manually here. | |
response_class = app.router.default_response_class | |
if isinstance(response_class, fastapi.datastructures.DefaultPlaceholder): | |
actual_response_class = response_class.value | |
else: | |
actual_response_class = response_class | |
http_response = actual_response_class( | |
output | |
) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264 | |
# Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`. | |
response_json = json.loads(http_response.body.decode()) | |
return response_json | |
async def process_events( | |
self, events: list[Event], batch: bool, begin_time: float | |
) -> None: | |
awake_events: list[Event] = [] | |
fn_index = events[0].fn_index | |
try: | |
for event in events: | |
if event.alive: | |
self.send_message( | |
event, | |
ServerMessage.process_starts, | |
{ | |
"eta": self.process_time_per_fn_index[fn_index].avg_time | |
if fn_index in self.process_time_per_fn_index | |
else None | |
}, | |
) | |
awake_events.append(event) | |
if not awake_events: | |
return | |
try: | |
response = await self.call_prediction(awake_events, batch) | |
err = None | |
except Exception as e: | |
response = None | |
err = e | |
for event in awake_events: | |
self.send_message( | |
event, | |
ServerMessage.process_completed, | |
{ | |
"output": { | |
"error": None | |
if len(e.args) and e.args[0] is None | |
else str(e) | |
}, | |
"success": False, | |
}, | |
) | |
if response and response.get("is_generating", False): | |
old_response = response | |
old_err = err | |
while response and response.get("is_generating", False): | |
old_response = response | |
old_err = err | |
for event in awake_events: | |
self.send_message( | |
event, | |
ServerMessage.process_generating, | |
{ | |
"output": old_response, | |
"success": old_response is not None, | |
}, | |
) | |
awake_events = [event for event in awake_events if event.alive] | |
if not awake_events: | |
return | |
try: | |
response = await self.call_prediction(awake_events, batch) | |
err = None | |
except Exception as e: | |
response = None | |
err = e | |
for event in awake_events: | |
if response is None: | |
relevant_response = err | |
else: | |
relevant_response = old_response or old_err | |
self.send_message( | |
event, | |
ServerMessage.process_completed, | |
{ | |
"output": {"error": str(relevant_response)} | |
if isinstance(relevant_response, Exception) | |
else relevant_response, | |
"success": relevant_response | |
and not isinstance(relevant_response, Exception), | |
}, | |
) | |
elif response: | |
output = copy.deepcopy(response) | |
for e, event in enumerate(awake_events): | |
if batch and "data" in output: | |
output["data"] = list(zip(*response.get("data")))[e] | |
self.send_message( | |
event, | |
ServerMessage.process_completed, | |
{ | |
"output": output, | |
"success": response is not None, | |
}, | |
) | |
end_time = time.time() | |
if response is not None: | |
self.process_time_per_fn_index[events[0].fn_index].add( | |
end_time - begin_time | |
) | |
except Exception as e: | |
traceback.print_exc() | |
finally: | |
event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id] | |
event_queue.current_concurrency -= 1 | |
start_times = event_queue.start_times_per_fn_index[fn_index] | |
if begin_time in start_times: | |
start_times.remove(begin_time) | |
try: | |
self.active_jobs[self.active_jobs.index(events)] = None | |
except ValueError: | |
# `events` can be absent from `self.active_jobs` | |
# when this coroutine is called from the `join_queue` endpoint handler in `routes.py` | |
# without putting the `events` into `self.active_jobs`. | |
# https://github.com/gradio-app/gradio/blob/f09aea34d6bd18c1e2fef80c86ab2476a6d1dd83/gradio/routes.py#L594-L596 | |
pass | |
for event in events: | |
# Always reset the state of the iterator | |
# If the job finished successfully, this has no effect | |
# If the job is cancelled, this will enable future runs | |
# to start "from scratch" | |
await self.reset_iterators(event._id) | |
async def reset_iterators(self, event_id: str): | |
# Do the same thing as the /reset route | |
app = self.server_app | |
if app is None: | |
raise Exception("Server app has not been set.") | |
if event_id not in app.iterators: | |
# Failure, but don't raise an error | |
return | |
async with app.lock: | |
del app.iterators[event_id] | |
app.iterators_to_reset.add(event_id) | |
return | |