Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from .base_data_element import BaseDataElement | |
class LabelData(BaseDataElement): | |
"""Data structure for label-level annotations or predictions.""" | |
def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: | |
"""Convert the one-hot input to label. | |
Args: | |
onehot (torch.Tensor, optional): The one-hot input. The format | |
of input must be one-hot. | |
Returns: | |
torch.Tensor: The converted results. | |
""" | |
assert isinstance(onehot, torch.Tensor) | |
if (onehot.ndim == 1 and onehot.max().item() <= 1 | |
and onehot.min().item() >= 0): | |
return onehot.nonzero().squeeze(-1) | |
else: | |
raise ValueError( | |
'input is not one-hot and can not convert to label') | |
def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: | |
"""Convert the label-format input to one-hot. | |
Args: | |
label (torch.Tensor): The label-format input. The format | |
of item must be label-format. | |
num_classes (int): The number of classes. | |
Returns: | |
torch.Tensor: The converted results. | |
""" | |
assert isinstance(label, torch.Tensor) | |
onehot = label.new_zeros((num_classes, )) | |
assert max(label, default=torch.tensor(0)).item() < num_classes | |
onehot[label] = 1 | |
return onehot | |