File size: 9,771 Bytes
d9a2e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
import os
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import torch

from modules.AutoDetailer import mask_util
from modules.Device import Device


def sam_predict(

    predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float

) -> list:
    """#### Predict masks using SAM.



    #### Args:

        - `predictor` (SamPredictor): The SAM predictor.

        - `points` (list): List of points.

        - `plabs` (list): List of point labels.

        - `bbox` (list): Bounding box.

        - `threshold` (float): Threshold for mask selection.



    #### Returns:

        - `list`: List of predicted masks.

    """
    point_coords = None if not points else np.array(points)
    point_labels = None if not plabs else np.array(plabs)

    box = np.array([bbox]) if bbox is not None else None

    cur_masks, scores, _ = predictor.predict(
        point_coords=point_coords, point_labels=point_labels, box=box
    )

    total_masks = []

    selected = False
    max_score = 0
    max_mask = None
    for idx in range(len(scores)):
        if scores[idx] > max_score:
            max_score = scores[idx]
            max_mask = cur_masks[idx]

        if scores[idx] >= threshold:
            selected = True
            total_masks.append(cur_masks[idx])
        else:
            pass

    if not selected and max_mask is not None:
        total_masks.append(max_mask)

    return total_masks


def is_same_device(a: torch.device, b: torch.device) -> bool:
    """#### Check if two devices are the same.



    #### Args:

        - `a` (torch.device): The first device.

        - `b` (torch.device): The second device.



    #### Returns:

        - `bool`: Whether the devices are the same.

    """
    a_device = torch.device(a) if isinstance(a, str) else a
    b_device = torch.device(b) if isinstance(b, str) else b
    return a_device.type == b_device.type and a_device.index == b_device.index


class SafeToGPU:
    """#### Class to safely move objects to GPU."""

    def __init__(self, size: int):
        self.size = size

    def to_device(self, obj: torch.nn.Module, device: torch.device) -> None:
        """#### Move an object to a device.



        #### Args:

            - `obj` (torch.nn.Module): The object to move.

            - `device` (torch.device): The target device.

        """
        if is_same_device(device, "cpu"):
            obj.to(device)
        else:
            if is_same_device(obj.device, "cpu"):  # cpu to gpu
                Device.free_memory(self.size * 1.3, device)
                if Device.get_free_memory(device) > self.size * 1.3:
                    try:
                        obj.to(device)
                    except:
                        print(
                            f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]"
                        )
                else:
                    print(
                        f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]"
                    )


class SAMWrapper:
    """#### Wrapper class for SAM model."""

    def __init__(

        self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None

    ):
        self.model = model
        self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU()
        self.is_auto_mode = is_auto_mode

    def prepare_device(self) -> None:
        """#### Prepare the device for the model."""
        if self.is_auto_mode:
            device = Device.get_torch_device()
            self.safe_to_gpu.to_device(self.model, device=device)

    def release_device(self) -> None:
        """#### Release the device from the model."""
        if self.is_auto_mode:
            self.model.to(device="cpu")

    def predict(

        self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float

    ) -> list:
        """#### Predict masks using the SAM model.



        #### Args:

            - `image` (np.ndarray): The input image.

            - `points` (list): List of points.

            - `plabs` (list): List of point labels.

            - `bbox` (list): Bounding box.

            - `threshold` (float): Threshold for mask selection.



        #### Returns:

            - `list`: List of predicted masks.

        """
        predictor = SamPredictor(self.model)
        predictor.set_image(image, "RGB")

        return sam_predict(predictor, points, plabs, bbox, threshold)


class SAMLoader:
    """#### Class to load SAM models."""

    def load_model(self, model_name: str, device_mode: str = "auto") -> tuple:
        """#### Load a SAM model.



        #### Args:

            - `model_name` (str): The name of the model.

            - `device_mode` (str, optional): The device mode. Defaults to "auto".



        #### Returns:

            - `tuple`: The loaded SAM model.

        """
        modelname = "./_internal/yolos/" + model_name

        if "vit_h" in model_name:
            model_kind = "vit_h"
        elif "vit_l" in model_name:
            model_kind = "vit_l"
        else:
            model_kind = "vit_b"

        sam = sam_model_registry[model_kind](checkpoint=modelname)
        size = os.path.getsize(modelname)
        safe_to = SafeToGPU(size)

        # Unless user explicitly wants to use CPU, we use GPU
        device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU"

        if device_mode == "Prefer GPU":
            safe_to.to_device(sam, device)

        is_auto_mode = device_mode == "AUTO"

        sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to)
        sam.sam_wrapper = sam_obj

        print(f"Loads SAM model: {modelname} (device:{device_mode})")
        return (sam,)


def make_sam_mask(

    sam: SAMWrapper,

    segs: tuple,

    image: torch.Tensor,

    detection_hint: bool,

    dilation: int,

    threshold: float,

    bbox_expansion: int,

    mask_hint_threshold: float,

    mask_hint_use_negative: bool,

) -> torch.Tensor:
    """#### Create a SAM mask.



    #### Args:

        - `sam` (SAMWrapper): The SAM wrapper.

        - `segs` (tuple): Segmentation information.

        - `image` (torch.Tensor): The input image.

        - `detection_hint` (bool): Whether to use detection hint.

        - `dilation` (int): Dilation value.

        - `threshold` (float): Threshold for mask selection.

        - `bbox_expansion` (int): Bounding box expansion value.

        - `mask_hint_threshold` (float): Mask hint threshold.

        - `mask_hint_use_negative` (bool): Whether to use negative mask hint.



    #### Returns:

        - `torch.Tensor`: The created SAM mask.

    """
    sam_obj = sam.sam_wrapper
    sam_obj.prepare_device()

    try:
        image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)

        total_masks = []
        # seg_shape = segs[0]
        segs = segs[1]
        for i in range(len(segs)):
            bbox = segs[i].bbox
            center = mask_util.center_of_bbox(bbox)
            x1 = max(bbox[0] - bbox_expansion, 0)
            y1 = max(bbox[1] - bbox_expansion, 0)
            x2 = min(bbox[2] + bbox_expansion, image.shape[1])
            y2 = min(bbox[3] + bbox_expansion, image.shape[0])
            dilated_bbox = [x1, y1, x2, y2]
            points = []
            plabs = []
            points.append(center)
            plabs = [1]  # 1 = foreground point, 0 = background point
            detected_masks = sam_obj.predict(
                image, points, plabs, dilated_bbox, threshold
            )
            total_masks += detected_masks

        # merge every collected masks
        mask = mask_util.combine_masks2(total_masks)

    finally:
        sam_obj.release_device()

    if mask is not None:
        mask = mask.float()
        mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation)
        mask = torch.from_numpy(mask)

        mask = mask_util.make_3d_mask(mask)
        return mask
    else:
        return None


class SAMDetectorCombined:
    """#### Class to combine SAM detection."""

    def doit(

        self,

        sam_model: SAMWrapper,

        segs: tuple,

        image: torch.Tensor,

        detection_hint: bool,

        dilation: int,

        threshold: float,

        bbox_expansion: int,

        mask_hint_threshold: float,

        mask_hint_use_negative: bool,

    ) -> tuple:
        """#### Combine SAM detection.



        #### Args:

            - `sam_model` (SAMWrapper): The SAM wrapper.

            - `segs` (tuple): Segmentation information.

            - `image` (torch.Tensor): The input image.

            - `detection_hint` (bool): Whether to use detection hint.

            - `dilation` (int): Dilation value.

            - `threshold` (float): Threshold for mask selection.

            - `bbox_expansion` (int): Bounding box expansion value.

            - `mask_hint_threshold` (float): Mask hint threshold.

            - `mask_hint_use_negative` (bool): Whether to use negative mask hint.



        #### Returns:

            - `tuple`: The combined SAM detection result.

        """
        sam = make_sam_mask(
            sam_model,
            segs,
            image,
            detection_hint,
            dilation,
            threshold,
            bbox_expansion,
            mask_hint_threshold,
            mask_hint_use_negative,
        )
        if sam is not None:
            return (sam,)
        else:
            return None