Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import asyncio | |
| import copy | |
| import json | |
| import os | |
| import random | |
| import time | |
| import traceback | |
| import uuid | |
| from collections import defaultdict | |
| from queue import Queue as ThreadQueue | |
| from typing import TYPE_CHECKING | |
| import fastapi | |
| from gradio_client.utils import ServerMessage | |
| from typing_extensions import Literal | |
| from gradio import route_utils, routes | |
| from gradio.data_classes import ( | |
| Estimation, | |
| LogMessage, | |
| PredictBody, | |
| Progress, | |
| ProgressUnit, | |
| ) | |
| from gradio.exceptions import Error | |
| from gradio.helpers import TrackedIterable | |
| from gradio.utils import LRUCache, run_coro_in_background, safe_get_lock, set_task_name | |
| if TYPE_CHECKING: | |
| from gradio.blocks import BlockFunction | |
| class Event: | |
| def __init__( | |
| self, | |
| session_hash: str, | |
| fn_index: int, | |
| request: fastapi.Request, | |
| username: str | None, | |
| concurrency_id: str, | |
| ): | |
| self.session_hash = session_hash | |
| self.fn_index = fn_index | |
| self.request = request | |
| self.username = username | |
| self.concurrency_id = concurrency_id | |
| self._id = uuid.uuid4().hex | |
| self.data: PredictBody | None = None | |
| self.progress: Progress | None = None | |
| self.progress_pending: bool = False | |
| self.alive = True | |
| class EventQueue: | |
| def __init__(self, concurrency_id: str, concurrency_limit: int | None): | |
| self.queue: list[Event] = [] | |
| self.concurrency_id = concurrency_id | |
| self.concurrency_limit = concurrency_limit | |
| self.current_concurrency = 0 | |
| self.start_times_per_fn_index: defaultdict[int, set[float]] = defaultdict(set) | |
| class ProcessTime: | |
| def __init__(self): | |
| self.process_time = 0 | |
| self.count = 0 | |
| self.avg_time = 0 | |
| def add(self, time: float): | |
| self.process_time += time | |
| self.count += 1 | |
| self.avg_time = self.process_time / self.count | |
| class Queue: | |
| def __init__( | |
| self, | |
| live_updates: bool, | |
| concurrency_count: int, | |
| update_intervals: float, | |
| max_size: int | None, | |
| block_fns: list[BlockFunction], | |
| default_concurrency_limit: int | None | Literal["not_set"] = "not_set", | |
| ): | |
| self.pending_messages_per_session: LRUCache[str, ThreadQueue] = LRUCache(2000) | |
| self.pending_event_ids_session: dict[str, set[str]] = {} | |
| self.pending_message_lock = safe_get_lock() | |
| self.event_queue_per_concurrency_id: dict[str, EventQueue] = {} | |
| self.stopped = False | |
| self.max_thread_count = concurrency_count | |
| self.update_intervals = update_intervals | |
| self.active_jobs: list[None | list[Event]] = [] | |
| self.delete_lock = safe_get_lock() | |
| self.server_app = None | |
| self.process_time_per_fn_index: defaultdict[int, ProcessTime] = defaultdict( | |
| ProcessTime | |
| ) | |
| self.live_updates = live_updates | |
| self.sleep_when_free = 0.05 | |
| self.progress_update_sleep_when_free = 0.1 | |
| self.max_size = max_size | |
| self.block_fns = block_fns | |
| self.continuous_tasks: list[Event] = [] | |
| self._asyncio_tasks: list[asyncio.Task] = [] | |
| self.default_concurrency_limit = self._resolve_concurrency_limit( | |
| default_concurrency_limit | |
| ) | |
| def start(self): | |
| self.active_jobs = [None] * self.max_thread_count | |
| for block_fn in self.block_fns: | |
| concurrency_id = block_fn.concurrency_id | |
| concurrency_limit: int | None | |
| if block_fn.concurrency_limit == "default": | |
| concurrency_limit = self.default_concurrency_limit | |
| else: | |
| concurrency_limit = block_fn.concurrency_limit | |
| if concurrency_id not in self.event_queue_per_concurrency_id: | |
| self.event_queue_per_concurrency_id[concurrency_id] = EventQueue( | |
| concurrency_id, concurrency_limit | |
| ) | |
| elif ( | |
| concurrency_limit is not None | |
| ): # Update concurrency limit if it is lower than existing limit | |
| existing_event_queue = self.event_queue_per_concurrency_id[ | |
| concurrency_id | |
| ] | |
| if ( | |
| existing_event_queue.concurrency_limit is None | |
| or concurrency_limit < existing_event_queue.concurrency_limit | |
| ): | |
| existing_event_queue.concurrency_limit = concurrency_limit | |
| run_coro_in_background(self.start_processing) | |
| run_coro_in_background(self.start_progress_updates) | |
| if not self.live_updates: | |
| run_coro_in_background(self.notify_clients) | |
| def close(self): | |
| self.stopped = True | |
| def send_message( | |
| self, | |
| event: Event, | |
| message_type: str, | |
| data: dict | None = None, | |
| ): | |
| if not event.alive: | |
| return | |
| data = {} if data is None else data | |
| messages = self.pending_messages_per_session[event.session_hash] | |
| messages.put_nowait({"msg": message_type, "event_id": event._id, **data}) | |
| def _resolve_concurrency_limit( | |
| self, default_concurrency_limit: int | None | Literal["not_set"] | |
| ) -> int | None: | |
| """ | |
| Handles the logic of resolving the default_concurrency_limit as this can be specified via a combination | |
| of the `default_concurrency_limit` parameter of the `Blocks.queue()` or the `GRADIO_DEFAULT_CONCURRENCY_LIMIT` | |
| environment variable. The parameter in `Blocks.queue()` takes precedence over the environment variable. | |
| Parameters: | |
| default_concurrency_limit: The default concurrency limit, as specified by a user in `Blocks.queu()`. | |
| """ | |
| if default_concurrency_limit != "not_set": | |
| return default_concurrency_limit | |
| if default_concurrency_limit_env := os.environ.get( | |
| "GRADIO_DEFAULT_CONCURRENCY_LIMIT" | |
| ): | |
| if default_concurrency_limit_env.lower() == "none": | |
| return None | |
| else: | |
| return int(default_concurrency_limit_env) | |
| else: | |
| return 1 | |
| def __len__(self): | |
| total_len = 0 | |
| for event_queue in self.event_queue_per_concurrency_id.values(): | |
| total_len += len(event_queue.queue) | |
| return total_len | |
| async def push( | |
| self, body: PredictBody, request: fastapi.Request, username: str | None | |
| ) -> tuple[bool, str]: | |
| if body.session_hash is None: | |
| return False, "No session hash provided." | |
| if body.fn_index is None: | |
| return False, "No function index provided." | |
| if self.max_size is not None and len(self) >= self.max_size: | |
| return ( | |
| False, | |
| f"Queue is full. Max size is {self.max_size} and size is {len(self)}.", | |
| ) | |
| event = Event( | |
| body.session_hash, | |
| body.fn_index, | |
| request, | |
| username, | |
| self.block_fns[body.fn_index].concurrency_id, | |
| ) | |
| event.data = body | |
| async with self.pending_message_lock: | |
| if body.session_hash not in self.pending_messages_per_session: | |
| self.pending_messages_per_session[body.session_hash] = ThreadQueue() | |
| if body.session_hash not in self.pending_event_ids_session: | |
| self.pending_event_ids_session[body.session_hash] = set() | |
| self.pending_event_ids_session[body.session_hash].add(event._id) | |
| event_queue = self.event_queue_per_concurrency_id[event.concurrency_id] | |
| event_queue.queue.append(event) | |
| self.broadcast_estimations(event.concurrency_id, len(event_queue.queue) - 1) | |
| return True, event._id | |
| def _cancel_asyncio_tasks(self): | |
| for task in self._asyncio_tasks: | |
| task.cancel() | |
| self._asyncio_tasks = [] | |
| def set_server_app(self, app: routes.App): | |
| self.server_app = app | |
| def get_active_worker_count(self) -> int: | |
| count = 0 | |
| for worker in self.active_jobs: | |
| if worker is not None: | |
| count += 1 | |
| return count | |
| def get_events(self) -> tuple[list[Event], bool, str] | None: | |
| concurrency_ids = list(self.event_queue_per_concurrency_id.keys()) | |
| random.shuffle(concurrency_ids) | |
| for concurrency_id in concurrency_ids: | |
| event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
| if len(event_queue.queue) and ( | |
| event_queue.concurrency_limit is None | |
| or event_queue.current_concurrency < event_queue.concurrency_limit | |
| ): | |
| first_event = event_queue.queue[0] | |
| block_fn = self.block_fns[first_event.fn_index] | |
| events = [first_event] | |
| batch = block_fn.batch | |
| if batch: | |
| events += [ | |
| event | |
| for event in event_queue.queue[1:] | |
| if event.fn_index == first_event.fn_index | |
| ][: block_fn.max_batch_size - 1] | |
| for event in events: | |
| event_queue.queue.remove(event) | |
| return events, batch, concurrency_id | |
| async def start_processing(self) -> None: | |
| try: | |
| while not self.stopped: | |
| if len(self) == 0: | |
| await asyncio.sleep(self.sleep_when_free) | |
| continue | |
| if None not in self.active_jobs: | |
| await asyncio.sleep(self.sleep_when_free) | |
| continue | |
| # Using mutex to avoid editing a list in use | |
| async with self.delete_lock: | |
| event_batch = self.get_events() | |
| if event_batch: | |
| events, batch, concurrency_id = event_batch | |
| self.active_jobs[self.active_jobs.index(None)] = events | |
| event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
| event_queue.current_concurrency += 1 | |
| start_time = time.time() | |
| event_queue.start_times_per_fn_index[events[0].fn_index].add( | |
| start_time | |
| ) | |
| process_event_task = run_coro_in_background( | |
| self.process_events, events, batch, start_time | |
| ) | |
| set_task_name( | |
| process_event_task, | |
| events[0].session_hash, | |
| events[0].fn_index, | |
| batch, | |
| ) | |
| self._asyncio_tasks.append(process_event_task) | |
| if self.live_updates: | |
| self.broadcast_estimations(concurrency_id) | |
| else: | |
| await asyncio.sleep(self.sleep_when_free) | |
| finally: | |
| self.stopped = True | |
| self._cancel_asyncio_tasks() | |
| async def start_progress_updates(self) -> None: | |
| """ | |
| Because progress updates can be very frequent, we do not necessarily want to send a message per update. | |
| Rather, we check for progress updates at regular intervals, and send a message if there is a pending update. | |
| Consecutive progress updates between sends will overwrite each other so only the most recent update will be sent. | |
| """ | |
| while not self.stopped: | |
| events = [ | |
| evt for job in self.active_jobs if job is not None for evt in job | |
| ] + self.continuous_tasks | |
| if len(events) == 0: | |
| await asyncio.sleep(self.progress_update_sleep_when_free) | |
| continue | |
| for event in events: | |
| if event.progress_pending and event.progress: | |
| event.progress_pending = False | |
| self.send_message( | |
| event, ServerMessage.progress, event.progress.model_dump() | |
| ) | |
| await asyncio.sleep(self.progress_update_sleep_when_free) | |
| def set_progress( | |
| self, | |
| event_id: str, | |
| iterables: list[TrackedIterable] | None, | |
| ): | |
| if iterables is None: | |
| return | |
| for job in self.active_jobs: | |
| if job is None: | |
| continue | |
| for evt in job: | |
| if evt._id == event_id: | |
| progress_data: list[ProgressUnit] = [] | |
| for iterable in iterables: | |
| progress_unit = ProgressUnit( | |
| index=iterable.index, | |
| length=iterable.length, | |
| unit=iterable.unit, | |
| progress=iterable.progress, | |
| desc=iterable.desc, | |
| ) | |
| progress_data.append(progress_unit) | |
| evt.progress = Progress(progress_data=progress_data) | |
| evt.progress_pending = True | |
| def log_message( | |
| self, | |
| event_id: str, | |
| log: str, | |
| level: Literal["info", "warning"], | |
| ): | |
| events = [ | |
| evt for job in self.active_jobs if job is not None for evt in job | |
| ] + self.continuous_tasks | |
| for event in events: | |
| if event._id == event_id: | |
| log_message = LogMessage( | |
| log=log, | |
| level=level, | |
| ) | |
| self.send_message(event, ServerMessage.log, log_message.model_dump()) | |
| async def clean_events( | |
| self, *, session_hash: str | None = None, event_id: str | None = None | |
| ) -> None: | |
| for job_set in self.active_jobs: | |
| if job_set: | |
| for job in job_set: | |
| if job.session_hash == session_hash or job._id == event_id: | |
| job.alive = False | |
| async with self.delete_lock: | |
| events_to_remove: list[Event] = [] | |
| for event_queue in self.event_queue_per_concurrency_id.values(): | |
| for event in event_queue.queue: | |
| if event.session_hash == session_hash or event._id == event_id: | |
| events_to_remove.append(event) | |
| for event in events_to_remove: | |
| self.event_queue_per_concurrency_id[event.concurrency_id].queue.remove( | |
| event | |
| ) | |
| async def notify_clients(self) -> None: | |
| """ | |
| Notify clients about events statuses in the queue periodically. | |
| """ | |
| while not self.stopped: | |
| await asyncio.sleep(self.update_intervals) | |
| if len(self) > 0: | |
| for concurrency_id in self.event_queue_per_concurrency_id: | |
| self.broadcast_estimations(concurrency_id) | |
| def broadcast_estimations( | |
| self, concurrency_id: str, after: int | None = None | |
| ) -> None: | |
| wait_so_far = 0 | |
| event_queue = self.event_queue_per_concurrency_id[concurrency_id] | |
| time_till_available_worker: int | None = 0 | |
| if event_queue.current_concurrency == event_queue.concurrency_limit: | |
| expected_end_times = [] | |
| for fn_index, start_times in event_queue.start_times_per_fn_index.items(): | |
| if fn_index not in self.process_time_per_fn_index: | |
| time_till_available_worker = None | |
| break | |
| process_time = self.process_time_per_fn_index[fn_index].avg_time | |
| expected_end_times += [ | |
| start_time + process_time for start_time in start_times | |
| ] | |
| if time_till_available_worker is not None and len(expected_end_times) > 0: | |
| time_of_first_completion = min(expected_end_times) | |
| time_till_available_worker = max( | |
| time_of_first_completion - time.time(), 0 | |
| ) | |
| for rank, event in enumerate(event_queue.queue): | |
| process_time_for_fn = ( | |
| self.process_time_per_fn_index[event.fn_index].avg_time | |
| if event.fn_index in self.process_time_per_fn_index | |
| else None | |
| ) | |
| rank_eta = ( | |
| process_time_for_fn + wait_so_far + time_till_available_worker | |
| if process_time_for_fn is not None | |
| and wait_so_far is not None | |
| and time_till_available_worker is not None | |
| else None | |
| ) | |
| if after is None or rank >= after: | |
| self.send_message( | |
| event, | |
| ServerMessage.estimation, | |
| Estimation( | |
| rank=rank, rank_eta=rank_eta, queue_size=len(event_queue.queue) | |
| ).model_dump(), | |
| ) | |
| if event_queue.concurrency_limit is None: | |
| wait_so_far = 0 | |
| elif wait_so_far is not None and process_time_for_fn is not None: | |
| wait_so_far += process_time_for_fn / event_queue.concurrency_limit | |
| else: | |
| wait_so_far = None | |
| def get_status(self) -> Estimation: | |
| return Estimation( | |
| queue_size=len(self), | |
| ) | |
| async def call_prediction(self, events: list[Event], batch: bool): | |
| body = events[0].data | |
| if body is None: | |
| raise ValueError("No event data") | |
| username = events[0].username | |
| body.event_id = events[0]._id if not batch else None | |
| try: | |
| body.request = events[0].request | |
| except ValueError: | |
| pass | |
| if batch: | |
| body.data = list(zip(*[event.data.data for event in events if event.data])) | |
| body.request = events[0].request | |
| body.batched = True | |
| app = self.server_app | |
| if app is None: | |
| raise Exception("Server app has not been set.") | |
| api_name = "predict" | |
| fn_index_inferred = route_utils.infer_fn_index( | |
| app=app, api_name=api_name, body=body | |
| ) | |
| gr_request = route_utils.compile_gr_request( | |
| app=app, | |
| body=body, | |
| fn_index_inferred=fn_index_inferred, | |
| username=username, | |
| request=None, | |
| ) | |
| try: | |
| output = await route_utils.call_process_api( | |
| app=app, | |
| body=body, | |
| gr_request=gr_request, | |
| fn_index_inferred=fn_index_inferred, | |
| ) | |
| except Exception as error: | |
| show_error = app.get_blocks().show_error or isinstance(error, Error) | |
| traceback.print_exc() | |
| raise Exception(str(error) if show_error else None) from error | |
| # To emulate the HTTP response from the predict API, | |
| # convert the output to a JSON response string. | |
| # This is done by FastAPI automatically in the HTTP endpoint handlers, | |
| # but we need to do it manually here. | |
| response_class = app.router.default_response_class | |
| if isinstance(response_class, fastapi.datastructures.DefaultPlaceholder): | |
| actual_response_class = response_class.value | |
| else: | |
| actual_response_class = response_class | |
| http_response = actual_response_class( | |
| output | |
| ) # Do the same as https://github.com/tiangolo/fastapi/blob/0.87.0/fastapi/routing.py#L264 | |
| # Also, decode the JSON string to a Python object, emulating the HTTP client behavior e.g. the `json()` method of `httpx`. | |
| response_json = json.loads(http_response.body.decode()) | |
| return response_json | |
| async def process_events( | |
| self, events: list[Event], batch: bool, begin_time: float | |
| ) -> None: | |
| awake_events: list[Event] = [] | |
| fn_index = events[0].fn_index | |
| try: | |
| for event in events: | |
| if event.alive: | |
| self.send_message( | |
| event, | |
| ServerMessage.process_starts, | |
| { | |
| "eta": self.process_time_per_fn_index[fn_index].avg_time | |
| if fn_index in self.process_time_per_fn_index | |
| else None | |
| }, | |
| ) | |
| awake_events.append(event) | |
| if not awake_events: | |
| return | |
| try: | |
| response = await self.call_prediction(awake_events, batch) | |
| err = None | |
| except Exception as e: | |
| response = None | |
| err = e | |
| for event in awake_events: | |
| self.send_message( | |
| event, | |
| ServerMessage.process_completed, | |
| { | |
| "output": { | |
| "error": None | |
| if len(e.args) and e.args[0] is None | |
| else str(e) | |
| }, | |
| "success": False, | |
| }, | |
| ) | |
| if response and response.get("is_generating", False): | |
| old_response = response | |
| old_err = err | |
| while response and response.get("is_generating", False): | |
| old_response = response | |
| old_err = err | |
| for event in awake_events: | |
| self.send_message( | |
| event, | |
| ServerMessage.process_generating, | |
| { | |
| "output": old_response, | |
| "success": old_response is not None, | |
| }, | |
| ) | |
| awake_events = [event for event in awake_events if event.alive] | |
| if not awake_events: | |
| return | |
| try: | |
| response = await self.call_prediction(awake_events, batch) | |
| err = None | |
| except Exception as e: | |
| response = None | |
| err = e | |
| for event in awake_events: | |
| if response is None: | |
| relevant_response = err | |
| else: | |
| relevant_response = old_response or old_err | |
| self.send_message( | |
| event, | |
| ServerMessage.process_completed, | |
| { | |
| "output": {"error": str(relevant_response)} | |
| if isinstance(relevant_response, Exception) | |
| else relevant_response, | |
| "success": relevant_response | |
| and not isinstance(relevant_response, Exception), | |
| }, | |
| ) | |
| elif response: | |
| output = copy.deepcopy(response) | |
| for e, event in enumerate(awake_events): | |
| if batch and "data" in output: | |
| output["data"] = list(zip(*response.get("data")))[e] | |
| self.send_message( | |
| event, | |
| ServerMessage.process_completed, | |
| { | |
| "output": output, | |
| "success": response is not None, | |
| }, | |
| ) | |
| end_time = time.time() | |
| if response is not None: | |
| self.process_time_per_fn_index[events[0].fn_index].add( | |
| end_time - begin_time | |
| ) | |
| except Exception as e: | |
| traceback.print_exc() | |
| finally: | |
| event_queue = self.event_queue_per_concurrency_id[events[0].concurrency_id] | |
| event_queue.current_concurrency -= 1 | |
| start_times = event_queue.start_times_per_fn_index[fn_index] | |
| if begin_time in start_times: | |
| start_times.remove(begin_time) | |
| try: | |
| self.active_jobs[self.active_jobs.index(events)] = None | |
| except ValueError: | |
| # `events` can be absent from `self.active_jobs` | |
| # when this coroutine is called from the `join_queue` endpoint handler in `routes.py` | |
| # without putting the `events` into `self.active_jobs`. | |
| # https://github.com/gradio-app/gradio/blob/f09aea34d6bd18c1e2fef80c86ab2476a6d1dd83/gradio/routes.py#L594-L596 | |
| pass | |
| for event in events: | |
| # Always reset the state of the iterator | |
| # If the job finished successfully, this has no effect | |
| # If the job is cancelled, this will enable future runs | |
| # to start "from scratch" | |
| await self.reset_iterators(event._id) | |
| async def reset_iterators(self, event_id: str): | |
| # Do the same thing as the /reset route | |
| app = self.server_app | |
| if app is None: | |
| raise Exception("Server app has not been set.") | |
| if event_id not in app.iterators: | |
| # Failure, but don't raise an error | |
| return | |
| async with app.lock: | |
| del app.iterators[event_id] | |
| app.iterators_to_reset.add(event_id) | |
| return | |