Spaces:
Build error
Build error
File size: 5,118 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import List, Sequence, Union
import numpy as np
import torch
from .base_data_element import BaseDataElement
class PixelData(BaseDataElement):
"""Data structure for pixel-level annotations or predictions.
All data items in ``data_fields`` of ``PixelData`` meet the following
requirements:
- They all have 3 dimensions in orders of channel, height, and width.
- They should have the same height and width.
Examples:
>>> metainfo = dict(
... img_id=random.randint(0, 100),
... img_shape=(random.randint(400, 600), random.randint(400, 600)))
>>> image = np.random.randint(0, 255, (4, 20, 40))
>>> featmap = torch.randint(0, 255, (10, 20, 40))
>>> pixel_data = PixelData(metainfo=metainfo,
... image=image,
... featmap=featmap)
>>> print(pixel_data.shape)
(20, 40)
>>> # slice
>>> slice_data = pixel_data[10:20, 20:40]
>>> assert slice_data.shape == (10, 20)
>>> slice_data = pixel_data[10, 20]
>>> assert slice_data.shape == (1, 1)
>>> # set
>>> pixel_data.map3 = torch.randint(0, 255, (20, 40))
>>> assert tuple(pixel_data.map3.shape) == (1, 20, 40)
>>> with self.assertRaises(AssertionError):
... # The dimension must be 3 or 2
... pixel_data.map2 = torch.randint(0, 255, (1, 3, 20, 40))
"""
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray]):
"""Set attributes of ``PixelData``.
If the dimension of value is 2 and its shape meet the demand, it
will automatically expand its channel-dimension.
Args:
name (str): The key to access the value, stored in `PixelData`.
value (Union[torch.Tensor, np.ndarray]): The value to store in.
The type of value must be `torch.Tensor` or `np.ndarray`,
and its shape must meet the requirements of `PixelData`.
"""
if name in ('_metainfo_fields', '_data_fields'):
if not hasattr(self, name):
super().__setattr__(name, value)
else:
raise AttributeError(f'{name} has been used as a '
'private attribute, which is immutable.')
else:
assert isinstance(value, (torch.Tensor, np.ndarray)), \
f'Can not set {type(value)}, only support' \
f' {(torch.Tensor, np.ndarray)}'
if self.shape:
assert tuple(value.shape[-2:]) == self.shape, (
'The height and width of '
f'values {tuple(value.shape[-2:])} is '
'not consistent with '
'the shape of this '
':obj:`PixelData` '
f'{self.shape}')
assert value.ndim in [
2, 3
], f'The dim of value must be 2 or 3, but got {value.ndim}'
if value.ndim == 2:
value = value[None]
warnings.warn('The shape of value will convert from '
f'{value.shape[-2:]} to {value.shape}')
super().__setattr__(name, value)
# TODO torch.Long/bool
def __getitem__(self, item: Sequence[Union[int, slice]]) -> 'PixelData':
"""
Args:
item (Sequence[Union[int, slice]]): Get the corresponding values
according to item.
Returns:
:obj:`PixelData`: Corresponding values.
"""
new_data = self.__class__(metainfo=self.metainfo)
if isinstance(item, tuple):
assert len(item) == 2, 'Only support to slice height and width'
tmp_item: List[slice] = list()
for index, single_item in enumerate(item[::-1]):
if isinstance(single_item, int):
tmp_item.insert(
0, slice(single_item, None, self.shape[-index - 1]))
elif isinstance(single_item, slice):
tmp_item.insert(0, single_item)
else:
raise TypeError(
'The type of element in input must be int or slice, '
f'but got {type(single_item)}')
tmp_item.insert(0, slice(None, None, None))
item = tuple(tmp_item)
for k, v in self.items():
setattr(new_data, k, v[item])
else:
raise TypeError(
f'Unsupported type {type(item)} for slicing PixelData')
return new_data
@property
def shape(self):
"""The shape of pixel data."""
if len(self._data_fields) > 0:
return tuple(self.values()[0].shape[-2:])
else:
return None
# TODO padding, resize
|