Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from ultralytics import YOLO | |
from modules.AutoDetailer import SEGS, AD_util, tensor_util | |
from typing import List, Tuple, Optional | |
class UltraBBoxDetector: | |
"""#### Class to detect bounding boxes using a YOLO model.""" | |
bbox_model: Optional[YOLO] = None | |
def __init__(self, bbox_model: YOLO): | |
"""#### Initialize the UltraBBoxDetector with a YOLO model. | |
#### Args: | |
- `bbox_model` (YOLO): The YOLO model to use for detection. | |
""" | |
self.bbox_model = bbox_model | |
def detect( | |
self, | |
image: torch.Tensor, | |
threshold: float, | |
dilation: int, | |
crop_factor: float, | |
drop_size: int = 1, | |
detailer_hook: Optional[callable] = None, | |
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: | |
"""#### Detect bounding boxes in an image. | |
#### Args: | |
- `image` (torch.Tensor): The input image tensor. | |
- `threshold` (float): The detection threshold. | |
- `dilation` (int): The dilation factor for masks. | |
- `crop_factor` (float): The crop factor for bounding boxes. | |
- `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1. | |
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. | |
#### Returns: | |
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. | |
""" | |
drop_size = max(drop_size, 1) | |
detected_results = AD_util.inference_bbox( | |
self.bbox_model, tensor_util.tensor2pil(image), threshold | |
) | |
segmasks = AD_util.create_segmasks(detected_results) | |
if dilation > 0: | |
segmasks = AD_util.dilate_masks(segmasks, dilation) | |
items = [] | |
h = image.shape[1] | |
w = image.shape[2] | |
for x, label in zip(segmasks, detected_results[0]): | |
item_bbox = x[0] | |
item_mask = x[1] | |
y1, x1, y2, x2 = item_bbox | |
if ( | |
x2 - x1 > drop_size and y2 - y1 > drop_size | |
): # minimum dimension must be (2,2) to avoid squeeze issue | |
crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor) | |
cropped_image = AD_util.crop_image(image, crop_region) | |
cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region) | |
confidence = x[2] | |
item = SEGS.SEG( | |
cropped_image, | |
cropped_mask, | |
confidence, | |
crop_region, | |
item_bbox, | |
label, | |
None, | |
) | |
items.append(item) | |
shape = image.shape[1], image.shape[2] | |
segs = shape, items | |
return segs | |
class UltraSegmDetector: | |
"""#### Class to detect segments using a YOLO model.""" | |
bbox_model: Optional[YOLO] = None | |
def __init__(self, bbox_model: YOLO): | |
"""#### Initialize the UltraSegmDetector with a YOLO model. | |
#### Args: | |
- `bbox_model` (YOLO): The YOLO model to use for detection. | |
""" | |
self.bbox_model = bbox_model | |
class NO_SEGM_DETECTOR: | |
"""#### Placeholder class for no segment detector.""" | |
pass | |
class UltralyticsDetectorProvider: | |
"""#### Class to provide YOLO models for detection.""" | |
def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]: | |
"""#### Load a YOLO model and return detectors. | |
#### Args: | |
- `model_name` (str): The name of the YOLO model to load. | |
#### Returns: | |
- `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors. | |
""" | |
model = AD_util.load_yolo("./_internal/yolos/" + model_name) | |
return UltraBBoxDetector(model), UltraSegmDetector(model) | |
class BboxDetectorForEach: | |
"""#### Class to detect bounding boxes for each segment.""" | |
def doit( | |
self, | |
bbox_detector: UltraBBoxDetector, | |
image: torch.Tensor, | |
threshold: float, | |
dilation: int, | |
crop_factor: float, | |
drop_size: int, | |
labels: Optional[str] = None, | |
detailer_hook: Optional[callable] = None, | |
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: | |
"""#### Detect bounding boxes for each segment in an image. | |
#### Args: | |
- `bbox_detector` (UltraBBoxDetector): The bounding box detector. | |
- `image` (torch.Tensor): The input image tensor. | |
- `threshold` (float): The detection threshold. | |
- `dilation` (int): The dilation factor for masks. | |
- `crop_factor` (float): The crop factor for bounding boxes. | |
- `drop_size` (int): The minimum size of bounding boxes to keep. | |
- `labels` (str, optional): The labels to filter. Defaults to None. | |
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. | |
#### Returns: | |
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. | |
""" | |
segs = bbox_detector.detect( | |
image, threshold, dilation, crop_factor, drop_size, detailer_hook | |
) | |
if labels is not None and labels != "": | |
labels = labels.split(",") | |
if len(labels) > 0: | |
segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels) | |
return segs | |
class WildcardChooser: | |
"""#### Class to choose wildcards for segments.""" | |
def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool): | |
"""#### Initialize the WildcardChooser. | |
#### Args: | |
- `items` (List[Tuple[None, str]]): The list of items to choose from. | |
- `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted. | |
""" | |
self.i = 0 | |
self.items = items | |
self.randomize_when_exhaust = randomize_when_exhaust | |
def get(self, seg: SEGS.SEG) -> Tuple[None, str]: | |
"""#### Get the next item from the list. | |
#### Args: | |
- `seg` (SEGS.SEG): The segment. | |
#### Returns: | |
- `Tuple[None, str]`: The next item from the list. | |
""" | |
item = self.items[self.i] | |
self.i += 1 | |
return item | |
def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]: | |
"""#### Process a wildcard for segments. | |
#### Args: | |
- `wildcard` (str): The wildcard. | |
#### Returns: | |
- `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser. | |
""" | |
return None, WildcardChooser([(None, wildcard)], False) | |