Spaces:
Build error
Build error
File size: 4,450 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from numbers import Number
from typing import Sequence, Union
import mmengine
import numpy as np
import torch
from mmengine.structures import BaseDataElement, LabelData
def format_label(value: Union[torch.Tensor, np.ndarray, Sequence, int],
num_classes: int = None) -> LabelData:
"""Convert label of various python types to :obj:`mmengine.LabelData`.
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int`.
Args:
value (torch.Tensor | numpy.ndarray | Sequence | int): Label value.
num_classes (int, optional): The number of classes. If not None, set
it to the metainfo. Defaults to None.
Returns:
:obj:`mmengine.LabelData`: The foramtted label data.
"""
# Handle single number
if isinstance(value, (torch.Tensor, np.ndarray)) and value.ndim == 0:
value = int(value.item())
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
elif isinstance(value, Sequence) and not mmengine.utils.is_str(value):
value = torch.tensor(value)
elif isinstance(value, int):
value = torch.LongTensor([value])
elif not isinstance(value, torch.Tensor):
raise TypeError(f'Type {type(value)} is not an available label type.')
metainfo = {}
if num_classes is not None:
metainfo['num_classes'] = num_classes
if value.max() >= num_classes:
raise ValueError(f'The label data ({value}) should not '
f'exceed num_classes ({num_classes}).')
label = LabelData(label=value, metainfo=metainfo)
return label
class ReIDDataSample(BaseDataElement):
"""A data structure interface of ReID task.
It's used as interfaces between different components.
Meta field:
img_shape (Tuple): The shape of the corresponding input image.
Used for visualization.
ori_shape (Tuple): The original shape of the corresponding image.
Used for visualization.
num_classes (int): The number of all categories.
Used for label format conversion.
Data field:
gt_label (LabelData): The ground truth label.
pred_label (LabelData): The predicted label.
scores (torch.Tensor): The outputs of model.
"""
@property
def gt_label(self):
return self._gt_label
@gt_label.setter
def gt_label(self, value: LabelData):
self.set_field(value, '_gt_label', dtype=LabelData)
@gt_label.deleter
def gt_label(self):
del self._gt_label
def set_gt_label(
self, value: Union[np.ndarray, torch.Tensor, Sequence[Number], Number]
) -> 'ReIDDataSample':
"""Set label of ``gt_label``."""
label = format_label(value, self.get('num_classes'))
if 'gt_label' in self: # setting for the second time
self.gt_label.label = label.label
else: # setting for the first time
self.gt_label = label
return self
def set_gt_score(self, value: torch.Tensor) -> 'ReIDDataSample':
"""Set score of ``gt_label``."""
assert isinstance(value, torch.Tensor), \
f'The value should be a torch.Tensor but got {type(value)}.'
assert value.ndim == 1, \
f'The dims of value should be 1, but got {value.ndim}.'
if 'num_classes' in self:
assert value.size(0) == self.num_classes, \
f"The length of value ({value.size(0)}) doesn't "\
f'match the num_classes ({self.num_classes}).'
metainfo = {'num_classes': self.num_classes}
else:
metainfo = {'num_classes': value.size(0)}
if 'gt_label' in self: # setting for the second time
self.gt_label.score = value
else: # setting for the first time
self.gt_label = LabelData(score=value, metainfo=metainfo)
return self
@property
def pred_feature(self):
return self._pred_feature
@pred_feature.setter
def pred_feature(self, value: torch.Tensor):
self.set_field(value, '_pred_feature', dtype=torch.Tensor)
@pred_feature.deleter
def pred_feature(self):
del self._pred_feature
|