| import asyncio | |
| from copy import deepcopy | |
| from functools import partial | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| from uuid import uuid4 | |
| from inference.core.entities.requests.clip import ClipCompareRequest | |
| from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest | |
| from inference.core.entities.requests.inference import ( | |
| ClassificationInferenceRequest, | |
| InstanceSegmentationInferenceRequest, | |
| KeypointsDetectionInferenceRequest, | |
| ObjectDetectionInferenceRequest, | |
| ) | |
| from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest | |
| from inference.core.env import ( | |
| HOSTED_CLASSIFICATION_URL, | |
| HOSTED_CORE_MODEL_URL, | |
| HOSTED_DETECT_URL, | |
| HOSTED_INSTANCE_SEGMENTATION_URL, | |
| LOCAL_INFERENCE_API_URL, | |
| WORKFLOWS_REMOTE_API_TARGET, | |
| WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| from inference.core.managers.base import ModelManager | |
| from inference.enterprise.workflows.complier.entities import StepExecutionMode | |
| from inference.enterprise.workflows.complier.steps_executors.constants import ( | |
| CENTER_X_KEY, | |
| CENTER_Y_KEY, | |
| ORIGIN_COORDINATES_KEY, | |
| ORIGIN_SIZE_KEY, | |
| PARENT_COORDINATES_SUFFIX, | |
| ) | |
| from inference.enterprise.workflows.complier.steps_executors.types import ( | |
| NextStepReference, | |
| OutputsLookup, | |
| ) | |
| from inference.enterprise.workflows.complier.steps_executors.utils import ( | |
| get_image, | |
| make_batches, | |
| resolve_parameter, | |
| ) | |
| from inference.enterprise.workflows.complier.utils import construct_step_selector | |
| from inference.enterprise.workflows.entities.steps import ( | |
| ClassificationModel, | |
| ClipComparison, | |
| InstanceSegmentationModel, | |
| KeypointsDetectionModel, | |
| MultiLabelClassificationModel, | |
| ObjectDetectionModel, | |
| OCRModel, | |
| RoboflowModel, | |
| StepInterface, | |
| YoloWorld, | |
| ) | |
| from inference_sdk import InferenceConfiguration, InferenceHTTPClient | |
| MODEL_TYPE2PREDICTION_TYPE = { | |
| "ClassificationModel": "classification", | |
| "MultiLabelClassificationModel": "classification", | |
| "ObjectDetectionModel": "object-detection", | |
| "InstanceSegmentationModel": "instance-segmentation", | |
| "KeypointsDetectionModel": "keypoint-detection", | |
| } | |
| async def run_roboflow_model_step( | |
| step: RoboflowModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| step_execution_mode: StepExecutionMode, | |
| ) -> Tuple[NextStepReference, OutputsLookup]: | |
| model_id = resolve_parameter( | |
| selector_or_value=step.model_id, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| image = get_image( | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| if step_execution_mode is StepExecutionMode.LOCAL: | |
| serialised_result = await get_roboflow_model_predictions_locally( | |
| image=image, | |
| model_id=model_id, | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| model_manager=model_manager, | |
| api_key=api_key, | |
| ) | |
| else: | |
| serialised_result = await get_roboflow_model_predictions_from_remote_api( | |
| image=image, | |
| model_id=model_id, | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| api_key=api_key, | |
| ) | |
| serialised_result = attach_prediction_type_info( | |
| results=serialised_result, | |
| prediction_type=MODEL_TYPE2PREDICTION_TYPE[step.get_type()], | |
| ) | |
| if step.type in {"ClassificationModel", "MultiLabelClassificationModel"}: | |
| serialised_result = attach_parent_info( | |
| image=image, results=serialised_result, nested_key=None | |
| ) | |
| else: | |
| serialised_result = attach_parent_info(image=image, results=serialised_result) | |
| serialised_result = anchor_detections_in_parent_coordinates( | |
| image=image, | |
| serialised_result=serialised_result, | |
| ) | |
| outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result | |
| return None, outputs_lookup | |
| async def get_roboflow_model_predictions_locally( | |
| image: List[dict], | |
| model_id: str, | |
| step: RoboflowModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| request_constructor = MODEL_TYPE2REQUEST_CONSTRUCTOR[step.type] | |
| request = request_constructor( | |
| step=step, | |
| image=image, | |
| api_key=api_key, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| model_manager.add_model( | |
| model_id=model_id, | |
| api_key=api_key, | |
| ) | |
| result = await model_manager.infer_from_request(model_id=model_id, request=request) | |
| if issubclass(type(result), list): | |
| serialised_result = [e.dict(by_alias=True, exclude_none=True) for e in result] | |
| else: | |
| serialised_result = [result.dict(by_alias=True, exclude_none=True)] | |
| return serialised_result | |
| def construct_classification_request( | |
| step: Union[ClassificationModel, MultiLabelClassificationModel], | |
| image: Any, | |
| api_key: Optional[str], | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> ClassificationInferenceRequest: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return ClassificationInferenceRequest( | |
| api_key=api_key, | |
| model_id=resolve(step.model_id), | |
| image=image, | |
| confidence=resolve(step.confidence), | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| ) | |
| def construct_object_detection_request( | |
| step: ObjectDetectionModel, | |
| image: Any, | |
| api_key: Optional[str], | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> ObjectDetectionInferenceRequest: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return ObjectDetectionInferenceRequest( | |
| api_key=api_key, | |
| model_id=resolve(step.model_id), | |
| image=image, | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| ) | |
| def construct_instance_segmentation_request( | |
| step: InstanceSegmentationModel, | |
| image: Any, | |
| api_key: Optional[str], | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> InstanceSegmentationInferenceRequest: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return InstanceSegmentationInferenceRequest( | |
| api_key=api_key, | |
| model_id=resolve(step.model_id), | |
| image=image, | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| mask_decode_mode=resolve(step.mask_decode_mode), | |
| tradeoff_factor=resolve(step.tradeoff_factor), | |
| ) | |
| def construct_keypoints_detection_request( | |
| step: KeypointsDetectionModel, | |
| image: Any, | |
| api_key: Optional[str], | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> KeypointsDetectionInferenceRequest: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return KeypointsDetectionInferenceRequest( | |
| api_key=api_key, | |
| model_id=resolve(step.model_id), | |
| image=image, | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| keypoint_confidence=resolve(step.keypoint_confidence), | |
| ) | |
| MODEL_TYPE2REQUEST_CONSTRUCTOR = { | |
| "ClassificationModel": construct_classification_request, | |
| "MultiLabelClassificationModel": construct_classification_request, | |
| "ObjectDetectionModel": construct_object_detection_request, | |
| "InstanceSegmentationModel": construct_instance_segmentation_request, | |
| "KeypointsDetectionModel": construct_keypoints_detection_request, | |
| } | |
| async def get_roboflow_model_predictions_from_remote_api( | |
| image: List[dict], | |
| model_id: str, | |
| step: RoboflowModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| api_url = resolve_model_api_url(step=step) | |
| client = InferenceHTTPClient( | |
| api_url=api_url, | |
| api_key=api_key, | |
| ) | |
| if WORKFLOWS_REMOTE_API_TARGET == "hosted": | |
| client.select_api_v0() | |
| configuration = MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR[step.type]( | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| client.configure(inference_configuration=configuration) | |
| inference_input = [i["value"] for i in image] | |
| results = await client.infer_async( | |
| inference_input=inference_input, | |
| model_id=model_id, | |
| ) | |
| if not issubclass(type(results), list): | |
| return [results] | |
| return results | |
| def construct_http_client_configuration_for_classification_step( | |
| step: Union[ClassificationModel, MultiLabelClassificationModel], | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> InferenceConfiguration: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return InferenceConfiguration( | |
| confidence_threshold=resolve(step.confidence), | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| def construct_http_client_configuration_for_detection_step( | |
| step: ObjectDetectionModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> InferenceConfiguration: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return InferenceConfiguration( | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence_threshold=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| def construct_http_client_configuration_for_segmentation_step( | |
| step: InstanceSegmentationModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> InferenceConfiguration: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return InferenceConfiguration( | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence_threshold=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| mask_decode_mode=resolve(step.mask_decode_mode), | |
| tradeoff_factor=resolve(step.tradeoff_factor), | |
| max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| def construct_http_client_configuration_for_keypoints_detection_step( | |
| step: KeypointsDetectionModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| ) -> InferenceConfiguration: | |
| resolve = partial( | |
| resolve_parameter, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| return InferenceConfiguration( | |
| disable_active_learning=resolve(step.disable_active_learning), | |
| class_agnostic_nms=resolve(step.class_agnostic_nms), | |
| class_filter=resolve(step.class_filter), | |
| confidence_threshold=resolve(step.confidence), | |
| iou_threshold=resolve(step.iou_threshold), | |
| max_detections=resolve(step.max_detections), | |
| max_candidates=resolve(step.max_candidates), | |
| keypoint_confidence_threshold=resolve(step.keypoint_confidence), | |
| max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR = { | |
| "ClassificationModel": construct_http_client_configuration_for_classification_step, | |
| "MultiLabelClassificationModel": construct_http_client_configuration_for_classification_step, | |
| "ObjectDetectionModel": construct_http_client_configuration_for_detection_step, | |
| "InstanceSegmentationModel": construct_http_client_configuration_for_segmentation_step, | |
| "KeypointsDetectionModel": construct_http_client_configuration_for_keypoints_detection_step, | |
| } | |
| async def run_yolo_world_model_step( | |
| step: YoloWorld, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| step_execution_mode: StepExecutionMode, | |
| ) -> Tuple[NextStepReference, OutputsLookup]: | |
| image = get_image( | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| class_names = resolve_parameter( | |
| selector_or_value=step.class_names, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| model_version = resolve_parameter( | |
| selector_or_value=step.version, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| confidence = resolve_parameter( | |
| selector_or_value=step.confidence, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| if step_execution_mode is StepExecutionMode.LOCAL: | |
| serialised_result = await get_yolo_world_predictions_locally( | |
| image=image, | |
| class_names=class_names, | |
| model_version=model_version, | |
| confidence=confidence, | |
| model_manager=model_manager, | |
| api_key=api_key, | |
| ) | |
| else: | |
| serialised_result = await get_yolo_world_predictions_from_remote_api( | |
| image=image, | |
| class_names=class_names, | |
| model_version=model_version, | |
| confidence=confidence, | |
| step=step, | |
| api_key=api_key, | |
| ) | |
| serialised_result = attach_prediction_type_info( | |
| results=serialised_result, | |
| prediction_type="object-detection", | |
| ) | |
| serialised_result = attach_parent_info(image=image, results=serialised_result) | |
| serialised_result = anchor_detections_in_parent_coordinates( | |
| image=image, | |
| serialised_result=serialised_result, | |
| ) | |
| outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result | |
| return None, outputs_lookup | |
| async def get_yolo_world_predictions_locally( | |
| image: List[dict], | |
| class_names: List[str], | |
| model_version: Optional[str], | |
| confidence: Optional[float], | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| serialised_result = [] | |
| for single_image in image: | |
| inference_request = YOLOWorldInferenceRequest( | |
| image=single_image, | |
| yolo_world_version_id=model_version, | |
| confidence=confidence, | |
| text=class_names, | |
| ) | |
| yolo_world_model_id = load_core_model( | |
| model_manager=model_manager, | |
| inference_request=inference_request, | |
| core_model="yolo_world", | |
| api_key=api_key, | |
| ) | |
| result = await model_manager.infer_from_request( | |
| yolo_world_model_id, inference_request | |
| ) | |
| serialised_result.append(result.dict()) | |
| return serialised_result | |
| async def get_yolo_world_predictions_from_remote_api( | |
| image: List[dict], | |
| class_names: List[str], | |
| model_version: Optional[str], | |
| confidence: Optional[float], | |
| step: YoloWorld, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| api_url = resolve_model_api_url(step=step) | |
| client = InferenceHTTPClient( | |
| api_url=api_url, | |
| api_key=api_key, | |
| ) | |
| configuration = InferenceConfiguration( | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| client.configure(inference_configuration=configuration) | |
| if WORKFLOWS_REMOTE_API_TARGET == "hosted": | |
| client.select_api_v0() | |
| image_batches = list( | |
| make_batches( | |
| iterable=image, | |
| batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| ) | |
| serialised_result = [] | |
| for single_batch in image_batches: | |
| batch_results = await client.infer_from_yolo_world_async( | |
| inference_input=[i["value"] for i in single_batch], | |
| class_names=class_names, | |
| model_version=model_version, | |
| confidence=confidence, | |
| ) | |
| serialised_result.extend(batch_results) | |
| return serialised_result | |
| async def run_ocr_model_step( | |
| step: OCRModel, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| step_execution_mode: StepExecutionMode, | |
| ) -> Tuple[NextStepReference, OutputsLookup]: | |
| image = get_image( | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| if step_execution_mode is StepExecutionMode.LOCAL: | |
| serialised_result = await get_ocr_predictions_locally( | |
| image=image, | |
| model_manager=model_manager, | |
| api_key=api_key, | |
| ) | |
| else: | |
| serialised_result = await get_ocr_predictions_from_remote_api( | |
| step=step, | |
| image=image, | |
| api_key=api_key, | |
| ) | |
| serialised_result = attach_parent_info( | |
| image=image, | |
| results=serialised_result, | |
| nested_key=None, | |
| ) | |
| serialised_result = attach_prediction_type_info( | |
| results=serialised_result, | |
| prediction_type="ocr", | |
| ) | |
| outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result | |
| return None, outputs_lookup | |
| async def get_ocr_predictions_locally( | |
| image: List[dict], | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| serialised_result = [] | |
| for single_image in image: | |
| inference_request = DoctrOCRInferenceRequest( | |
| image=single_image, | |
| ) | |
| doctr_model_id = load_core_model( | |
| model_manager=model_manager, | |
| inference_request=inference_request, | |
| core_model="doctr", | |
| api_key=api_key, | |
| ) | |
| result = await model_manager.infer_from_request( | |
| doctr_model_id, inference_request | |
| ) | |
| serialised_result.append(result.dict()) | |
| return serialised_result | |
| async def get_ocr_predictions_from_remote_api( | |
| step: OCRModel, | |
| image: List[dict], | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| api_url = resolve_model_api_url(step=step) | |
| client = InferenceHTTPClient( | |
| api_url=api_url, | |
| api_key=api_key, | |
| ) | |
| if WORKFLOWS_REMOTE_API_TARGET == "hosted": | |
| client.select_api_v0() | |
| configuration = InferenceConfiguration( | |
| max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, | |
| max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| client.configure(configuration) | |
| result = await client.ocr_image_async( | |
| inference_input=[i["value"] for i in image], | |
| ) | |
| if len(image) == 1: | |
| return [result] | |
| return result | |
| async def run_clip_comparison_step( | |
| step: ClipComparison, | |
| runtime_parameters: Dict[str, Any], | |
| outputs_lookup: OutputsLookup, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| step_execution_mode: StepExecutionMode, | |
| ) -> Tuple[NextStepReference, OutputsLookup]: | |
| image = get_image( | |
| step=step, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| text = resolve_parameter( | |
| selector_or_value=step.text, | |
| runtime_parameters=runtime_parameters, | |
| outputs_lookup=outputs_lookup, | |
| ) | |
| if step_execution_mode is StepExecutionMode.LOCAL: | |
| serialised_result = await get_clip_comparison_locally( | |
| image=image, | |
| text=text, | |
| model_manager=model_manager, | |
| api_key=api_key, | |
| ) | |
| else: | |
| serialised_result = await get_clip_comparison_from_remote_api( | |
| step=step, | |
| image=image, | |
| text=text, | |
| api_key=api_key, | |
| ) | |
| serialised_result = attach_parent_info( | |
| image=image, | |
| results=serialised_result, | |
| nested_key=None, | |
| ) | |
| serialised_result = attach_prediction_type_info( | |
| results=serialised_result, | |
| prediction_type="embeddings-comparison", | |
| ) | |
| outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result | |
| return None, outputs_lookup | |
| async def get_clip_comparison_locally( | |
| image: List[dict], | |
| text: str, | |
| model_manager: ModelManager, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| serialised_result = [] | |
| for single_image in image: | |
| inference_request = ClipCompareRequest( | |
| subject=single_image, subject_type="image", prompt=text, prompt_type="text" | |
| ) | |
| doctr_model_id = load_core_model( | |
| model_manager=model_manager, | |
| inference_request=inference_request, | |
| core_model="clip", | |
| api_key=api_key, | |
| ) | |
| result = await model_manager.infer_from_request( | |
| doctr_model_id, inference_request | |
| ) | |
| serialised_result.append(result.dict()) | |
| return serialised_result | |
| async def get_clip_comparison_from_remote_api( | |
| step: ClipComparison, | |
| image: List[dict], | |
| text: str, | |
| api_key: Optional[str], | |
| ) -> List[dict]: | |
| api_url = resolve_model_api_url(step=step) | |
| client = InferenceHTTPClient( | |
| api_url=api_url, | |
| api_key=api_key, | |
| ) | |
| if WORKFLOWS_REMOTE_API_TARGET == "hosted": | |
| client.select_api_v0() | |
| image_batches = list( | |
| make_batches( | |
| iterable=image, | |
| batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, | |
| ) | |
| ) | |
| serialised_result = [] | |
| for single_batch in image_batches: | |
| coroutines = [] | |
| for single_image in single_batch: | |
| coroutine = client.clip_compare_async( | |
| subject=single_image["value"], | |
| prompt=text, | |
| ) | |
| coroutines.append(coroutine) | |
| batch_results = list(await asyncio.gather(*coroutines)) | |
| serialised_result.extend(batch_results) | |
| return serialised_result | |
| def load_core_model( | |
| model_manager: ModelManager, | |
| inference_request: Union[DoctrOCRInferenceRequest, ClipCompareRequest], | |
| core_model: str, | |
| api_key: Optional[str] = None, | |
| ) -> str: | |
| if api_key: | |
| inference_request.api_key = api_key | |
| version_id_field = f"{core_model}_version_id" | |
| core_model_id = ( | |
| f"{core_model}/{inference_request.__getattribute__(version_id_field)}" | |
| ) | |
| model_manager.add_model(core_model_id, inference_request.api_key) | |
| return core_model_id | |
| def attach_prediction_type_info( | |
| results: List[Dict[str, Any]], | |
| prediction_type: str, | |
| key: str = "prediction_type", | |
| ) -> List[Dict[str, Any]]: | |
| for result in results: | |
| result[key] = prediction_type | |
| return results | |
| def attach_parent_info( | |
| image: List[Dict[str, Any]], | |
| results: List[Dict[str, Any]], | |
| nested_key: Optional[str] = "predictions", | |
| ) -> List[Dict[str, Any]]: | |
| return [ | |
| attach_parent_info_to_image_detections( | |
| image=i, predictions=p, nested_key=nested_key | |
| ) | |
| for i, p in zip(image, results) | |
| ] | |
| def attach_parent_info_to_image_detections( | |
| image: Dict[str, Any], | |
| predictions: Dict[str, Any], | |
| nested_key: Optional[str], | |
| ) -> Dict[str, Any]: | |
| predictions["parent_id"] = image["parent_id"] | |
| if nested_key is None: | |
| return predictions | |
| for prediction in predictions[nested_key]: | |
| prediction["parent_id"] = image["parent_id"] | |
| return predictions | |
| def anchor_detections_in_parent_coordinates( | |
| image: List[Dict[str, Any]], | |
| serialised_result: List[Dict[str, Any]], | |
| image_metadata_key: str = "image", | |
| detections_key: str = "predictions", | |
| ) -> List[Dict[str, Any]]: | |
| return [ | |
| anchor_image_detections_in_parent_coordinates( | |
| image=i, | |
| serialised_result=d, | |
| image_metadata_key=image_metadata_key, | |
| detections_key=detections_key, | |
| ) | |
| for i, d in zip(image, serialised_result) | |
| ] | |
| def anchor_image_detections_in_parent_coordinates( | |
| image: Dict[str, Any], | |
| serialised_result: Dict[str, Any], | |
| image_metadata_key: str = "image", | |
| detections_key: str = "predictions", | |
| ) -> Dict[str, Any]: | |
| serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( | |
| serialised_result[detections_key] | |
| ) | |
| serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( | |
| serialised_result[image_metadata_key] | |
| ) | |
| if ORIGIN_COORDINATES_KEY not in image: | |
| return serialised_result | |
| shift_x, shift_y = ( | |
| image[ORIGIN_COORDINATES_KEY][CENTER_X_KEY], | |
| image[ORIGIN_COORDINATES_KEY][CENTER_Y_KEY], | |
| ) | |
| for detection in serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"]: | |
| detection["x"] += shift_x | |
| detection["y"] += shift_y | |
| serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = image[ | |
| ORIGIN_COORDINATES_KEY | |
| ][ORIGIN_SIZE_KEY] | |
| return serialised_result | |
| ROBOFLOW_MODEL2HOSTED_ENDPOINT = { | |
| "ClassificationModel": HOSTED_CLASSIFICATION_URL, | |
| "MultiLabelClassificationModel": HOSTED_CLASSIFICATION_URL, | |
| "ObjectDetectionModel": HOSTED_DETECT_URL, | |
| "KeypointsDetectionModel": HOSTED_DETECT_URL, | |
| "InstanceSegmentationModel": HOSTED_INSTANCE_SEGMENTATION_URL, | |
| "OCRModel": HOSTED_CORE_MODEL_URL, | |
| "ClipComparison": HOSTED_CORE_MODEL_URL, | |
| } | |
| def resolve_model_api_url(step: StepInterface) -> str: | |
| if WORKFLOWS_REMOTE_API_TARGET != "hosted": | |
| return LOCAL_INFERENCE_API_URL | |
| return ROBOFLOW_MODEL2HOSTED_ENDPOINT[step.get_type()] | |