Spaces:
Running
on
Zero
Running
on
Zero
from collections import namedtuple | |
import numpy as np | |
import torch | |
from modules.AutoDetailer import mask_util | |
SEG = namedtuple( | |
"SEG", | |
[ | |
"cropped_image", | |
"cropped_mask", | |
"confidence", | |
"crop_region", | |
"bbox", | |
"label", | |
"control_net_wrapper", | |
], | |
defaults=[None], | |
) | |
def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple: | |
"""#### Apply bitwise AND operation between segmentation masks and a given mask. | |
#### Args: | |
- `segs` (tuple): A tuple containing segmentation information. | |
- `mask` (torch.Tensor): The mask tensor. | |
#### Returns: | |
- `tuple`: A tuple containing the original segmentation and the updated items. | |
""" | |
mask = mask_util.make_2d_mask(mask) | |
items = [] | |
mask = (mask.cpu().numpy() * 255).astype(np.uint8) | |
for seg in segs[1]: | |
cropped_mask = (seg.cropped_mask * 255).astype(np.uint8) | |
crop_region = seg.crop_region | |
cropped_mask2 = mask[ | |
crop_region[1] : crop_region[3], crop_region[0] : crop_region[2] | |
] | |
new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2) | |
new_mask = new_mask.astype(np.float32) / 255.0 | |
item = SEG( | |
seg.cropped_image, | |
new_mask, | |
seg.confidence, | |
seg.crop_region, | |
seg.bbox, | |
seg.label, | |
None, | |
) | |
items.append(item) | |
return segs[0], items | |
class SegsBitwiseAndMask: | |
"""#### Class to apply bitwise AND operation between segmentation masks and a given mask.""" | |
def doit(self, segs: tuple, mask: torch.Tensor) -> tuple: | |
"""#### Apply bitwise AND operation between segmentation masks and a given mask. | |
#### Args: | |
- `segs` (tuple): A tuple containing segmentation information. | |
- `mask` (torch.Tensor): The mask tensor. | |
#### Returns: | |
- `tuple`: A tuple containing the original segmentation and the updated items. | |
""" | |
return (segs_bitwise_and_mask(segs, mask),) | |
class SEGSLabelFilter: | |
"""#### Class to filter segmentation labels.""" | |
def filter(segs: tuple, labels: list) -> tuple: | |
"""#### Filter segmentation labels. | |
#### Args: | |
- `segs` (tuple): A tuple containing segmentation information. | |
- `labels` (list): A list of labels to filter. | |
#### Returns: | |
- `tuple`: A tuple containing the original segmentation and an empty list. | |
""" | |
labels = set([label.strip() for label in labels]) | |
return ( | |
segs, | |
(segs[0], []), | |
) | |