Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,789 Bytes
d9a2e19 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
from collections import namedtuple
import numpy as np
import torch
from modules.AutoDetailer import mask_util
SEG = namedtuple(
"SEG",
[
"cropped_image",
"cropped_mask",
"confidence",
"crop_region",
"bbox",
"label",
"control_net_wrapper",
],
defaults=[None],
)
def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple:
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `mask` (torch.Tensor): The mask tensor.
#### Returns:
- `tuple`: A tuple containing the original segmentation and the updated items.
"""
mask = mask_util.make_2d_mask(mask)
items = []
mask = (mask.cpu().numpy() * 255).astype(np.uint8)
for seg in segs[1]:
cropped_mask = (seg.cropped_mask * 255).astype(np.uint8)
crop_region = seg.crop_region
cropped_mask2 = mask[
crop_region[1] : crop_region[3], crop_region[0] : crop_region[2]
]
new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2)
new_mask = new_mask.astype(np.float32) / 255.0
item = SEG(
seg.cropped_image,
new_mask,
seg.confidence,
seg.crop_region,
seg.bbox,
seg.label,
None,
)
items.append(item)
return segs[0], items
class SegsBitwiseAndMask:
"""#### Class to apply bitwise AND operation between segmentation masks and a given mask."""
def doit(self, segs: tuple, mask: torch.Tensor) -> tuple:
"""#### Apply bitwise AND operation between segmentation masks and a given mask.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `mask` (torch.Tensor): The mask tensor.
#### Returns:
- `tuple`: A tuple containing the original segmentation and the updated items.
"""
return (segs_bitwise_and_mask(segs, mask),)
class SEGSLabelFilter:
"""#### Class to filter segmentation labels."""
@staticmethod
def filter(segs: tuple, labels: list) -> tuple:
"""#### Filter segmentation labels.
#### Args:
- `segs` (tuple): A tuple containing segmentation information.
- `labels` (list): A list of labels to filter.
#### Returns:
- `tuple`: A tuple containing the original segmentation and an empty list.
"""
labels = set([label.strip() for label in labels])
return (
segs,
(segs[0], []),
)
|