Spaces:
Configuration error
Configuration error
| import json | |
| from dataclasses import asdict | |
| from multiprocessing import shared_memory | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| from celery import Celery | |
| from redis import ConnectionPool, Redis | |
| import inference.enterprise.parallel.celeryconfig | |
| from inference.core.entities.requests.inference import ( | |
| InferenceRequest, | |
| request_from_type, | |
| ) | |
| from inference.core.entities.responses.inference import InferenceResponse | |
| from inference.core.env import REDIS_HOST, REDIS_PORT, STUB_CACHE_SIZE | |
| from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache | |
| from inference.core.managers.decorators.locked_load import ( | |
| LockedLoadModelManagerDecorator, | |
| ) | |
| from inference.core.managers.stub_loader import StubLoaderManager | |
| from inference.core.registries.roboflow import RoboflowModelRegistry | |
| from inference.enterprise.parallel.utils import ( | |
| SUCCESS_STATE, | |
| SharedMemoryMetadata, | |
| failure_handler, | |
| shm_manager, | |
| ) | |
| from inference.models.utils import ROBOFLOW_MODEL_TYPES | |
| pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) | |
| app = Celery("tasks", broker=f"redis://{REDIS_HOST}:{REDIS_PORT}") | |
| app.config_from_object(inference.enterprise.parallel.celeryconfig) | |
| model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) | |
| model_manager = StubLoaderManager(model_registry) | |
| model_manager = WithFixedSizeCache( | |
| LockedLoadModelManagerDecorator(model_manager), max_size=STUB_CACHE_SIZE | |
| ) | |
| def preprocess(request: Dict): | |
| redis_client = Redis(connection_pool=pool) | |
| with failure_handler(redis_client, request["id"]): | |
| model_manager.add_model(request["model_id"], request["api_key"]) | |
| model_type = model_manager.get_task_type(request["model_id"]) | |
| request = request_from_type(model_type, request) | |
| image, preprocess_return_metadata = model_manager.preprocess( | |
| request.model_id, request | |
| ) | |
| # multi image requests are split into single image requests upstream and rebatched later | |
| image = image[0] | |
| request.image.value = None # avoid writing image again since it's in memory | |
| shm = shared_memory.SharedMemory(create=True, size=image.nbytes) | |
| with shm_manager(shm): | |
| shared = np.ndarray(image.shape, dtype=image.dtype, buffer=shm.buf) | |
| shared[:] = image[:] | |
| shm_metadata = SharedMemoryMetadata(shm.name, image.shape, image.dtype.name) | |
| queue_infer_task( | |
| redis_client, shm_metadata, request, preprocess_return_metadata | |
| ) | |
| def postprocess( | |
| shm_info_list: Tuple[Dict], request: Dict, preproc_return_metadata: Dict | |
| ): | |
| redis_client = Redis(connection_pool=pool) | |
| shm_info_list: List[SharedMemoryMetadata] = [ | |
| SharedMemoryMetadata(**metadata) for metadata in shm_info_list | |
| ] | |
| with failure_handler(redis_client, request["id"]): | |
| with shm_manager( | |
| *[shm_metadata.shm_name for shm_metadata in shm_info_list], | |
| unlink_on_success=True, | |
| ) as shms: | |
| model_manager.add_model(request["model_id"], request["api_key"]) | |
| model_type = model_manager.get_task_type(request["model_id"]) | |
| request = request_from_type(model_type, request) | |
| outputs = load_outputs(shm_info_list, shms) | |
| request_dict = dict(**request.dict()) | |
| model_id = request_dict.pop("model_id") | |
| response = model_manager.postprocess( | |
| model_id, | |
| outputs, | |
| preproc_return_metadata, | |
| **request_dict, | |
| return_image_dims=True, | |
| )[0] | |
| write_response(redis_client, response, request.id) | |
| def load_outputs( | |
| shm_info_list: List[SharedMemoryMetadata], shms: List[shared_memory.SharedMemory] | |
| ) -> Tuple[np.ndarray, ...]: | |
| outputs = [] | |
| for args, shm in zip(shm_info_list, shms): | |
| output = np.ndarray( | |
| [1] + args.array_shape, dtype=args.array_dtype, buffer=shm.buf | |
| ) | |
| outputs.append(output) | |
| return tuple(outputs) | |
| def queue_infer_task( | |
| redis: Redis, | |
| shm_metadata: SharedMemoryMetadata, | |
| request: InferenceRequest, | |
| preprocess_return_metadata: Dict, | |
| ): | |
| return_vals = { | |
| "shm_metadata": asdict(shm_metadata), | |
| "request": request.dict(), | |
| "preprocess_metadata": preprocess_return_metadata, | |
| } | |
| return_vals = json.dumps(return_vals) | |
| pipe = redis.pipeline() | |
| pipe.zadd(f"infer:{request.model_id}", {return_vals: request.start}) | |
| pipe.hincrby(f"requests", request.model_id, 1) | |
| pipe.execute() | |
| def write_response(redis: Redis, response: InferenceResponse, request_id: str): | |
| response = response.dict(exclude_none=True, by_alias=True) | |
| redis.publish( | |
| f"results", | |
| json.dumps( | |
| {"status": SUCCESS_STATE, "task_id": request_id, "payload": response} | |
| ), | |
| ) | |