Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import itertools | |
from collections.abc import Sized | |
from typing import Any, List, Union | |
import numpy as np | |
import torch | |
from mmengine.device import get_device | |
from .base_data_element import BaseDataElement | |
BoolTypeTensor: Union[Any] | |
LongTypeTensor: Union[Any] | |
if get_device() == 'npu': | |
BoolTypeTensor = Union[torch.BoolTensor, torch.npu.BoolTensor] | |
LongTypeTensor = Union[torch.LongTensor, torch.npu.LongTensor] | |
elif get_device() == 'mlu': | |
BoolTypeTensor = Union[torch.BoolTensor, torch.mlu.BoolTensor] | |
LongTypeTensor = Union[torch.LongTensor, torch.mlu.LongTensor] | |
else: | |
BoolTypeTensor = Union[torch.BoolTensor, torch.cuda.BoolTensor] | |
LongTypeTensor = Union[torch.LongTensor, torch.cuda.LongTensor] | |
IndexType: Union[Any] = Union[str, slice, int, list, LongTypeTensor, | |
BoolTypeTensor, np.ndarray] | |
# Modified from | |
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa | |
class InstanceData(BaseDataElement): | |
"""Data structure for instance-level annotations or predictions. | |
Subclass of :class:`BaseDataElement`. All value in `data_fields` | |
should have the same length. This design refer to | |
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501 | |
InstanceData also support extra functions: ``index``, ``slice`` and ``cat`` for data field. The type of value | |
in data field can be base data structure such as `torch.Tensor`, `numpy.ndarray`, `list`, `str`, `tuple`, | |
and can be customized data structure that has ``__len__``, ``__getitem__`` and ``cat`` attributes. | |
Examples: | |
>>> # custom data structure | |
>>> class TmpObject: | |
... def __init__(self, tmp) -> None: | |
... assert isinstance(tmp, list) | |
... self.tmp = tmp | |
... def __len__(self): | |
... return len(self.tmp) | |
... def __getitem__(self, item): | |
... if isinstance(item, int): | |
... if item >= len(self) or item < -len(self): # type:ignore | |
... raise IndexError(f'Index {item} out of range!') | |
... else: | |
... # keep the dimension | |
... item = slice(item, None, len(self)) | |
... return TmpObject(self.tmp[item]) | |
... @staticmethod | |
... def cat(tmp_objs): | |
... assert all(isinstance(results, TmpObject) for results in tmp_objs) | |
... if len(tmp_objs) == 1: | |
... return tmp_objs[0] | |
... tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs] | |
... tmp_list = list(itertools.chain(*tmp_list)) | |
... new_data = TmpObject(tmp_list) | |
... return new_data | |
... def __repr__(self): | |
... return str(self.tmp) | |
>>> from mmengine.structures import InstanceData | |
>>> import numpy as np | |
>>> import torch | |
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3)) | |
>>> instance_data = InstanceData(metainfo=img_meta) | |
>>> 'img_shape' in instance_data | |
True | |
>>> instance_data.det_labels = torch.LongTensor([2, 3]) | |
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7]) | |
>>> instance_data.bboxes = torch.rand((2, 4)) | |
>>> instance_data.polygons = TmpObject([[1, 2, 3, 4], [5, 6, 7, 8]]) | |
>>> len(instance_data) | |
2 | |
>>> print(instance_data) | |
<InstanceData( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
pad_shape: (800, 1216, 3) | |
DATA FIELDS | |
det_labels: tensor([2, 3]) | |
det_scores: tensor([0.8000, 0.7000]) | |
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |
[0.8101, 0.3105, 0.5123, 0.6263]]) | |
polygons: [[1, 2, 3, 4], [5, 6, 7, 8]] | |
) at 0x7fb492de6280> | |
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices] | |
>>> sorted_results.det_scores | |
tensor([0.7000, 0.8000]) | |
>>> print(instance_data[instance_data.det_scores > 0.75]) | |
<InstanceData( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
pad_shape: (800, 1216, 3) | |
DATA FIELDS | |
det_labels: tensor([2]) | |
det_scores: tensor([0.8000]) | |
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]]) | |
polygons: [[1, 2, 3, 4]] | |
) at 0x7f64ecf0ec40> | |
>>> print(instance_data[instance_data.det_scores > 1]) | |
<InstanceData( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
pad_shape: (800, 1216, 3) | |
DATA FIELDS | |
det_labels: tensor([], dtype=torch.int64) | |
det_scores: tensor([]) | |
bboxes: tensor([], size=(0, 4)) | |
polygons: [] | |
) at 0x7f660a6a7f70> | |
>>> print(instance_data.cat([instance_data, instance_data])) | |
<InstanceData( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
pad_shape: (800, 1216, 3) | |
DATA FIELDS | |
det_labels: tensor([2, 3, 2, 3]) | |
det_scores: tensor([0.8000, 0.7000, 0.8000, 0.7000]) | |
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188], | |
[0.8101, 0.3105, 0.5123, 0.6263], | |
[0.4997, 0.7707, 0.0595, 0.4188], | |
[0.8101, 0.3105, 0.5123, 0.6263]]) | |
polygons: [[1, 2, 3, 4], [5, 6, 7, 8], [1, 2, 3, 4], [5, 6, 7, 8]] | |
) at 0x7f203542feb0> | |
""" | |
def __setattr__(self, name: str, value: Sized): | |
"""setattr is only used to set data. | |
The value must have the attribute of `__len__` and have the same length | |
of `InstanceData`. | |
""" | |
if name in ('_metainfo_fields', '_data_fields'): | |
if not hasattr(self, name): | |
super().__setattr__(name, value) | |
else: | |
raise AttributeError(f'{name} has been used as a ' | |
'private attribute, which is immutable.') | |
else: | |
assert isinstance(value, | |
Sized), 'value must contain `__len__` attribute' | |
if len(self) > 0: | |
assert len(value) == len(self), 'The length of ' \ | |
f'values {len(value)} is ' \ | |
'not consistent with ' \ | |
'the length of this ' \ | |
':obj:`InstanceData` ' \ | |
f'{len(self)}' | |
super().__setattr__(name, value) | |
__setitem__ = __setattr__ | |
def __getitem__(self, item: IndexType) -> 'InstanceData': | |
""" | |
Args: | |
item (str, int, list, :obj:`slice`, :obj:`numpy.ndarray`, | |
:obj:`torch.LongTensor`, :obj:`torch.BoolTensor`): | |
Get the corresponding values according to item. | |
Returns: | |
:obj:`InstanceData`: Corresponding values. | |
""" | |
assert isinstance(item, IndexType.__args__) | |
if isinstance(item, list): | |
item = np.array(item) | |
if isinstance(item, np.ndarray): | |
# The default int type of numpy is platform dependent, int32 for | |
# windows and int64 for linux. `torch.Tensor` requires the index | |
# should be int64, therefore we simply convert it to int64 here. | |
# More details in https://github.com/numpy/numpy/issues/9464 | |
item = item.astype(np.int64) if item.dtype == np.int32 else item | |
item = torch.from_numpy(item) | |
if isinstance(item, str): | |
return getattr(self, item) | |
if isinstance(item, int): | |
if item >= len(self) or item < -len(self): # type:ignore | |
raise IndexError(f'Index {item} out of range!') | |
else: | |
# keep the dimension | |
item = slice(item, None, len(self)) | |
new_data = self.__class__(metainfo=self.metainfo) | |
if isinstance(item, torch.Tensor): | |
assert item.dim() == 1, 'Only support to get the' \ | |
' values along the first dimension.' | |
if isinstance(item, BoolTypeTensor.__args__): | |
assert len(item) == len(self), 'The shape of the ' \ | |
'input(BoolTensor) ' \ | |
f'{len(item)} ' \ | |
'does not match the shape ' \ | |
'of the indexed tensor ' \ | |
'in results_field ' \ | |
f'{len(self)} at ' \ | |
'first dimension.' | |
for k, v in self.items(): | |
if isinstance(v, torch.Tensor): | |
new_data[k] = v[item] | |
elif isinstance(v, np.ndarray): | |
new_data[k] = v[item.cpu().numpy()] | |
elif isinstance( | |
v, (str, list, tuple)) or (hasattr(v, '__getitem__') | |
and hasattr(v, 'cat')): | |
# convert to indexes from BoolTensor | |
if isinstance(item, BoolTypeTensor.__args__): | |
indexes = torch.nonzero(item).view( | |
-1).cpu().numpy().tolist() | |
else: | |
indexes = item.cpu().numpy().tolist() | |
slice_list = [] | |
if indexes: | |
for index in indexes: | |
slice_list.append(slice(index, None, len(v))) | |
else: | |
slice_list.append(slice(None, 0, None)) | |
r_list = [v[s] for s in slice_list] | |
if isinstance(v, (str, list, tuple)): | |
new_value = r_list[0] | |
for r in r_list[1:]: | |
new_value = new_value + r | |
else: | |
new_value = v.cat(r_list) | |
new_data[k] = new_value | |
else: | |
raise ValueError( | |
f'The type of `{k}` is `{type(v)}`, which has no ' | |
'attribute of `cat`, so it does not ' | |
'support slice with `bool`') | |
else: | |
# item is a slice | |
for k, v in self.items(): | |
new_data[k] = v[item] | |
return new_data # type:ignore | |
def cat(instances_list: List['InstanceData']) -> 'InstanceData': | |
"""Concat the instances of all :obj:`InstanceData` in the list. | |
Note: To ensure that cat returns as expected, make sure that | |
all elements in the list must have exactly the same keys. | |
Args: | |
instances_list (list[:obj:`InstanceData`]): A list | |
of :obj:`InstanceData`. | |
Returns: | |
:obj:`InstanceData` | |
""" | |
assert all( | |
isinstance(results, InstanceData) for results in instances_list) | |
assert len(instances_list) > 0 | |
if len(instances_list) == 1: | |
return instances_list[0] | |
# metainfo and data_fields must be exactly the | |
# same for each element to avoid exceptions. | |
field_keys_list = [ | |
instances.all_keys() for instances in instances_list | |
] | |
assert len({len(field_keys) for field_keys in field_keys_list}) \ | |
== 1 and len(set(itertools.chain(*field_keys_list))) \ | |
== len(field_keys_list[0]), 'There are different keys in ' \ | |
'`instances_list`, which may ' \ | |
'cause the cat operation ' \ | |
'to fail. Please make sure all ' \ | |
'elements in `instances_list` ' \ | |
'have the exact same key.' | |
new_data = instances_list[0].__class__( | |
metainfo=instances_list[0].metainfo) | |
for k in instances_list[0].keys(): | |
values = [results[k] for results in instances_list] | |
v0 = values[0] | |
if isinstance(v0, torch.Tensor): | |
new_values = torch.cat(values, dim=0) | |
elif isinstance(v0, np.ndarray): | |
new_values = np.concatenate(values, axis=0) | |
elif isinstance(v0, (str, list, tuple)): | |
new_values = v0[:] | |
for v in values[1:]: | |
new_values += v | |
elif hasattr(v0, 'cat'): | |
new_values = v0.cat(values) | |
else: | |
raise ValueError( | |
f'The type of `{k}` is `{type(v0)}` which has no ' | |
'attribute of `cat`') | |
new_data[k] = new_values | |
return new_data # type:ignore | |
def __len__(self) -> int: | |
"""int: The length of InstanceData.""" | |
if len(self._data_fields) > 0: | |
return len(self.values()[0]) | |
else: | |
return 0 | |