Spaces:
Configuration error
Configuration error
| import queue | |
| from queue import Queue | |
| from threading import Thread | |
| from typing import Any, List, Optional | |
| from inference.core import logger | |
| from inference.core.active_learning.accounting import image_can_be_submitted_to_batch | |
| from inference.core.active_learning.batching import generate_batch_name | |
| from inference.core.active_learning.configuration import ( | |
| prepare_active_learning_configuration, | |
| prepare_active_learning_configuration_inplace, | |
| ) | |
| from inference.core.active_learning.core import ( | |
| execute_datapoint_registration, | |
| execute_sampling, | |
| ) | |
| from inference.core.active_learning.entities import ( | |
| ActiveLearningConfiguration, | |
| Prediction, | |
| PredictionType, | |
| ) | |
| from inference.core.cache.base import BaseCache | |
| from inference.core.utils.image_utils import load_image | |
| MAX_REGISTRATION_QUEUE_SIZE = 512 | |
| class NullActiveLearningMiddleware: | |
| def register_batch( | |
| self, | |
| inference_inputs: List[Any], | |
| predictions: List[Prediction], | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| pass | |
| def register( | |
| self, | |
| inference_input: Any, | |
| prediction: dict, | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| pass | |
| def start_registration_thread(self) -> None: | |
| pass | |
| def stop_registration_thread(self) -> None: | |
| pass | |
| def __enter__(self) -> "NullActiveLearningMiddleware": | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| pass | |
| class ActiveLearningMiddleware: | |
| def init( | |
| cls, api_key: str, model_id: str, cache: BaseCache | |
| ) -> "ActiveLearningMiddleware": | |
| configuration = prepare_active_learning_configuration( | |
| api_key=api_key, | |
| model_id=model_id, | |
| cache=cache, | |
| ) | |
| return cls( | |
| api_key=api_key, | |
| configuration=configuration, | |
| cache=cache, | |
| ) | |
| def init_from_config( | |
| cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict] | |
| ) -> "ActiveLearningMiddleware": | |
| configuration = prepare_active_learning_configuration_inplace( | |
| api_key=api_key, | |
| model_id=model_id, | |
| active_learning_configuration=config, | |
| ) | |
| return cls( | |
| api_key=api_key, | |
| configuration=configuration, | |
| cache=cache, | |
| ) | |
| def __init__( | |
| self, | |
| api_key: str, | |
| configuration: Optional[ActiveLearningConfiguration], | |
| cache: BaseCache, | |
| ): | |
| self._api_key = api_key | |
| self._configuration = configuration | |
| self._cache = cache | |
| def register_batch( | |
| self, | |
| inference_inputs: List[Any], | |
| predictions: List[Prediction], | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| for inference_input, prediction in zip(inference_inputs, predictions): | |
| self.register( | |
| inference_input=inference_input, | |
| prediction=prediction, | |
| prediction_type=prediction_type, | |
| disable_preproc_auto_orient=disable_preproc_auto_orient, | |
| ) | |
| def register( | |
| self, | |
| inference_input: Any, | |
| prediction: dict, | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| self._execute_registration( | |
| inference_input=inference_input, | |
| prediction=prediction, | |
| prediction_type=prediction_type, | |
| disable_preproc_auto_orient=disable_preproc_auto_orient, | |
| ) | |
| def _execute_registration( | |
| self, | |
| inference_input: Any, | |
| prediction: dict, | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| if self._configuration is None: | |
| return None | |
| image, is_bgr = load_image( | |
| value=inference_input, | |
| disable_preproc_auto_orient=disable_preproc_auto_orient, | |
| ) | |
| if not is_bgr: | |
| image = image[:, :, ::-1] | |
| matching_strategies = execute_sampling( | |
| image=image, | |
| prediction=prediction, | |
| prediction_type=prediction_type, | |
| sampling_methods=self._configuration.sampling_methods, | |
| ) | |
| if len(matching_strategies) == 0: | |
| return None | |
| batch_name = generate_batch_name(configuration=self._configuration) | |
| if not image_can_be_submitted_to_batch( | |
| batch_name=batch_name, | |
| workspace_id=self._configuration.workspace_id, | |
| dataset_id=self._configuration.dataset_id, | |
| max_batch_images=self._configuration.max_batch_images, | |
| api_key=self._api_key, | |
| ): | |
| logger.debug(f"Limit on Active Learning batch size reached.") | |
| return None | |
| execute_datapoint_registration( | |
| cache=self._cache, | |
| matching_strategies=matching_strategies, | |
| image=image, | |
| prediction=prediction, | |
| prediction_type=prediction_type, | |
| configuration=self._configuration, | |
| api_key=self._api_key, | |
| batch_name=batch_name, | |
| ) | |
| class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware): | |
| def init( | |
| cls, | |
| api_key: str, | |
| model_id: str, | |
| cache: BaseCache, | |
| max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, | |
| ) -> "ThreadingActiveLearningMiddleware": | |
| configuration = prepare_active_learning_configuration( | |
| api_key=api_key, | |
| model_id=model_id, | |
| cache=cache, | |
| ) | |
| task_queue = Queue(max_queue_size) | |
| return cls( | |
| api_key=api_key, | |
| configuration=configuration, | |
| cache=cache, | |
| task_queue=task_queue, | |
| ) | |
| def init_from_config( | |
| cls, | |
| api_key: str, | |
| model_id: str, | |
| cache: BaseCache, | |
| config: Optional[dict], | |
| max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, | |
| ) -> "ThreadingActiveLearningMiddleware": | |
| configuration = prepare_active_learning_configuration_inplace( | |
| api_key=api_key, | |
| model_id=model_id, | |
| active_learning_configuration=config, | |
| ) | |
| task_queue = Queue(max_queue_size) | |
| return cls( | |
| api_key=api_key, | |
| configuration=configuration, | |
| cache=cache, | |
| task_queue=task_queue, | |
| ) | |
| def __init__( | |
| self, | |
| api_key: str, | |
| configuration: ActiveLearningConfiguration, | |
| cache: BaseCache, | |
| task_queue: Queue, | |
| ): | |
| super().__init__(api_key=api_key, configuration=configuration, cache=cache) | |
| self._task_queue = task_queue | |
| self._registration_thread: Optional[Thread] = None | |
| def register( | |
| self, | |
| inference_input: Any, | |
| prediction: dict, | |
| prediction_type: PredictionType, | |
| disable_preproc_auto_orient: bool = False, | |
| ) -> None: | |
| logger.debug(f"Putting registration task into queue") | |
| try: | |
| self._task_queue.put_nowait( | |
| ( | |
| inference_input, | |
| prediction, | |
| prediction_type, | |
| disable_preproc_auto_orient, | |
| ) | |
| ) | |
| except queue.Full: | |
| logger.warning( | |
| f"Dropping datapoint registered in Active Learning due to insufficient processing " | |
| f"capabilities." | |
| ) | |
| def start_registration_thread(self) -> None: | |
| if self._registration_thread is not None: | |
| logger.warning(f"Registration thread already started.") | |
| return None | |
| logger.debug("Staring registration thread") | |
| self._registration_thread = Thread(target=self._consume_queue) | |
| self._registration_thread.start() | |
| def stop_registration_thread(self) -> None: | |
| if self._registration_thread is None: | |
| logger.warning("Registration thread is already stopped.") | |
| return None | |
| logger.debug("Stopping registration thread") | |
| self._task_queue.put(None) | |
| self._registration_thread.join() | |
| if self._registration_thread.is_alive(): | |
| logger.warning(f"Registration thread stopping was unsuccessful.") | |
| self._registration_thread = None | |
| def _consume_queue(self) -> None: | |
| queue_closed = False | |
| while not queue_closed: | |
| queue_closed = self._consume_queue_task() | |
| def _consume_queue_task(self) -> bool: | |
| logger.debug("Consuming registration task") | |
| task = self._task_queue.get() | |
| logger.debug("Received registration task") | |
| if task is None: | |
| logger.debug("Terminating registration thread") | |
| self._task_queue.task_done() | |
| return True | |
| inference_input, prediction, prediction_type, disable_preproc_auto_orient = task | |
| try: | |
| self._execute_registration( | |
| inference_input=inference_input, | |
| prediction=prediction, | |
| prediction_type=prediction_type, | |
| disable_preproc_auto_orient=disable_preproc_auto_orient, | |
| ) | |
| except Exception as error: | |
| # Error handling to be decided | |
| logger.warning( | |
| f"Error in datapoint registration for Active Learning. Details: {error}. " | |
| f"Error is suppressed in favour of normal operations of registration thread." | |
| ) | |
| self._task_queue.task_done() | |
| return False | |
| def __enter__(self) -> "ThreadingActiveLearningMiddleware": | |
| self.start_registration_thread() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | |
| self.stop_registration_thread() | |