Spaces:
Build error
Build error
File size: 1,653 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from .base_data_element import BaseDataElement
class LabelData(BaseDataElement):
"""Data structure for label-level annotations or predictions."""
@staticmethod
def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor:
"""Convert the one-hot input to label.
Args:
onehot (torch.Tensor, optional): The one-hot input. The format
of input must be one-hot.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(onehot, torch.Tensor)
if (onehot.ndim == 1 and onehot.max().item() <= 1
and onehot.min().item() >= 0):
return onehot.nonzero().squeeze(-1)
else:
raise ValueError(
'input is not one-hot and can not convert to label')
@staticmethod
def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor:
"""Convert the label-format input to one-hot.
Args:
label (torch.Tensor): The label-format input. The format
of item must be label-format.
num_classes (int): The number of classes.
Returns:
torch.Tensor: The converted results.
"""
assert isinstance(label, torch.Tensor)
onehot = label.new_zeros((num_classes, ))
assert max(label, default=torch.tensor(0)).item() < num_classes
onehot[label] = 1
return onehot
|