Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,967 Bytes
d9a2e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import torch
from ultralytics import YOLO
from modules.AutoDetailer import SEGS, AD_util, tensor_util
from typing import List, Tuple, Optional
class UltraBBoxDetector:
"""#### Class to detect bounding boxes using a YOLO model."""
bbox_model: Optional[YOLO] = None
def __init__(self, bbox_model: YOLO):
"""#### Initialize the UltraBBoxDetector with a YOLO model.
#### Args:
- `bbox_model` (YOLO): The YOLO model to use for detection.
"""
self.bbox_model = bbox_model
def detect(
self,
image: torch.Tensor,
threshold: float,
dilation: int,
crop_factor: float,
drop_size: int = 1,
detailer_hook: Optional[callable] = None,
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
"""#### Detect bounding boxes in an image.
#### Args:
- `image` (torch.Tensor): The input image tensor.
- `threshold` (float): The detection threshold.
- `dilation` (int): The dilation factor for masks.
- `crop_factor` (float): The crop factor for bounding boxes.
- `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1.
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
#### Returns:
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
"""
drop_size = max(drop_size, 1)
detected_results = AD_util.inference_bbox(
self.bbox_model, tensor_util.tensor2pil(image), threshold
)
segmasks = AD_util.create_segmasks(detected_results)
if dilation > 0:
segmasks = AD_util.dilate_masks(segmasks, dilation)
items = []
h = image.shape[1]
w = image.shape[2]
for x, label in zip(segmasks, detected_results[0]):
item_bbox = x[0]
item_mask = x[1]
y1, x1, y2, x2 = item_bbox
if (
x2 - x1 > drop_size and y2 - y1 > drop_size
): # minimum dimension must be (2,2) to avoid squeeze issue
crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor)
cropped_image = AD_util.crop_image(image, crop_region)
cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region)
confidence = x[2]
item = SEGS.SEG(
cropped_image,
cropped_mask,
confidence,
crop_region,
item_bbox,
label,
None,
)
items.append(item)
shape = image.shape[1], image.shape[2]
segs = shape, items
return segs
class UltraSegmDetector:
"""#### Class to detect segments using a YOLO model."""
bbox_model: Optional[YOLO] = None
def __init__(self, bbox_model: YOLO):
"""#### Initialize the UltraSegmDetector with a YOLO model.
#### Args:
- `bbox_model` (YOLO): The YOLO model to use for detection.
"""
self.bbox_model = bbox_model
class NO_SEGM_DETECTOR:
"""#### Placeholder class for no segment detector."""
pass
class UltralyticsDetectorProvider:
"""#### Class to provide YOLO models for detection."""
def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]:
"""#### Load a YOLO model and return detectors.
#### Args:
- `model_name` (str): The name of the YOLO model to load.
#### Returns:
- `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors.
"""
model = AD_util.load_yolo("./_internal/yolos/" + model_name)
return UltraBBoxDetector(model), UltraSegmDetector(model)
class BboxDetectorForEach:
"""#### Class to detect bounding boxes for each segment."""
def doit(
self,
bbox_detector: UltraBBoxDetector,
image: torch.Tensor,
threshold: float,
dilation: int,
crop_factor: float,
drop_size: int,
labels: Optional[str] = None,
detailer_hook: Optional[callable] = None,
) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
"""#### Detect bounding boxes for each segment in an image.
#### Args:
- `bbox_detector` (UltraBBoxDetector): The bounding box detector.
- `image` (torch.Tensor): The input image tensor.
- `threshold` (float): The detection threshold.
- `dilation` (int): The dilation factor for masks.
- `crop_factor` (float): The crop factor for bounding boxes.
- `drop_size` (int): The minimum size of bounding boxes to keep.
- `labels` (str, optional): The labels to filter. Defaults to None.
- `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
#### Returns:
- `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
"""
segs = bbox_detector.detect(
image, threshold, dilation, crop_factor, drop_size, detailer_hook
)
if labels is not None and labels != "":
labels = labels.split(",")
if len(labels) > 0:
segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels)
return segs
class WildcardChooser:
"""#### Class to choose wildcards for segments."""
def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool):
"""#### Initialize the WildcardChooser.
#### Args:
- `items` (List[Tuple[None, str]]): The list of items to choose from.
- `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted.
"""
self.i = 0
self.items = items
self.randomize_when_exhaust = randomize_when_exhaust
def get(self, seg: SEGS.SEG) -> Tuple[None, str]:
"""#### Get the next item from the list.
#### Args:
- `seg` (SEGS.SEG): The segment.
#### Returns:
- `Tuple[None, str]`: The next item from the list.
"""
item = self.items[self.i]
self.i += 1
return item
def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]:
"""#### Process a wildcard for segments.
#### Args:
- `wildcard` (str): The wildcard.
#### Returns:
- `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser.
"""
return None, WildcardChooser([(None, wildcard)], False)
|