Spaces:
Runtime error
Runtime error
from typing import Any, List, Optional, Set, Type | |
from pydantic import ValidationError | |
from inference.core.entities.requests.inference import InferenceRequestImage | |
from inference.enterprise.workflows.entities.base import GraphNone | |
from inference.enterprise.workflows.errors import ( | |
InvalidStepInputDetected, | |
VariableTypeError, | |
) | |
STEPS_WITH_IMAGE = { | |
"InferenceImage", | |
"Crop", | |
"AbsoluteStaticCrop", | |
"RelativeStaticCrop", | |
} | |
def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None: | |
if issubclass(type(value), list): | |
if any(not is_selector(selector_or_value=e) for e in value): | |
raise ValueError(f"`{field_name}` field can only contain selector values") | |
elif not is_selector(selector_or_value=value): | |
raise ValueError(f"`{field_name}` field can only contain selector values") | |
def validate_field_is_in_range_zero_one_or_empty_or_selector( | |
value: Any, field_name: str = "confidence" | |
) -> None: | |
if is_selector(selector_or_value=value) or value is None: | |
return None | |
validate_value_is_empty_or_number_in_range_zero_one( | |
value=value, field_name=field_name | |
) | |
def validate_value_is_empty_or_number_in_range_zero_one( | |
value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError | |
) -> None: | |
validate_field_has_given_type( | |
field_name=field_name, | |
allowed_types=[type(None), int, float], | |
value=value, | |
error=error, | |
) | |
if value is None: | |
return None | |
if not (0 <= value <= 1): | |
raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]") | |
def validate_value_is_empty_or_selector_or_positive_number( | |
value: Any, field_name: str | |
) -> None: | |
if is_selector(selector_or_value=value): | |
return None | |
validate_value_is_empty_or_positive_number(value=value, field_name=field_name) | |
def validate_value_is_empty_or_positive_number( | |
value: Any, field_name: str, error: Type[Exception] = ValueError | |
) -> None: | |
validate_field_has_given_type( | |
field_name=field_name, | |
allowed_types=[type(None), int, float], | |
value=value, | |
error=error, | |
) | |
if value is None: | |
return None | |
if value <= 0: | |
raise error(f"Parameter `{field_name}` must be positive (> 0)") | |
def validate_field_is_list_of_selectors( | |
value: Any, field_name: str, error: Type[Exception] = ValueError | |
) -> None: | |
if not issubclass(type(value), list): | |
raise error(f"`{field_name}` field must be list") | |
if any(not is_selector(selector_or_value=e) for e in value): | |
raise error(f"Parameter `{field_name}` must be a list of selectors") | |
def validate_field_is_empty_or_selector_or_list_of_string( | |
value: Any, field_name: str | |
) -> None: | |
if is_selector(selector_or_value=value) or value is None: | |
return value | |
validate_field_is_list_of_string(value=value, field_name=field_name) | |
def validate_field_is_list_of_string( | |
value: Any, field_name: str, error: Type[Exception] = ValueError | |
) -> None: | |
if not issubclass(type(value), list): | |
raise error(f"`{field_name}` field must be list") | |
if any(not issubclass(type(e), str) for e in value): | |
raise error(f"Parameter `{field_name}` must be a list of string") | |
def validate_field_is_selector_or_one_of_values( | |
value: Any, field_name: str, selected_values: set | |
) -> None: | |
if is_selector(selector_or_value=value) or value is None: | |
return value | |
validate_field_is_one_of_selected_values( | |
value=value, field_name=field_name, selected_values=selected_values | |
) | |
def validate_field_is_one_of_selected_values( | |
value: Any, | |
field_name: str, | |
selected_values: set, | |
error: Type[Exception] = ValueError, | |
) -> None: | |
if value not in selected_values: | |
raise error( | |
f"Value of field `{field_name}` must be in {selected_values}. Found: {value}" | |
) | |
def validate_field_is_selector_or_has_given_type( | |
value: Any, field_name: str, allowed_types: List[type] | |
) -> None: | |
if is_selector(selector_or_value=value): | |
return None | |
validate_field_has_given_type( | |
field_name=field_name, allowed_types=allowed_types, value=value | |
) | |
return None | |
def validate_field_has_given_type( | |
value: Any, | |
field_name: str, | |
allowed_types: List[type], | |
error: Type[Exception] = ValueError, | |
) -> None: | |
if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types): | |
raise error( | |
f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}" | |
) | |
def validate_image_biding(value: Any, field_name: str = "image") -> None: | |
try: | |
if not issubclass(type(value), list): | |
value = [value] | |
for e in value: | |
InferenceRequestImage.model_validate(e) | |
except (ValueError, ValidationError) as error: | |
raise VariableTypeError( | |
f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`" | |
) from error | |
def validate_selector_is_inference_parameter( | |
step_type: str, | |
field_name: str, | |
input_step: GraphNone, | |
applicable_fields: Set[str], | |
) -> None: | |
if field_name not in applicable_fields: | |
return None | |
input_step_type = input_step.get_type() | |
if input_step_type not in {"InferenceParameter"}: | |
raise InvalidStepInputDetected( | |
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. " | |
f"Expected: `InferenceParameter`" | |
) | |
def validate_selector_holds_image( | |
step_type: str, | |
field_name: str, | |
input_step: GraphNone, | |
applicable_fields: Optional[Set[str]] = None, | |
) -> None: | |
if applicable_fields is None: | |
applicable_fields = {"image"} | |
if field_name not in applicable_fields: | |
return None | |
if input_step.get_type() not in STEPS_WITH_IMAGE: | |
raise InvalidStepInputDetected( | |
f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. " | |
f"Expected: {STEPS_WITH_IMAGE}" | |
) | |
def validate_selector_holds_detections( | |
step_name: str, | |
image_selector: Optional[str], | |
detections_selector: str, | |
field_name: str, | |
input_step: GraphNone, | |
applicable_fields: Optional[Set[str]] = None, | |
) -> None: | |
if applicable_fields is None: | |
applicable_fields = {"detections"} | |
if field_name not in applicable_fields: | |
return None | |
if input_step.get_type() not in { | |
"ObjectDetectionModel", | |
"KeypointsDetectionModel", | |
"InstanceSegmentationModel", | |
"DetectionFilter", | |
"DetectionsConsensus", | |
"DetectionOffset", | |
"YoloWorld", | |
}: | |
raise InvalidStepInputDetected( | |
f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. " | |
f"Step requires detection-based output." | |
) | |
if get_last_selector_chunk(detections_selector) != "predictions": | |
raise InvalidStepInputDetected( | |
f"Step with name {step_name} must take as input step output of name `predictions`" | |
) | |
if not hasattr(input_step, "image") or image_selector is None: | |
# Here, filter do not hold the reference to image, we skip the check in this case | |
return None | |
input_step_image_reference = input_step.image | |
if image_selector != input_step_image_reference: | |
raise InvalidStepInputDetected( | |
f"Step step with name {step_name} was given detections reference that is bound to different image: " | |
f"step.image: {image_selector}, detections step image: {input_step_image_reference}" | |
) | |
def is_selector(selector_or_value: Any) -> bool: | |
if not issubclass(type(selector_or_value), str): | |
return False | |
return selector_or_value.startswith("$") | |
def get_last_selector_chunk(selector: str) -> str: | |
return selector.split(".")[-1] | |