Spaces:
Runtime error
Runtime error
import json | |
from typing import List, Tuple | |
from inference.core.active_learning.entities import ( | |
Prediction, | |
PredictionFileType, | |
PredictionType, | |
SerialisedPrediction, | |
) | |
from inference.core.constants import ( | |
CLASSIFICATION_TASK, | |
INSTANCE_SEGMENTATION_TASK, | |
OBJECT_DETECTION_TASK, | |
) | |
from inference.core.exceptions import PredictionFormatNotSupported | |
def adjust_prediction_to_client_scaling_factor( | |
prediction: dict, scaling_factor: float, prediction_type: PredictionType | |
) -> dict: | |
if abs(scaling_factor - 1.0) < 1e-5: | |
return prediction | |
if "image" in prediction: | |
prediction["image"] = { | |
"width": round(prediction["image"]["width"] / scaling_factor), | |
"height": round(prediction["image"]["height"] / scaling_factor), | |
} | |
if predictions_should_not_be_post_processed( | |
prediction=prediction, prediction_type=prediction_type | |
): | |
return prediction | |
if prediction_type == INSTANCE_SEGMENTATION_TASK: | |
prediction["predictions"] = ( | |
adjust_prediction_with_bbox_and_points_to_client_scaling_factor( | |
predictions=prediction["predictions"], | |
scaling_factor=scaling_factor, | |
points_key="points", | |
) | |
) | |
if prediction_type == OBJECT_DETECTION_TASK: | |
prediction["predictions"] = ( | |
adjust_object_detection_predictions_to_client_scaling_factor( | |
predictions=prediction["predictions"], | |
scaling_factor=scaling_factor, | |
) | |
) | |
return prediction | |
def predictions_should_not_be_post_processed( | |
prediction: dict, prediction_type: PredictionType | |
) -> bool: | |
# excluding from post-processing classification output, stub-output and empty predictions | |
return ( | |
"is_stub" in prediction | |
or "predictions" not in prediction | |
or CLASSIFICATION_TASK in prediction_type | |
or len(prediction["predictions"]) == 0 | |
) | |
def adjust_object_detection_predictions_to_client_scaling_factor( | |
predictions: List[dict], | |
scaling_factor: float, | |
) -> List[dict]: | |
result = [] | |
for prediction in predictions: | |
prediction = adjust_bbox_coordinates_to_client_scaling_factor( | |
bbox=prediction, | |
scaling_factor=scaling_factor, | |
) | |
result.append(prediction) | |
return result | |
def adjust_prediction_with_bbox_and_points_to_client_scaling_factor( | |
predictions: List[dict], | |
scaling_factor: float, | |
points_key: str, | |
) -> List[dict]: | |
result = [] | |
for prediction in predictions: | |
prediction = adjust_bbox_coordinates_to_client_scaling_factor( | |
bbox=prediction, | |
scaling_factor=scaling_factor, | |
) | |
prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor( | |
points=prediction[points_key], | |
scaling_factor=scaling_factor, | |
) | |
result.append(prediction) | |
return result | |
def adjust_bbox_coordinates_to_client_scaling_factor( | |
bbox: dict, | |
scaling_factor: float, | |
) -> dict: | |
bbox["x"] = bbox["x"] / scaling_factor | |
bbox["y"] = bbox["y"] / scaling_factor | |
bbox["width"] = bbox["width"] / scaling_factor | |
bbox["height"] = bbox["height"] / scaling_factor | |
return bbox | |
def adjust_points_coordinates_to_client_scaling_factor( | |
points: List[dict], | |
scaling_factor: float, | |
) -> List[dict]: | |
result = [] | |
for point in points: | |
point["x"] = point["x"] / scaling_factor | |
point["y"] = point["y"] / scaling_factor | |
result.append(point) | |
return result | |
def encode_prediction( | |
prediction: Prediction, | |
prediction_type: PredictionType, | |
) -> Tuple[SerialisedPrediction, PredictionFileType]: | |
if CLASSIFICATION_TASK not in prediction_type: | |
return json.dumps(prediction), "json" | |
if "top" in prediction: | |
return prediction["top"], "txt" | |
raise PredictionFormatNotSupported( | |
f"Prediction type or prediction format not supported." | |
) | |