Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import warnings | |
from typing import List, Sequence, Union | |
import numpy as np | |
import torch | |
from .base_data_element import BaseDataElement | |
class PixelData(BaseDataElement): | |
"""Data structure for pixel-level annotations or predictions. | |
All data items in ``data_fields`` of ``PixelData`` meet the following | |
requirements: | |
- They all have 3 dimensions in orders of channel, height, and width. | |
- They should have the same height and width. | |
Examples: | |
>>> metainfo = dict( | |
... img_id=random.randint(0, 100), | |
... img_shape=(random.randint(400, 600), random.randint(400, 600))) | |
>>> image = np.random.randint(0, 255, (4, 20, 40)) | |
>>> featmap = torch.randint(0, 255, (10, 20, 40)) | |
>>> pixel_data = PixelData(metainfo=metainfo, | |
... image=image, | |
... featmap=featmap) | |
>>> print(pixel_data.shape) | |
(20, 40) | |
>>> # slice | |
>>> slice_data = pixel_data[10:20, 20:40] | |
>>> assert slice_data.shape == (10, 20) | |
>>> slice_data = pixel_data[10, 20] | |
>>> assert slice_data.shape == (1, 1) | |
>>> # set | |
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40)) | |
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40) | |
>>> with self.assertRaises(AssertionError): | |
... # The dimension must be 3 or 2 | |
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40)) | |
""" | |
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]): | |
"""Set attributes of ``PixelData``. | |
If the dimension of value is 2 and its shape meet the demand, it | |
will automatically expand its channel-dimension. | |
Args: | |
name (str): The key to access the value, stored in `PixelData`. | |
value (Union[torch.Tensor, np.ndarray]): The value to store in. | |
The type of value must be `torch.Tensor` or `np.ndarray`, | |
and its shape must meet the requirements of `PixelData`. | |
""" | |
if name in ('_metainfo_fields', '_data_fields'): | |
if not hasattr(self, name): | |
super().__setattr__(name, value) | |
else: | |
raise AttributeError(f'{name} has been used as a ' | |
'private attribute, which is immutable.') | |
else: | |
assert isinstance(value, (torch.Tensor, np.ndarray)), \ | |
f'Can not set {type(value)}, only support' \ | |
f' {(torch.Tensor, np.ndarray)}' | |
if self.shape: | |
assert tuple(value.shape[-2:]) == self.shape, ( | |
'The height and width of ' | |
f'values {tuple(value.shape[-2:])} is ' | |
'not consistent with ' | |
'the shape of this ' | |
':obj:`PixelData` ' | |
f'{self.shape}') | |
assert value.ndim in [ | |
2, 3 | |
], f'The dim of value must be 2 or 3, but got {value.ndim}' | |
if value.ndim == 2: | |
value = value[None] | |
warnings.warn('The shape of value will convert from ' | |
f'{value.shape[-2:]} to {value.shape}') | |
super().__setattr__(name, value) | |
# TODO torch.Long/bool | |
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData': | |
""" | |
Args: | |
item (Sequence[Union[int, slice]]): Get the corresponding values | |
according to item. | |
Returns: | |
:obj:`PixelData`: Corresponding values. | |
""" | |
new_data = self.__class__(metainfo=self.metainfo) | |
if isinstance(item, tuple): | |
assert len(item) == 2, 'Only support to slice height and width' | |
tmp_item: List[slice] = list() | |
for index, single_item in enumerate(item[::-1]): | |
if isinstance(single_item, int): | |
tmp_item.insert( | |
0, slice(single_item, None, self.shape[-index - 1])) | |
elif isinstance(single_item, slice): | |
tmp_item.insert(0, single_item) | |
else: | |
raise TypeError( | |
'The type of element in input must be int or slice, ' | |
f'but got {type(single_item)}') | |
tmp_item.insert(0, slice(None, None, None)) | |
item = tuple(tmp_item) | |
for k, v in self.items(): | |
setattr(new_data, k, v[item]) | |
else: | |
raise TypeError( | |
f'Unsupported type {type(item)} for slicing PixelData') | |
return new_data | |
def shape(self): | |
"""The shape of pixel data.""" | |
if len(self._data_fields) > 0: | |
return tuple(self.values()[0].shape[-2:]) | |
else: | |
return None | |
# TODO padding, resize | |