Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| from dataclasses import make_dataclass | |
| from functools import lru_cache | |
| from typing import Any, Optional | |
| import torch | |
| def decorate_predictor_output_class_with_confidences(BasePredictorOutput: type) -> type: | |
| """ | |
| Create a new output class from an existing one by adding new attributes | |
| related to confidence estimation: | |
| - sigma_1 (tensor) | |
| - sigma_2 (tensor) | |
| - kappa_u (tensor) | |
| - kappa_v (tensor) | |
| - fine_segm_confidence (tensor) | |
| - coarse_segm_confidence (tensor) | |
| Details on confidence estimation parameters can be found in: | |
| N. Neverova, D. Novotny, A. Vedaldi "Correlated Uncertainty for Learning | |
| Dense Correspondences from Noisy Labels", p. 918--926, in Proc. NIPS 2019 | |
| A. Sanakoyeu et al., Transferring Dense Pose to Proximal Animal Classes, CVPR 2020 | |
| The new class inherits the provided `BasePredictorOutput` class, | |
| it's name is composed of the name of the provided class and | |
| "WithConfidences" suffix. | |
| Args: | |
| BasePredictorOutput (type): output type to which confidence data | |
| is to be added, assumed to be a dataclass | |
| Return: | |
| New dataclass derived from the provided one that has attributes | |
| for confidence estimation | |
| """ | |
| PredictorOutput = make_dataclass( | |
| BasePredictorOutput.__name__ + "WithConfidences", | |
| fields=[ | |
| ("sigma_1", Optional[torch.Tensor], None), | |
| ("sigma_2", Optional[torch.Tensor], None), | |
| ("kappa_u", Optional[torch.Tensor], None), | |
| ("kappa_v", Optional[torch.Tensor], None), | |
| ("fine_segm_confidence", Optional[torch.Tensor], None), | |
| ("coarse_segm_confidence", Optional[torch.Tensor], None), | |
| ], | |
| bases=(BasePredictorOutput,), | |
| ) | |
| # add possibility to index PredictorOutput | |
| def slice_if_not_none(data, item): | |
| if data is None: | |
| return None | |
| if isinstance(item, int): | |
| return data[item].unsqueeze(0) | |
| return data[item] | |
| def PredictorOutput_getitem(self, item): | |
| PredictorOutput = type(self) | |
| base_predictor_output_sliced = super(PredictorOutput, self).__getitem__(item) | |
| return PredictorOutput( | |
| **base_predictor_output_sliced.__dict__, | |
| coarse_segm_confidence=slice_if_not_none(self.coarse_segm_confidence, item), | |
| fine_segm_confidence=slice_if_not_none(self.fine_segm_confidence, item), | |
| sigma_1=slice_if_not_none(self.sigma_1, item), | |
| sigma_2=slice_if_not_none(self.sigma_2, item), | |
| kappa_u=slice_if_not_none(self.kappa_u, item), | |
| kappa_v=slice_if_not_none(self.kappa_v, item), | |
| ) | |
| PredictorOutput.__getitem__ = PredictorOutput_getitem | |
| def PredictorOutput_to(self, device: torch.device): | |
| """ | |
| Transfers all tensors to the given device | |
| """ | |
| PredictorOutput = type(self) | |
| base_predictor_output_to = super(PredictorOutput, self).to(device) # pyre-ignore[16] | |
| def to_device_if_tensor(var: Any): | |
| if isinstance(var, torch.Tensor): | |
| return var.to(device) | |
| return var | |
| return PredictorOutput( | |
| **base_predictor_output_to.__dict__, | |
| sigma_1=to_device_if_tensor(self.sigma_1), | |
| sigma_2=to_device_if_tensor(self.sigma_2), | |
| kappa_u=to_device_if_tensor(self.kappa_u), | |
| kappa_v=to_device_if_tensor(self.kappa_v), | |
| fine_segm_confidence=to_device_if_tensor(self.fine_segm_confidence), | |
| coarse_segm_confidence=to_device_if_tensor(self.coarse_segm_confidence), | |
| ) | |
| PredictorOutput.to = PredictorOutput_to | |
| return PredictorOutput | |