Spaces:
Runtime error
Runtime error
from typing import List, Optional, Tuple | |
import numpy as np | |
from inference.core.entities.responses.inference import ( | |
InferenceResponseImage, | |
Keypoint, | |
KeypointsDetectionInferenceResponse, | |
KeypointsPrediction, | |
) | |
from inference.core.exceptions import ModelArtefactError | |
from inference.core.models.object_detection_base import ( | |
ObjectDetectionBaseOnnxRoboflowInferenceModel, | |
) | |
from inference.core.models.types import PreprocessReturnMetadata | |
from inference.core.models.utils.keypoints import model_keypoints_to_response | |
from inference.core.models.utils.validate import ( | |
get_num_classes_from_model_prediction_shape, | |
) | |
from inference.core.nms import w_np_non_max_suppression | |
from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints | |
DEFAULT_CONFIDENCE = 0.4 | |
DEFAULT_IOU_THRESH = 0.3 | |
DEFAULT_CLASS_AGNOSTIC_NMS = False | |
DEFAUlT_MAX_DETECTIONS = 300 | |
DEFAULT_MAX_CANDIDATES = 3000 | |
class KeypointsDetectionBaseOnnxRoboflowInferenceModel( | |
ObjectDetectionBaseOnnxRoboflowInferenceModel | |
): | |
"""Roboflow ONNX Object detection model. This class implements an object detection specific infer method.""" | |
task_type = "keypoint-detection" | |
def __init__(self, model_id: str, *args, **kwargs): | |
super().__init__(model_id, *args, **kwargs) | |
def get_infer_bucket_file_list(self) -> list: | |
"""Returns the list of files to be downloaded from the inference bucket for ONNX model. | |
Returns: | |
list: A list of filenames specific to ONNX models. | |
""" | |
return ["environment.json", "class_names.txt", "keypoints_metadata.json"] | |
def postprocess( | |
self, | |
predictions: Tuple[np.ndarray], | |
preproc_return_metadata: PreprocessReturnMetadata, | |
class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS, | |
confidence: float = DEFAULT_CONFIDENCE, | |
iou_threshold: float = DEFAULT_IOU_THRESH, | |
max_candidates: int = DEFAULT_MAX_CANDIDATES, | |
max_detections: int = DEFAUlT_MAX_DETECTIONS, | |
return_image_dims: bool = False, | |
**kwargs, | |
) -> List[KeypointsDetectionInferenceResponse]: | |
"""Postprocesses the object detection predictions. | |
Args: | |
predictions (np.ndarray): Raw predictions from the model. | |
img_dims (List[Tuple[int, int]]): Dimensions of the images. | |
class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False. | |
confidence (float): Confidence threshold for filtering detections. Default is 0.5. | |
iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5. | |
max_candidates (int): Maximum number of candidate detections. Default is 3000. | |
max_detections (int): Maximum number of final detections. Default is 300. | |
Returns: | |
List[KeypointsDetectionInferenceResponse]: The post-processed predictions. | |
""" | |
predictions = predictions[0] | |
number_of_classes = len(self.get_class_names) | |
num_masks = predictions.shape[2] - 5 - number_of_classes | |
predictions = w_np_non_max_suppression( | |
predictions, | |
conf_thresh=confidence, | |
iou_thresh=iou_threshold, | |
class_agnostic=class_agnostic_nms, | |
max_detections=max_detections, | |
max_candidate_detections=max_candidates, | |
num_masks=num_masks, | |
) | |
infer_shape = (self.img_size_h, self.img_size_w) | |
img_dims = preproc_return_metadata["img_dims"] | |
predictions = post_process_bboxes( | |
predictions=predictions, | |
infer_shape=infer_shape, | |
img_dims=img_dims, | |
preproc=self.preproc, | |
resize_method=self.resize_method, | |
disable_preproc_static_crop=preproc_return_metadata[ | |
"disable_preproc_static_crop" | |
], | |
) | |
predictions = post_process_keypoints( | |
predictions=predictions, | |
keypoints_start_index=-num_masks, | |
infer_shape=infer_shape, | |
img_dims=img_dims, | |
preproc=self.preproc, | |
resize_method=self.resize_method, | |
disable_preproc_static_crop=preproc_return_metadata[ | |
"disable_preproc_static_crop" | |
], | |
) | |
return self.make_response(predictions, img_dims, **kwargs) | |
def make_response( | |
self, | |
predictions: List[List[float]], | |
img_dims: List[Tuple[int, int]], | |
class_filter: Optional[List[str]] = None, | |
*args, | |
**kwargs, | |
) -> List[KeypointsDetectionInferenceResponse]: | |
"""Constructs object detection response objects based on predictions. | |
Args: | |
predictions (List[List[float]]): The list of predictions. | |
img_dims (List[Tuple[int, int]]): Dimensions of the images. | |
class_filter (Optional[List[str]]): A list of class names to filter, if provided. | |
Returns: | |
List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions. | |
""" | |
if isinstance(img_dims, dict) and "img_dims" in img_dims: | |
img_dims = img_dims["img_dims"] | |
keypoint_confidence_threshold = 0.0 | |
if "request" in kwargs: | |
keypoint_confidence_threshold = kwargs["request"].keypoint_confidence | |
responses = [ | |
KeypointsDetectionInferenceResponse( | |
predictions=[ | |
KeypointsPrediction( | |
# Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) | |
**{ | |
"x": (pred[0] + pred[2]) / 2, | |
"y": (pred[1] + pred[3]) / 2, | |
"width": pred[2] - pred[0], | |
"height": pred[3] - pred[1], | |
"confidence": pred[4], | |
"class": self.class_names[int(pred[6])], | |
"class_id": int(pred[6]), | |
"keypoints": model_keypoints_to_response( | |
keypoints_metadata=self.keypoints_metadata, | |
keypoints=pred[7:], | |
predicted_object_class_id=int( | |
pred[4 + len(self.get_class_names)] | |
), | |
keypoint_confidence_threshold=keypoint_confidence_threshold, | |
), | |
} | |
) | |
for pred in batch_predictions | |
if not class_filter | |
or self.class_names[int(pred[6])] in class_filter | |
], | |
image=InferenceResponseImage( | |
width=img_dims[ind][1], height=img_dims[ind][0] | |
), | |
) | |
for ind, batch_predictions in enumerate(predictions) | |
] | |
return responses | |
def keypoints_count(self) -> int: | |
raise NotImplementedError | |
def validate_model_classes(self) -> None: | |
num_keypoints = self.keypoints_count() | |
output_shape = self.get_model_output_shape() | |
num_classes = get_num_classes_from_model_prediction_shape( | |
len_prediction=output_shape[2], keypoints=num_keypoints | |
) | |
if num_classes != self.num_classes: | |
raise ValueError( | |
f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" | |
) | |