Spaces:
Runtime error
Runtime error
from typing import Optional, Tuple, Union | |
import mmcv | |
import mmengine | |
import numpy as np | |
import pycocotools.mask as maskUtils | |
import torch | |
from mmcv.transforms.base import BaseTransform | |
from mmdet.registry import TRANSFORMS | |
from mmdet.datasets.transforms import LoadAnnotations as MMDET_LoadAnnotations | |
from mmdet.structures.bbox import autocast_box_type | |
from mmdet.structures.mask import BitmapMasks | |
from mmdet.datasets.transforms import LoadPanopticAnnotations | |
from mmengine.fileio import get | |
from seg.models.utils import NO_OBJ | |
class LoadPanopticAnnotationsHB(LoadPanopticAnnotations): | |
def _load_masks_and_semantic_segs(self, results: dict) -> None: | |
"""Private function to load mask and semantic segmentation annotations. | |
In gt_semantic_seg, the foreground label is from ``0`` to | |
``num_things - 1``, the background label is from ``num_things`` to | |
``num_things + num_stuff - 1``, 255 means the ignored label (``VOID``). | |
Args: | |
results (dict): Result dict from :obj:``mmdet.CustomDataset``. | |
""" | |
# seg_map_path is None, when inference on the dataset without gts. | |
if results.get('seg_map_path', None) is None: | |
return | |
img_bytes = get( | |
results['seg_map_path'], backend_args=self.backend_args) | |
pan_png = mmcv.imfrombytes( | |
img_bytes, flag='color', channel_order='rgb').squeeze() | |
pan_png = self.rgb2id(pan_png) | |
gt_masks = [] | |
gt_seg = np.zeros_like(pan_png).astype(np.int32) + NO_OBJ # 255 as ignore | |
for segment_info in results['segments_info']: | |
mask = (pan_png == segment_info['id']) | |
gt_seg = np.where(mask, segment_info['category'], gt_seg) | |
# The legal thing masks | |
if segment_info.get('is_thing'): | |
gt_masks.append(mask.astype(np.uint8)) | |
if self.with_mask: | |
h, w = results['ori_shape'] | |
gt_masks = BitmapMasks(gt_masks, h, w) | |
results['gt_masks'] = gt_masks | |
if self.with_seg: | |
results['gt_seg_map'] = gt_seg | |
class LoadVideoSegAnnotations(LoadPanopticAnnotations): | |
def __init__( | |
self, | |
**kwargs | |
) -> None: | |
super().__init__(**kwargs) | |
def _load_instances_ids(self, results: dict) -> None: | |
"""Private function to load instances id annotations. | |
Args: | |
results (dict): Result dict from :obj :obj:``mmcv.BaseDataset``. | |
Returns: | |
dict: The dict containing instances id annotations. | |
""" | |
gt_instances_ids = [] | |
for instance in results['instances']: | |
gt_instances_ids.append(instance['instance_id']) | |
results['gt_instances_ids'] = np.array( | |
gt_instances_ids, dtype=np.int32) | |
def _load_masks_and_semantic_segs(self, results: dict) -> None: | |
h, w = results['ori_shape'] | |
gt_masks = [] | |
gt_seg = np.zeros((h, w), dtype=np.int32) + NO_OBJ | |
for segment_info in results['segments_info']: | |
mask = maskUtils.decode(segment_info['mask']) | |
gt_seg = np.where(mask, segment_info['category'], gt_seg) | |
# The legal thing masks | |
if segment_info.get('is_thing'): | |
gt_masks.append(mask.astype(np.uint8)) | |
if self.with_mask: | |
h, w = results['ori_shape'] | |
gt_masks = BitmapMasks(gt_masks, h, w) | |
results['gt_masks'] = gt_masks | |
if self.with_seg: | |
results['gt_seg_map'] = gt_seg | |
def transform(self, results: dict) -> dict: | |
"""Function to load multiple types panoptic annotations. | |
Args: | |
results (dict): Result dict from :obj:``mmdet.CustomDataset``. | |
Returns: | |
dict: The dict contains loaded bounding box, label, mask and | |
semantic segmentation annotations. | |
""" | |
super().transform(results) | |
self._load_instances_ids(results) | |
return results | |
class LoadJSONFromFile(BaseTransform): | |
"""Load an json from file. | |
Required Keys: | |
- info_path | |
Modified Keys: | |
Args: | |
backend_args (dict, optional): Instantiates the corresponding file | |
backend. It may contain `backend` key to specify the file | |
backend. If it contains, the file backend corresponding to this | |
value will be used and initialized with the remaining values, | |
otherwise the corresponding file backend will be selected | |
based on the prefix of the file path. Defaults to None. | |
New in version 2.0.0rc4. | |
""" | |
def __init__(self, backend_args: Optional[dict] = None) -> None: | |
self.backend_args: Optional[dict] = None | |
if backend_args is not None: | |
self.backend_args = backend_args.copy() | |
def transform(self, results: dict) -> Optional[dict]: | |
"""Functions to load image. | |
Args: | |
results (dict): Result dict from | |
:class:`mmengine.dataset.BaseDataset`. | |
Returns: | |
dict: The dict contains loaded image and meta information. | |
""" | |
filename = results['info_path'] | |
data_info = mmengine.load(filename, backend_args=self.backend_args) | |
results['height'] = data_info['image']['height'] | |
results['width'] = data_info['image']['width'] | |
# The code here are similar to `parse_data_info` in coco | |
instances = [] | |
for ann in sorted(data_info['annotations'], key=lambda x: -x['area']): | |
instance = {} | |
if ann.get('ignore', False): | |
continue | |
x1, y1, w, h = ann['bbox'] | |
inter_w = max(0, min(x1 + w, results['width']) - max(x1, 0)) | |
inter_h = max(0, min(y1 + h, results['height']) - max(y1, 0)) | |
if inter_w * inter_h == 0: | |
continue | |
if ann['area'] <= 0 or w < 1 or h < 1: | |
continue | |
bbox = [x1, y1, x1 + w, y1 + h] | |
instance['ignore_flag'] = 0 | |
instance['bbox'] = bbox | |
instance['bbox_label'] = 0 | |
if ann.get('segmentation', None): | |
instance['mask'] = ann['segmentation'] | |
if ann.get('point_coords', None): | |
instance['point_coords'] = ann['point_coords'] | |
instances.append(instance) | |
results['instances'] = instances | |
return results | |
def __repr__(self): | |
repr_str = (f'{self.__class__.__name__}(' | |
f'backend_args={self.backend_args})') | |
return repr_str | |
class LoadAnnotationsSAM(MMDET_LoadAnnotations): | |
def __init__(self, *args, with_point_coords=False, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.with_point_coords = with_point_coords | |
def _load_point_coords(self, results: dict) -> None: | |
assert self.with_point_coords | |
gt_point_coords = [] | |
for instance in results.get('instances', []): | |
gt_point_coords.append(instance['point_coords']) | |
results['gt_point_coords'] = np.array(gt_point_coords, dtype=np.float32) | |
def transform(self, results: dict) -> Optional[dict]: | |
super().transform(results) | |
if self.with_point_coords: | |
self._load_point_coords(results) | |
return results | |
class FilterAnnotationsHB(BaseTransform): | |
"""Filter invalid annotations. | |
Required Keys: | |
- gt_bboxes (BaseBoxes[torch.float32]) (optional) | |
- gt_bboxes_labels (np.int64) (optional) | |
- gt_masks (BitmapMasks | PolygonMasks) (optional) | |
- gt_ignore_flags (bool) (optional) | |
Modified Keys: | |
- gt_bboxes (optional) | |
- gt_bboxes_labels (optional) | |
- gt_masks (optional) | |
- gt_ignore_flags (optional) | |
Args: | |
min_gt_bbox_wh (tuple[float]): Minimum width and height of ground truth | |
boxes. Default: (1., 1.) | |
min_gt_mask_area (int): Minimum foreground area of ground truth masks. | |
Default: 1 | |
by_box (bool): Filter instances with bounding boxes not meeting the | |
min_gt_bbox_wh threshold. Default: True | |
by_mask (bool): Filter instances with masks not meeting | |
min_gt_mask_area threshold. Default: False | |
keep_empty (bool): Whether to return None when it | |
becomes an empty bbox after filtering. Defaults to True. | |
""" | |
def __init__(self, | |
min_gt_bbox_wh: Tuple[int, int] = (1, 1), | |
min_gt_mask_area: int = 1, | |
by_box: bool = True, | |
by_mask: bool = False) -> None: | |
assert by_box or by_mask | |
self.min_gt_bbox_wh = min_gt_bbox_wh | |
self.min_gt_mask_area = min_gt_mask_area | |
self.by_box = by_box | |
self.by_mask = by_mask | |
def transform(self, results: dict) -> Union[dict, None]: | |
"""Transform function to filter annotations. | |
Args: | |
results (dict): Result dict. | |
Returns: | |
dict: Updated result dict. | |
""" | |
assert 'gt_bboxes' in results | |
gt_bboxes = results['gt_bboxes'] | |
if gt_bboxes.shape[0] == 0: | |
return None | |
tests = [] | |
if self.by_box: | |
tests.append( | |
((gt_bboxes.widths > self.min_gt_bbox_wh[0]) & | |
(gt_bboxes.heights > self.min_gt_bbox_wh[1])).numpy()) | |
if self.by_mask: | |
assert 'gt_masks' in results | |
gt_masks = results['gt_masks'] | |
tests.append(gt_masks.areas >= self.min_gt_mask_area) | |
keep = tests[0] | |
for t in tests[1:]: | |
keep = keep & t | |
results['gt_ignore_flags'] = np.logical_or(results['gt_ignore_flags'], np.logical_not(keep)) | |
if results['gt_ignore_flags'].all(): | |
return None | |
return results | |
def __repr__(self): | |
return self.__class__.__name__ | |
class GTNMS(BaseTransform): | |
def __init__(self, | |
by_box: bool = True, | |
by_mask: bool = False | |
) -> None: | |
assert by_box or by_mask and not (by_box and by_mask) | |
self.by_box = by_box | |
self.by_mask = by_mask | |
def transform(self, results: dict) -> Union[dict, None]: | |
"""Transform function to filter annotations. | |
Args: | |
results (dict): Result dict. | |
Returns: | |
dict: Updated result dict. | |
""" | |
gt_ignore_flags = results['gt_ignore_flags'] | |
if self.by_box: | |
raise NotImplementedError | |
if self.by_mask: | |
assert 'gt_masks' in results | |
gt_masks = results['gt_masks'].masks | |
tot_mask = np.zeros_like(gt_masks[0], dtype=np.uint8) | |
for idx, mask in enumerate(gt_masks): | |
if gt_ignore_flags[idx]: | |
continue | |
overlapping = mask * tot_mask | |
ratio = overlapping.sum() / sum(mask).sum() | |
if ratio > 0.8: | |
# ignore with overlapping | |
gt_ignore_flags[idx] = True | |
continue | |
tot_mask = (tot_mask + mask).clip(max=1) | |
results['gt_ignore_flags'] = gt_ignore_flags | |
return results | |
def __repr__(self): | |
return self.__class__.__name__ | |
class LoadFeatFromFile(BaseTransform): | |
def __init__(self, model_name='vit_h'): | |
self.cache_suffix = f'_{model_name}_cache.pth' | |
def transform(self, results: dict) -> Optional[dict]: | |
img_path = results['img_path'] | |
feat_path = img_path.replace('.jpg', self.cache_suffix) | |
assert mmengine.exists(feat_path) | |
feat = torch.load(feat_path) | |
results['feat'] = feat | |
return results | |
def __repr__(self): | |
repr_str = f'{self.__class__.__name__}' | |
return repr_str | |
class ResizeOri(BaseTransform): | |
def __init__( | |
self, | |
backend: str = 'cv2', | |
interpolation='bilinear' | |
): | |
self.backend = backend | |
self.interpolation = interpolation | |
def transform(self, results: dict) -> Optional[dict]: | |
results['ori_shape'] = results['img_shape'] | |
results['scale_factor'] = (1., 1.) | |
return results | |
def __repr__(self): | |
repr_str = f'{self.__class__.__name__}' | |
return repr_str | |