Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
from typing import Any, Iterator, Optional, Tuple, Type, Union | |
import numpy as np | |
import torch | |
class BaseDataElement: | |
"""A base data interface that supports Tensor-like and dict-like | |
operations. | |
A typical data elements refer to predicted results or ground truth labels | |
on a task, such as predicted bboxes, instance masks, semantic | |
segmentation masks, etc. Because groundtruth labels and predicted results | |
often have similar properties (for example, the predicted bboxes and the | |
groundtruth bboxes), MMEngine uses the same abstract data interface to | |
encapsulate predicted results and groundtruth labels, and it is recommended | |
to use different name conventions to distinguish them, such as using | |
``gt_instances`` and ``pred_instances`` to distinguish between labels and | |
predicted results. Additionally, we distinguish data elements at instance | |
level, pixel level, and label level. Each of these types has its own | |
characteristics. Therefore, MMEngine defines the base class | |
``BaseDataElement``, and implement ``InstanceData``, ``PixelData``, and | |
``LabelData`` inheriting from ``BaseDataElement`` to represent different | |
types of ground truth labels or predictions. | |
Another common data element is sample data. A sample data consists of input | |
data (such as an image) and its annotations and predictions. In general, | |
an image can have multiple types of annotations and/or predictions at the | |
same time (for example, both pixel-level semantic segmentation annotations | |
and instance-level detection bboxes annotations). All labels and | |
predictions of a training sample are often passed between Dataset, Model, | |
Visualizer, and Evaluator components. In order to simplify the interface | |
between components, we can treat them as a large data element and | |
encapsulate them. Such data elements are generally called XXDataSample in | |
the OpenMMLab. Therefore, Similar to `nn.Module`, the `BaseDataElement` | |
allows `BaseDataElement` as its attribute. Such a class generally | |
encapsulates all the data of a sample in the algorithm library, and its | |
attributes generally are various types of data elements. For example, | |
MMDetection is assigned by the BaseDataElement to encapsulate all the data | |
elements of the sample labeling and prediction of a sample in the | |
algorithm library. | |
The attributes in ``BaseDataElement`` are divided into two parts, | |
the ``metainfo`` and the ``data`` respectively. | |
- ``metainfo``: Usually contains the | |
information about the image such as filename, | |
image_shape, pad_shape, etc. The attributes can be accessed or | |
modified by dict-like or object-like operations, such as | |
``.`` (for data access and modification), ``in``, ``del``, | |
``pop(str)``, ``get(str)``, ``metainfo_keys()``, | |
``metainfo_values()``, ``metainfo_items()``, ``set_metainfo()`` (for | |
set or change key-value pairs in metainfo). | |
- ``data``: Annotations or model predictions are | |
stored. The attributes can be accessed or modified by | |
dict-like or object-like operations, such as | |
``.``, ``in``, ``del``, ``pop(str)``, ``get(str)``, ``keys()``, | |
``values()``, ``items()``. Users can also apply tensor-like | |
methods to all :obj:`torch.Tensor` in the ``data_fields``, | |
such as ``.cuda()``, ``.cpu()``, ``.numpy()``, ``.to()``, | |
``to_tensor()``, ``.detach()``. | |
Args: | |
metainfo (dict, optional): A dict contains the meta information | |
of single image, such as ``dict(img_shape=(512, 512, 3), | |
scale_factor=(1, 1, 1, 1))``. Defaults to None. | |
kwargs (dict, optional): A dict contains annotations of single image or | |
model predictions. Defaults to None. | |
Examples: | |
>>> import torch | |
>>> from mmengine.structures import BaseDataElement | |
>>> gt_instances = BaseDataElement() | |
>>> bboxes = torch.rand((5, 4)) | |
>>> scores = torch.rand((5,)) | |
>>> img_id = 0 | |
>>> img_shape = (800, 1333) | |
>>> gt_instances = BaseDataElement( | |
... metainfo=dict(img_id=img_id, img_shape=img_shape), | |
... bboxes=bboxes, scores=scores) | |
>>> gt_instances = BaseDataElement( | |
... metainfo=dict(img_id=img_id, img_shape=(640, 640))) | |
>>> # new | |
>>> gt_instances1 = gt_instances.new( | |
... metainfo=dict(img_id=1, img_shape=(640, 640)), | |
... bboxes=torch.rand((5, 4)), | |
... scores=torch.rand((5,))) | |
>>> gt_instances2 = gt_instances1.new() | |
>>> # add and process property | |
>>> gt_instances = BaseDataElement() | |
>>> gt_instances.set_metainfo(dict(img_id=9, img_shape=(100, 100))) | |
>>> assert 'img_shape' in gt_instances.metainfo_keys() | |
>>> assert 'img_shape' in gt_instances | |
>>> assert 'img_shape' not in gt_instances.keys() | |
>>> assert 'img_shape' in gt_instances.all_keys() | |
>>> print(gt_instances.img_shape) | |
(100, 100) | |
>>> gt_instances.scores = torch.rand((5,)) | |
>>> assert 'scores' in gt_instances.keys() | |
>>> assert 'scores' in gt_instances | |
>>> assert 'scores' in gt_instances.all_keys() | |
>>> assert 'scores' not in gt_instances.metainfo_keys() | |
>>> print(gt_instances.scores) | |
tensor([0.5230, 0.7885, 0.2426, 0.3911, 0.4876]) | |
>>> gt_instances.bboxes = torch.rand((5, 4)) | |
>>> assert 'bboxes' in gt_instances.keys() | |
>>> assert 'bboxes' in gt_instances | |
>>> assert 'bboxes' in gt_instances.all_keys() | |
>>> assert 'bboxes' not in gt_instances.metainfo_keys() | |
>>> print(gt_instances.bboxes) | |
tensor([[0.0900, 0.0424, 0.1755, 0.4469], | |
[0.8648, 0.0592, 0.3484, 0.0913], | |
[0.5808, 0.1909, 0.6165, 0.7088], | |
[0.5490, 0.4209, 0.9416, 0.2374], | |
[0.3652, 0.1218, 0.8805, 0.7523]]) | |
>>> # delete and change property | |
>>> gt_instances = BaseDataElement( | |
... metainfo=dict(img_id=0, img_shape=(640, 640)), | |
... bboxes=torch.rand((6, 4)), scores=torch.rand((6,))) | |
>>> gt_instances.set_metainfo(dict(img_shape=(1280, 1280))) | |
>>> gt_instances.img_shape # (1280, 1280) | |
>>> gt_instances.bboxes = gt_instances.bboxes * 2 | |
>>> gt_instances.get('img_shape', None) # (1280, 1280) | |
>>> gt_instances.get('bboxes', None) # 6x4 tensor | |
>>> del gt_instances.img_shape | |
>>> del gt_instances.bboxes | |
>>> assert 'img_shape' not in gt_instances | |
>>> assert 'bboxes' not in gt_instances | |
>>> gt_instances.pop('img_shape', None) # None | |
>>> gt_instances.pop('bboxes', None) # None | |
>>> # Tensor-like | |
>>> cuda_instances = gt_instances.cuda() | |
>>> cuda_instances = gt_instances.to('cuda:0') | |
>>> cpu_instances = cuda_instances.cpu() | |
>>> cpu_instances = cuda_instances.to('cpu') | |
>>> fp16_instances = cuda_instances.to( | |
... device=None, dtype=torch.float16, non_blocking=False, | |
... copy=False, memory_format=torch.preserve_format) | |
>>> cpu_instances = cuda_instances.detach() | |
>>> np_instances = cpu_instances.numpy() | |
>>> metainfo = dict(img_shape=(800, 1196, 3)) | |
>>> gt_instances = BaseDataElement( | |
... metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) | |
>>> sample = BaseDataElement(metainfo=metainfo, | |
... gt_instances=gt_instances) | |
>>> print(sample) | |
<BaseDataElement( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
DATA FIELDS | |
gt_instances: <BaseDataElement( | |
META INFORMATION | |
img_shape: (800, 1196, 3) | |
DATA FIELDS | |
det_labels: tensor([0, 1, 2, 3]) | |
) at 0x7f0ec5eadc70> | |
) at 0x7f0fea49e130> | |
>>> # inheritance | |
>>> class DetDataSample(BaseDataElement): | |
... @property | |
... def proposals(self): | |
... return self._proposals | |
... @proposals.setter | |
... def proposals(self, value): | |
... self.set_field(value, '_proposals', dtype=BaseDataElement) | |
... @proposals.deleter | |
... def proposals(self): | |
... del self._proposals | |
... @property | |
... def gt_instances(self): | |
... return self._gt_instances | |
... @gt_instances.setter | |
... def gt_instances(self, value): | |
... self.set_field(value, '_gt_instances', | |
... dtype=BaseDataElement) | |
... @gt_instances.deleter | |
... def gt_instances(self): | |
... del self._gt_instances | |
... @property | |
... def pred_instances(self): | |
... return self._pred_instances | |
... @pred_instances.setter | |
... def pred_instances(self, value): | |
... self.set_field(value, '_pred_instances', | |
... dtype=BaseDataElement) | |
... @pred_instances.deleter | |
... def pred_instances(self): | |
... del self._pred_instances | |
>>> det_sample = DetDataSample() | |
>>> proposals = BaseDataElement(bboxes=torch.rand((5, 4))) | |
>>> det_sample.proposals = proposals | |
>>> assert 'proposals' in det_sample | |
>>> assert det_sample.proposals == proposals | |
>>> del det_sample.proposals | |
>>> assert 'proposals' not in det_sample | |
>>> with self.assertRaises(AssertionError): | |
... det_sample.proposals = torch.rand((5, 4)) | |
""" | |
def __init__(self, *, metainfo: Optional[dict] = None, **kwargs) -> None: | |
self._metainfo_fields: set = set() | |
self._data_fields: set = set() | |
if metainfo is not None: | |
self.set_metainfo(metainfo=metainfo) | |
if kwargs: | |
self.set_data(kwargs) | |
def set_metainfo(self, metainfo: dict) -> None: | |
"""Set or change key-value pairs in ``metainfo_field`` by parameter | |
``metainfo``. | |
Args: | |
metainfo (dict): A dict contains the meta information | |
of image, such as ``img_shape``, ``scale_factor``, etc. | |
""" | |
assert isinstance( | |
metainfo, | |
dict), f'metainfo should be a ``dict`` but got {type(metainfo)}' | |
meta = copy.deepcopy(metainfo) | |
for k, v in meta.items(): | |
self.set_field(name=k, value=v, field_type='metainfo', dtype=None) | |
def set_data(self, data: dict) -> None: | |
"""Set or change key-value pairs in ``data_field`` by parameter | |
``data``. | |
Args: | |
data (dict): A dict contains annotations of image or | |
model predictions. | |
""" | |
assert isinstance(data, | |
dict), f'data should be a `dict` but got {data}' | |
for k, v in data.items(): | |
# Use `setattr()` rather than `self.set_field` to allow `set_data` | |
# to set property method. | |
setattr(self, k, v) | |
def update(self, instance: 'BaseDataElement') -> None: | |
"""The update() method updates the BaseDataElement with the elements | |
from another BaseDataElement object. | |
Args: | |
instance (BaseDataElement): Another BaseDataElement object for | |
update the current object. | |
""" | |
assert isinstance( | |
instance, BaseDataElement | |
), f'instance should be a `BaseDataElement` but got {type(instance)}' | |
self.set_metainfo(dict(instance.metainfo_items())) | |
self.set_data(dict(instance.items())) | |
def new(self, | |
*, | |
metainfo: Optional[dict] = None, | |
**kwargs) -> 'BaseDataElement': | |
"""Return a new data element with same type. If ``metainfo`` and | |
``data`` are None, the new data element will have same metainfo and | |
data. If metainfo or data is not None, the new result will overwrite it | |
with the input value. | |
Args: | |
metainfo (dict, optional): A dict contains the meta information | |
of image, such as ``img_shape``, ``scale_factor``, etc. | |
Defaults to None. | |
kwargs (dict): A dict contains annotations of image or | |
model predictions. | |
Returns: | |
BaseDataElement: A new data element with same type. | |
""" | |
new_data = self.__class__() | |
if metainfo is not None: | |
new_data.set_metainfo(metainfo) | |
else: | |
new_data.set_metainfo(dict(self.metainfo_items())) | |
if kwargs: | |
new_data.set_data(kwargs) | |
else: | |
new_data.set_data(dict(self.items())) | |
return new_data | |
def clone(self): | |
"""Deep copy the current data element. | |
Returns: | |
BaseDataElement: The copy of current data element. | |
""" | |
clone_data = self.__class__() | |
clone_data.set_metainfo(dict(self.metainfo_items())) | |
clone_data.set_data(dict(self.items())) | |
return clone_data | |
def keys(self) -> list: | |
""" | |
Returns: | |
list: Contains all keys in data_fields. | |
""" | |
# We assume that the name of the attribute related to property is | |
# '_' + the name of the property. We use this rule to filter out | |
# private keys. | |
# TODO: Use a more robust way to solve this problem | |
private_keys = { | |
'_' + key | |
for key in self._data_fields | |
if isinstance(getattr(type(self), key, None), property) | |
} | |
return list(self._data_fields - private_keys) | |
def metainfo_keys(self) -> list: | |
""" | |
Returns: | |
list: Contains all keys in metainfo_fields. | |
""" | |
return list(self._metainfo_fields) | |
def values(self) -> list: | |
""" | |
Returns: | |
list: Contains all values in data. | |
""" | |
return [getattr(self, k) for k in self.keys()] | |
def metainfo_values(self) -> list: | |
""" | |
Returns: | |
list: Contains all values in metainfo. | |
""" | |
return [getattr(self, k) for k in self.metainfo_keys()] | |
def all_keys(self) -> list: | |
""" | |
Returns: | |
list: Contains all keys in metainfo and data. | |
""" | |
return self.metainfo_keys() + self.keys() | |
def all_values(self) -> list: | |
""" | |
Returns: | |
list: Contains all values in metainfo and data. | |
""" | |
return self.metainfo_values() + self.values() | |
def all_items(self) -> Iterator[Tuple[str, Any]]: | |
""" | |
Returns: | |
iterator: An iterator object whose element is (key, value) tuple | |
pairs for ``metainfo`` and ``data``. | |
""" | |
for k in self.all_keys(): | |
yield (k, getattr(self, k)) | |
def items(self) -> Iterator[Tuple[str, Any]]: | |
""" | |
Returns: | |
iterator: An iterator object whose element is (key, value) tuple | |
pairs for ``data``. | |
""" | |
for k in self.keys(): | |
yield (k, getattr(self, k)) | |
def metainfo_items(self) -> Iterator[Tuple[str, Any]]: | |
""" | |
Returns: | |
iterator: An iterator object whose element is (key, value) tuple | |
pairs for ``metainfo``. | |
""" | |
for k in self.metainfo_keys(): | |
yield (k, getattr(self, k)) | |
def metainfo(self) -> dict: | |
"""dict: A dict contains metainfo of current data element.""" | |
return dict(self.metainfo_items()) | |
def __setattr__(self, name: str, value: Any): | |
"""setattr is only used to set data.""" | |
if name in ('_metainfo_fields', '_data_fields'): | |
if not hasattr(self, name): | |
super().__setattr__(name, value) | |
else: | |
raise AttributeError(f'{name} has been used as a ' | |
'private attribute, which is immutable.') | |
else: | |
self.set_field( | |
name=name, value=value, field_type='data', dtype=None) | |
def __delattr__(self, item: str): | |
"""Delete the item in dataelement. | |
Args: | |
item (str): The key to delete. | |
""" | |
if item in ('_metainfo_fields', '_data_fields'): | |
raise AttributeError(f'{item} has been used as a ' | |
'private attribute, which is immutable.') | |
super().__delattr__(item) | |
if item in self._metainfo_fields: | |
self._metainfo_fields.remove(item) | |
elif item in self._data_fields: | |
self._data_fields.remove(item) | |
# dict-like methods | |
__delitem__ = __delattr__ | |
def get(self, key, default=None) -> Any: | |
"""Get property in data and metainfo as the same as python.""" | |
# Use `getattr()` rather than `self.__dict__.get()` to allow getting | |
# properties. | |
return getattr(self, key, default) | |
def pop(self, *args) -> Any: | |
"""Pop property in data and metainfo as the same as python.""" | |
assert len(args) < 3, '``pop`` get more than 2 arguments' | |
name = args[0] | |
if name in self._metainfo_fields: | |
self._metainfo_fields.remove(args[0]) | |
return self.__dict__.pop(*args) | |
elif name in self._data_fields: | |
self._data_fields.remove(args[0]) | |
return self.__dict__.pop(*args) | |
# with default value | |
elif len(args) == 2: | |
return args[1] | |
else: | |
# don't just use 'self.__dict__.pop(*args)' for only popping key in | |
# metainfo or data | |
raise KeyError(f'{args[0]} is not contained in metainfo or data') | |
def __contains__(self, item: str) -> bool: | |
"""Whether the item is in dataelement. | |
Args: | |
item (str): The key to inquire. | |
""" | |
return item in self._data_fields or item in self._metainfo_fields | |
def set_field(self, | |
value: Any, | |
name: str, | |
dtype: Optional[Union[Type, Tuple[Type, ...]]] = None, | |
field_type: str = 'data') -> None: | |
"""Special method for set union field, used as property.setter | |
functions.""" | |
assert field_type in ['metainfo', 'data'] | |
if dtype is not None: | |
assert isinstance( | |
value, | |
dtype), f'{value} should be a {dtype} but got {type(value)}' | |
if field_type == 'metainfo': | |
if name in self._data_fields: | |
raise AttributeError( | |
f'Cannot set {name} to be a field of metainfo ' | |
f'because {name} is already a data field') | |
self._metainfo_fields.add(name) | |
else: | |
if name in self._metainfo_fields: | |
raise AttributeError( | |
f'Cannot set {name} to be a field of data ' | |
f'because {name} is already a metainfo field') | |
self._data_fields.add(name) | |
super().__setattr__(name, value) | |
# Tensor-like methods | |
def to(self, *args, **kwargs) -> 'BaseDataElement': | |
"""Apply same name function to all tensors in data_fields.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if hasattr(v, 'to'): | |
v = v.to(*args, **kwargs) | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
# Tensor-like methods | |
def cpu(self) -> 'BaseDataElement': | |
"""Convert all tensors to CPU in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.cpu() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
# Tensor-like methods | |
def cuda(self) -> 'BaseDataElement': | |
"""Convert all tensors to GPU in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.cuda() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
# Tensor-like methods | |
def npu(self) -> 'BaseDataElement': | |
"""Convert all tensors to NPU in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.npu() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
def mlu(self) -> 'BaseDataElement': | |
"""Convert all tensors to MLU in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.mlu() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
# Tensor-like methods | |
def detach(self) -> 'BaseDataElement': | |
"""Detach all tensors in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.detach() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
# Tensor-like methods | |
def numpy(self) -> 'BaseDataElement': | |
"""Convert all tensors to np.ndarray in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
if isinstance(v, (torch.Tensor, BaseDataElement)): | |
v = v.detach().cpu().numpy() | |
data = {k: v} | |
new_data.set_data(data) | |
return new_data | |
def to_tensor(self) -> 'BaseDataElement': | |
"""Convert all np.ndarray to tensor in data.""" | |
new_data = self.new() | |
for k, v in self.items(): | |
data = {} | |
if isinstance(v, np.ndarray): | |
v = torch.from_numpy(v) | |
data[k] = v | |
elif isinstance(v, BaseDataElement): | |
v = v.to_tensor() | |
data[k] = v | |
new_data.set_data(data) | |
return new_data | |
def to_dict(self) -> dict: | |
"""Convert BaseDataElement to dict.""" | |
return { | |
k: v.to_dict() if isinstance(v, BaseDataElement) else v | |
for k, v in self.all_items() | |
} | |
def __repr__(self) -> str: | |
"""Represent the object.""" | |
def _addindent(s_: str, num_spaces: int) -> str: | |
"""This func is modified from `pytorch` https://github.com/pytorch/ | |
pytorch/blob/b17b2b1cc7b017c3daaeff8cc7ec0f514d42ec37/torch/nn/modu | |
les/module.py#L29. | |
Args: | |
s_ (str): The string to add spaces. | |
num_spaces (int): The num of space to add. | |
Returns: | |
str: The string after add indent. | |
""" | |
s = s_.split('\n') | |
# don't do anything for single-line stuff | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(num_spaces * ' ') + line for line in s] | |
s = '\n'.join(s) # type: ignore | |
s = first + '\n' + s # type: ignore | |
return s # type: ignore | |
def dump(obj: Any) -> str: | |
"""Represent the object. | |
Args: | |
obj (Any): The obj to represent. | |
Returns: | |
str: The represented str. | |
""" | |
_repr = '' | |
if isinstance(obj, dict): | |
for k, v in obj.items(): | |
_repr += f'\n{k}: {_addindent(dump(v), 4)}' | |
elif isinstance(obj, BaseDataElement): | |
_repr += '\n\n META INFORMATION' | |
metainfo_items = dict(obj.metainfo_items()) | |
_repr += _addindent(dump(metainfo_items), 4) | |
_repr += '\n\n DATA FIELDS' | |
items = dict(obj.items()) | |
_repr += _addindent(dump(items), 4) | |
classname = obj.__class__.__name__ | |
_repr = f'<{classname}({_repr}\n) at {hex(id(obj))}>' | |
else: | |
_repr += repr(obj) | |
return _repr | |
return dump(self) | |