Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
from collections import namedtuple
import numpy as np
import torch
from modules.AutoDetailer import mask_util
SEG = namedtuple(
"SEG",
[
"cropped_image",
"cropped_mask",
"confidence",
"crop_region",
"bbox",
"label",
"control_net_wrapper",
],
defaults=[None],
)
def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple:
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `mask` (torch.Tensor): The mask tensor.
#### Returns:
- `tuple`: A tuple containing the original segmentation and the updated items.
"""
mask = mask_util.make_2d_mask(mask)
items = []
mask = (mask.cpu().numpy() * 255).astype(np.uint8)
for seg in segs[1]:
cropped_mask = (seg.cropped_mask * 255).astype(np.uint8)
crop_region = seg.crop_region
cropped_mask2 = mask[
crop_region[1] : crop_region[3], crop_region[0] : crop_region[2]
]
new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2)
new_mask = new_mask.astype(np.float32) / 255.0
item = SEG(
seg.cropped_image,
new_mask,
seg.confidence,
seg.crop_region,
seg.bbox,
seg.label,
None,
)
items.append(item)
return segs[0], items
class SegsBitwiseAndMask:
"""#### Class to apply bitwise AND operation between segmentation masks and a given mask."""
def doit(self, segs: tuple, mask: torch.Tensor) -> tuple:
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `mask` (torch.Tensor): The mask tensor.
#### Returns:
- `tuple`: A tuple containing the original segmentation and the updated items.
"""
return (segs_bitwise_and_mask(segs, mask),)
class SEGSLabelFilter:
"""#### Class to filter segmentation labels."""
@staticmethod
def filter(segs: tuple, labels: list) -> tuple:
"""#### Filter segmentation labels.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `labels` (list): A list of labels to filter.
#### Returns:
- `tuple`: A tuple containing the original segmentation and an empty list.
"""
labels = set([label.strip() for label in labels])
return (
segs,
(segs[0], []),
)