Spaces:
Configuration error
Configuration error
| import asyncio | |
| from asyncio import BoundedSemaphore | |
| from time import perf_counter, time | |
| from typing import Any, Dict, List, Optional | |
| import orjson | |
| from redis.asyncio import Redis | |
| from inference.core.entities.requests.inference import ( | |
| InferenceRequest, | |
| request_from_type, | |
| ) | |
| from inference.core.entities.responses.inference import response_from_type | |
| from inference.core.env import NUM_PARALLEL_TASKS | |
| from inference.core.managers.base import ModelManager | |
| from inference.core.registries.base import ModelRegistry | |
| from inference.core.registries.roboflow import get_model_type | |
| from inference.enterprise.parallel.tasks import preprocess | |
| from inference.enterprise.parallel.utils import FAILURE_STATE, SUCCESS_STATE | |
| class ResultsChecker: | |
| """ | |
| Class responsible for queuing asyncronous inference runs, | |
| keeping track of running requests, and awaiting their results. | |
| """ | |
| def __init__(self, redis: Redis): | |
| self.tasks: Dict[str, asyncio.Event] = {} | |
| self.dones = dict() | |
| self.errors = dict() | |
| self.running = True | |
| self.redis = redis | |
| self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS) | |
| async def add_task(self, task_id: str, request: InferenceRequest): | |
| """ | |
| Wait until there's available cylce to queue a task. | |
| When there are cycles, add the task's id to a list to keep track of its results, | |
| launch the preprocess celeryt task, set the task's status to in progress in redis. | |
| """ | |
| await self.semaphore.acquire() | |
| self.tasks[task_id] = asyncio.Event() | |
| preprocess.s(request.dict()).delay() | |
| def get_result(self, task_id: str) -> Any: | |
| """ | |
| Check the done tasks and errored tasks for this task id. | |
| """ | |
| if task_id in self.dones: | |
| return self.dones.pop(task_id) | |
| elif task_id in self.errors: | |
| message = self.errors.pop(task_id) | |
| raise Exception(message) | |
| else: | |
| raise RuntimeError( | |
| "Task result not found in either success or error dict. Unreachable" | |
| ) | |
| async def loop(self): | |
| """ | |
| Main loop. Check all in progress tasks for their status, and if their status is final, | |
| (either failure or success) then add their results to the appropriate results dictionary. | |
| """ | |
| async with self.redis.pubsub() as pubsub: | |
| await pubsub.subscribe("results") | |
| async for message in pubsub.listen(): | |
| if message["type"] != "message": | |
| continue | |
| message = orjson.loads(message["data"]) | |
| task_id = message.pop("task_id") | |
| if task_id not in self.tasks: | |
| continue | |
| self.semaphore.release() | |
| status = message.pop("status") | |
| if status == FAILURE_STATE: | |
| self.errors[task_id] = message["payload"] | |
| elif status == SUCCESS_STATE: | |
| self.dones[task_id] = message["payload"] | |
| else: | |
| raise RuntimeError( | |
| "Task result not found in possible states. Unreachable" | |
| ) | |
| self.tasks[task_id].set() | |
| await asyncio.sleep(0) | |
| async def wait_for_response(self, key: str): | |
| event = self.tasks[key] | |
| await event.wait() | |
| del self.tasks[key] | |
| return self.get_result(key) | |
| class DispatchModelManager(ModelManager): | |
| def __init__( | |
| self, | |
| model_registry: ModelRegistry, | |
| checker: ResultsChecker, | |
| models: Optional[dict] = None, | |
| ): | |
| super().__init__(model_registry, models) | |
| self.checker = checker | |
| async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): | |
| if request.visualize_predictions: | |
| raise NotImplementedError("Visualisation of prediction is not supported") | |
| request.start = time() | |
| t = perf_counter() | |
| task_type = self.get_task_type(model_id, request.api_key) | |
| list_mode = False | |
| if isinstance(request.image, list): | |
| list_mode = True | |
| request_dict = request.dict() | |
| images = request_dict.pop("image") | |
| del request_dict["id"] | |
| requests = [ | |
| request_from_type(task_type, dict(**request_dict, image=image)) | |
| for image in images | |
| ] | |
| else: | |
| requests = [request] | |
| start_task_awaitables = [] | |
| results_awaitables = [] | |
| for r in requests: | |
| start_task_awaitables.append(self.checker.add_task(r.id, r)) | |
| results_awaitables.append(self.checker.wait_for_response(r.id)) | |
| await asyncio.gather(*start_task_awaitables) | |
| response_jsons = await asyncio.gather(*results_awaitables) | |
| responses = [] | |
| for response_json in response_jsons: | |
| response = response_from_type(task_type, response_json) | |
| response.time = perf_counter() - t | |
| responses.append(response) | |
| if list_mode: | |
| return responses | |
| return responses[0] | |
| def add_model( | |
| self, model_id: str, api_key: str, model_id_alias: str = None | |
| ) -> None: | |
| pass | |
| def __contains__(self, model_id: str) -> bool: | |
| return True | |
| def get_task_type(self, model_id: str, api_key: str = None) -> str: | |
| return get_model_type(model_id, api_key)[0] | |