Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from .base_data_element import BaseDataElement | |
| class LabelData(BaseDataElement): | |
| """Data structure for label-level annotations or predictions.""" | |
| def onehot_to_label(onehot: torch.Tensor) -> torch.Tensor: | |
| """Convert the one-hot input to label. | |
| Args: | |
| onehot (torch.Tensor, optional): The one-hot input. The format | |
| of input must be one-hot. | |
| Returns: | |
| torch.Tensor: The converted results. | |
| """ | |
| assert isinstance(onehot, torch.Tensor) | |
| if (onehot.ndim == 1 and onehot.max().item() <= 1 | |
| and onehot.min().item() >= 0): | |
| return onehot.nonzero().squeeze(-1) | |
| else: | |
| raise ValueError( | |
| 'input is not one-hot and can not convert to label') | |
| def label_to_onehot(label: torch.Tensor, num_classes: int) -> torch.Tensor: | |
| """Convert the label-format input to one-hot. | |
| Args: | |
| label (torch.Tensor): The label-format input. The format | |
| of item must be label-format. | |
| num_classes (int): The number of classes. | |
| Returns: | |
| torch.Tensor: The converted results. | |
| """ | |
| assert isinstance(label, torch.Tensor) | |
| onehot = label.new_zeros((num_classes, )) | |
| assert max(label, default=torch.tensor(0)).item() < num_classes | |
| onehot[label] = 1 | |
| return onehot | |