Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from ultralytics import YOLO | |
| from modules.AutoDetailer import SEGS, AD_util, tensor_util | |
| from typing import List, Tuple, Optional | |
| class UltraBBoxDetector: | |
| """#### Class to detect bounding boxes using a YOLO model.""" | |
| bbox_model: Optional[YOLO] = None | |
| def __init__(self, bbox_model: YOLO): | |
| """#### Initialize the UltraBBoxDetector with a YOLO model. | |
| #### Args: | |
| - `bbox_model` (YOLO): The YOLO model to use for detection. | |
| """ | |
| self.bbox_model = bbox_model | |
| def detect( | |
| self, | |
| image: torch.Tensor, | |
| threshold: float, | |
| dilation: int, | |
| crop_factor: float, | |
| drop_size: int = 1, | |
| detailer_hook: Optional[callable] = None, | |
| ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: | |
| """#### Detect bounding boxes in an image. | |
| #### Args: | |
| - `image` (torch.Tensor): The input image tensor. | |
| - `threshold` (float): The detection threshold. | |
| - `dilation` (int): The dilation factor for masks. | |
| - `crop_factor` (float): The crop factor for bounding boxes. | |
| - `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1. | |
| - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. | |
| #### Returns: | |
| - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. | |
| """ | |
| drop_size = max(drop_size, 1) | |
| detected_results = AD_util.inference_bbox( | |
| self.bbox_model, tensor_util.tensor2pil(image), threshold | |
| ) | |
| segmasks = AD_util.create_segmasks(detected_results) | |
| if dilation > 0: | |
| segmasks = AD_util.dilate_masks(segmasks, dilation) | |
| items = [] | |
| h = image.shape[1] | |
| w = image.shape[2] | |
| for x, label in zip(segmasks, detected_results[0]): | |
| item_bbox = x[0] | |
| item_mask = x[1] | |
| y1, x1, y2, x2 = item_bbox | |
| if ( | |
| x2 - x1 > drop_size and y2 - y1 > drop_size | |
| ): # minimum dimension must be (2,2) to avoid squeeze issue | |
| crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor) | |
| cropped_image = AD_util.crop_image(image, crop_region) | |
| cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region) | |
| confidence = x[2] | |
| item = SEGS.SEG( | |
| cropped_image, | |
| cropped_mask, | |
| confidence, | |
| crop_region, | |
| item_bbox, | |
| label, | |
| None, | |
| ) | |
| items.append(item) | |
| shape = image.shape[1], image.shape[2] | |
| segs = shape, items | |
| return segs | |
| class UltraSegmDetector: | |
| """#### Class to detect segments using a YOLO model.""" | |
| bbox_model: Optional[YOLO] = None | |
| def __init__(self, bbox_model: YOLO): | |
| """#### Initialize the UltraSegmDetector with a YOLO model. | |
| #### Args: | |
| - `bbox_model` (YOLO): The YOLO model to use for detection. | |
| """ | |
| self.bbox_model = bbox_model | |
| class NO_SEGM_DETECTOR: | |
| """#### Placeholder class for no segment detector.""" | |
| pass | |
| class UltralyticsDetectorProvider: | |
| """#### Class to provide YOLO models for detection.""" | |
| def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]: | |
| """#### Load a YOLO model and return detectors. | |
| #### Args: | |
| - `model_name` (str): The name of the YOLO model to load. | |
| #### Returns: | |
| - `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors. | |
| """ | |
| model = AD_util.load_yolo("./_internal/yolos/" + model_name) | |
| return UltraBBoxDetector(model), UltraSegmDetector(model) | |
| class BboxDetectorForEach: | |
| """#### Class to detect bounding boxes for each segment.""" | |
| def doit( | |
| self, | |
| bbox_detector: UltraBBoxDetector, | |
| image: torch.Tensor, | |
| threshold: float, | |
| dilation: int, | |
| crop_factor: float, | |
| drop_size: int, | |
| labels: Optional[str] = None, | |
| detailer_hook: Optional[callable] = None, | |
| ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: | |
| """#### Detect bounding boxes for each segment in an image. | |
| #### Args: | |
| - `bbox_detector` (UltraBBoxDetector): The bounding box detector. | |
| - `image` (torch.Tensor): The input image tensor. | |
| - `threshold` (float): The detection threshold. | |
| - `dilation` (int): The dilation factor for masks. | |
| - `crop_factor` (float): The crop factor for bounding boxes. | |
| - `drop_size` (int): The minimum size of bounding boxes to keep. | |
| - `labels` (str, optional): The labels to filter. Defaults to None. | |
| - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. | |
| #### Returns: | |
| - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. | |
| """ | |
| segs = bbox_detector.detect( | |
| image, threshold, dilation, crop_factor, drop_size, detailer_hook | |
| ) | |
| if labels is not None and labels != "": | |
| labels = labels.split(",") | |
| if len(labels) > 0: | |
| segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels) | |
| return segs | |
| class WildcardChooser: | |
| """#### Class to choose wildcards for segments.""" | |
| def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool): | |
| """#### Initialize the WildcardChooser. | |
| #### Args: | |
| - `items` (List[Tuple[None, str]]): The list of items to choose from. | |
| - `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted. | |
| """ | |
| self.i = 0 | |
| self.items = items | |
| self.randomize_when_exhaust = randomize_when_exhaust | |
| def get(self, seg: SEGS.SEG) -> Tuple[None, str]: | |
| """#### Get the next item from the list. | |
| #### Args: | |
| - `seg` (SEGS.SEG): The segment. | |
| #### Returns: | |
| - `Tuple[None, str]`: The next item from the list. | |
| """ | |
| item = self.items[self.i] | |
| self.i += 1 | |
| return item | |
| def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]: | |
| """#### Process a wildcard for segments. | |
| #### Args: | |
| - `wildcard` (str): The wildcard. | |
| #### Returns: | |
| - `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser. | |
| """ | |
| return None, WildcardChooser([(None, wildcard)], False) | |