File size: 12,593 Bytes
b34d1d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
from typing import Optional, Tuple, Union

import mmcv
import mmengine
import numpy as np
import pycocotools.mask as maskUtils
import torch

from mmcv.transforms.base import BaseTransform
from mmdet.registry import TRANSFORMS
from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations
from mmdet.structures.bbox import autocast_box_type
from mmdet.structures.mask import BitmapMasks
from mmdet.datasets.transforms import LoadPanopticAnnotations
from mmengine.fileio import get

from seg.models.utils import NO_OBJ


@TRANSFORMS.register_module()
class LoadPanopticAnnotationsHB(LoadPanopticAnnotations):
    def _load_masks_and_semantic_segs(self, results: dict) -> None:
        """Private function to load mask and semantic segmentation annotations.

        In gt_semantic_seg, the foreground label is from ``0`` to
        ``num_things - 1``, the background label is from ``num_things`` to
        ``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``).

        Args:
            results (dict): Result dict from :obj:``mmdet.CustomDataset``.
        """
        # seg_map_path is None, when inference on the dataset without gts.
        if results.get('seg_map_path', None) is None:
            return

        img_bytes = get(
            results['seg_map_path'], backend_args=self.backend_args)
        pan_png = mmcv.imfrombytes(
            img_bytes, flag='color', channel_order='rgb').squeeze()
        pan_png = self.rgb2id(pan_png)

        gt_masks = []
        gt_seg = np.zeros_like(pan_png).astype(np.int32) + NO_OBJ  # 255 as ignore

        for segment_info in results['segments_info']:
            mask = (pan_png == segment_info['id'])
            gt_seg = np.where(mask, segment_info['category'], gt_seg)

            # The legal thing masks
            if segment_info.get('is_thing'):
                gt_masks.append(mask.astype(np.uint8))

        if self.with_mask:
            h, w = results['ori_shape']
            gt_masks = BitmapMasks(gt_masks, h, w)
            results['gt_masks'] = gt_masks

        if self.with_seg:
            results['gt_seg_map'] = gt_seg


@TRANSFORMS.register_module()
class LoadVideoSegAnnotations(LoadPanopticAnnotations):

    def __init__(
            self,
            **kwargs
    ) -> None:
        super().__init__(**kwargs)

    def _load_instances_ids(self, results: dict) -> None:
        """Private function to load instances id annotations.

        Args:
            results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``.

        Returns:
            dict: The dict containing instances id annotations.
        """
        gt_instances_ids = []
        for instance in results['instances']:
            gt_instances_ids.append(instance['instance_id'])
        results['gt_instances_ids'] = np.array(
            gt_instances_ids, dtype=np.int32)

    def _load_masks_and_semantic_segs(self, results: dict) -> None:
        h, w = results['ori_shape']
        gt_masks = []
        gt_seg = np.zeros((h, w), dtype=np.int32) + NO_OBJ

        for segment_info in results['segments_info']:
            mask = maskUtils.decode(segment_info['mask'])
            gt_seg = np.where(mask, segment_info['category'], gt_seg)

            # The legal thing masks
            if segment_info.get('is_thing'):
                gt_masks.append(mask.astype(np.uint8))

        if self.with_mask:
            h, w = results['ori_shape']
            gt_masks = BitmapMasks(gt_masks, h, w)
            results['gt_masks'] = gt_masks

        if self.with_seg:
            results['gt_seg_map'] = gt_seg

    def transform(self, results: dict) -> dict:
        """Function to load multiple types panoptic annotations.

        Args:
            results (dict): Result dict from :obj:``mmdet.CustomDataset``.

        Returns:
            dict: The dict contains loaded bounding box, label, mask and
                semantic segmentation annotations.
        """

        super().transform(results)
        self._load_instances_ids(results)
        return results


@TRANSFORMS.register_module()
class LoadJSONFromFile(BaseTransform):
    """Load an json from file.

    Required Keys:

    - info_path

    Modified Keys:

    Args:
        backend_args (dict, optional): Instantiates the corresponding file
            backend. It may contain `backend` key to specify the file
            backend. If it contains, the file backend corresponding to this
            value will be used and initialized with the remaining values,
            otherwise the corresponding file backend will be selected
            based on the prefix of the file path. Defaults to None.
            New in version 2.0.0rc4.
    """

    def __init__(self, backend_args: Optional[dict] = None) -> None:
        self.backend_args: Optional[dict] = None
        if backend_args is not None:
            self.backend_args = backend_args.copy()

    def transform(self, results: dict) -> Optional[dict]:
        """Functions to load image.

        Args:
            results (dict): Result dict from
                :class:`mmengine.dataset.BaseDataset`.

        Returns:
            dict: The dict contains loaded image and meta information.
        """

        filename = results['info_path']
        data_info = mmengine.load(filename, backend_args=self.backend_args)

        results['height'] = data_info['image']['height']
        results['width'] = data_info['image']['width']

        # The code here are similar to `parse_data_info` in coco
        instances = []
        for ann in sorted(data_info['annotations'], key=lambda x: -x['area']):
            instance = {}

            if ann.get('ignore', False):
                continue
            x1, y1, w, h = ann['bbox']
            inter_w = max(0, min(x1 + w, results['width']) - max(x1, 0))
            inter_h = max(0, min(y1 + h, results['height']) - max(y1, 0))
            if inter_w * inter_h == 0:
                continue
            if ann['area'] <= 0 or w < 1 or h < 1:
                continue
            bbox = [x1, y1, x1 + w, y1 + h]

            instance['ignore_flag'] = 0
            instance['bbox'] = bbox
            instance['bbox_label'] = 0

            if ann.get('segmentation', None):
                instance['mask'] = ann['segmentation']

            if ann.get('point_coords', None):
                instance['point_coords'] = ann['point_coords']

            instances.append(instance)

        results['instances'] = instances
        return results

    def __repr__(self):
        repr_str = (f'{self.__class__.__name__}('
                    f'backend_args={self.backend_args})')

        return repr_str


@TRANSFORMS.register_module()
class LoadAnnotationsSAM(MMDET_LoadAnnotations):

    def __init__(self, *args, with_point_coords=False, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.with_point_coords = with_point_coords

    def _load_point_coords(self, results: dict) -> None:
        assert self.with_point_coords
        gt_point_coords = []
        for instance in results.get('instances', []):
            gt_point_coords.append(instance['point_coords'])
        results['gt_point_coords'] = np.array(gt_point_coords, dtype=np.float32)

    def transform(self, results: dict) -> Optional[dict]:
        super().transform(results)
        if self.with_point_coords:
            self._load_point_coords(results)
        return results


@TRANSFORMS.register_module()
class FilterAnnotationsHB(BaseTransform):
    """Filter invalid annotations.

    Required Keys:

    - gt_bboxes (BaseBoxes[torch.float32]) (optional)
    - gt_bboxes_labels (np.int64) (optional)
    - gt_masks (BitmapMasks | PolygonMasks) (optional)
    - gt_ignore_flags (bool) (optional)

    Modified Keys:

    - gt_bboxes (optional)
    - gt_bboxes_labels (optional)
    - gt_masks (optional)
    - gt_ignore_flags (optional)

    Args:
        min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth
            boxes. Default: (1., 1.)
        min_gt_mask_area (int): Minimum foreground area of ground truth masks.
            Default: 1
        by_box (bool): Filter instances with bounding boxes not meeting the
            min_gt_bbox_wh threshold. Default: True
        by_mask (bool): Filter instances with masks not meeting
            min_gt_mask_area threshold. Default: False
        keep_empty (bool): Whether to return None when it
            becomes an empty bbox after filtering. Defaults to True.
    """

    def __init__(self,
                 min_gt_bbox_wh: Tuple[int, int] = (1, 1),
                 min_gt_mask_area: int = 1,
                 by_box: bool = True,
                 by_mask: bool = False) -> None:
        assert by_box or by_mask
        self.min_gt_bbox_wh = min_gt_bbox_wh
        self.min_gt_mask_area = min_gt_mask_area
        self.by_box = by_box
        self.by_mask = by_mask

    @autocast_box_type()
    def transform(self, results: dict) -> Union[dict, None]:
        """Transform function to filter annotations.

        Args:
            results (dict): Result dict.

        Returns:
            dict: Updated result dict.
        """
        assert 'gt_bboxes' in results
        gt_bboxes = results['gt_bboxes']
        if gt_bboxes.shape[0] == 0:
            return None

        tests = []
        if self.by_box:
            tests.append(
                ((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
                 (gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy())
        if self.by_mask:
            assert 'gt_masks' in results
            gt_masks = results['gt_masks']
            tests.append(gt_masks.areas >= self.min_gt_mask_area)

        keep = tests[0]
        for t in tests[1:]:
            keep = keep & t

        results['gt_ignore_flags'] = np.logical_or(results['gt_ignore_flags'], np.logical_not(keep))
        if results['gt_ignore_flags'].all():
            return None
        return results

    def __repr__(self):
        return self.__class__.__name__


@TRANSFORMS.register_module()
class GTNMS(BaseTransform):

    def __init__(self,
                 by_box: bool = True,
                 by_mask: bool = False
                 ) -> None:
        assert by_box or by_mask and not (by_box and by_mask)
        self.by_box = by_box
        self.by_mask = by_mask

    @autocast_box_type()
    def transform(self, results: dict) -> Union[dict, None]:
        """Transform function to filter annotations.

        Args:
            results (dict): Result dict.

        Returns:
            dict: Updated result dict.
        """
        gt_ignore_flags = results['gt_ignore_flags']
        if self.by_box:
            raise NotImplementedError
        if self.by_mask:
            assert 'gt_masks' in results
            gt_masks = results['gt_masks'].masks
            tot_mask = np.zeros_like(gt_masks[0], dtype=np.uint8)
            for idx, mask in enumerate(gt_masks):
                if gt_ignore_flags[idx]:
                    continue
                overlapping = mask * tot_mask
                ratio = overlapping.sum() / sum(mask).sum()
                if ratio > 0.8:
                    # ignore with overlapping
                    gt_ignore_flags[idx] = True
                    continue
                tot_mask = (tot_mask + mask).clip(max=1)

        results['gt_ignore_flags'] = gt_ignore_flags
        return results

    def __repr__(self):
        return self.__class__.__name__


@TRANSFORMS.register_module()
class LoadFeatFromFile(BaseTransform):

    def __init__(self, model_name='vit_h'):
        self.cache_suffix = f'_{model_name}_cache.pth'

    def transform(self, results: dict) -> Optional[dict]:
        img_path = results['img_path']
        feat_path = img_path.replace('.jpg', self.cache_suffix)
        assert mmengine.exists(feat_path)
        feat = torch.load(feat_path)
        results['feat'] = feat
        return results

    def __repr__(self):
        repr_str = f'{self.__class__.__name__}'

        return repr_str


@TRANSFORMS.register_module()
class ResizeOri(BaseTransform):

    def __init__(
            self,
            backend: str = 'cv2',
            interpolation='bilinear'
    ):
        self.backend = backend
        self.interpolation = interpolation

    def transform(self, results: dict) -> Optional[dict]:
        results['ori_shape'] = results['img_shape']
        results['scale_factor'] = (1., 1.)
        return results

    def __repr__(self):
        repr_str = f'{self.__class__.__name__}'
        return repr_str