Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| """ | |
| Wrappers around on some nn functions, mainly to support empty tensors. | |
| Ideally, add support directly in PyTorch to empty tensors in those functions. | |
| These can be removed once https://github.com/pytorch/pytorch/issues/12013 | |
| is implemented | |
| """ | |
| import warnings | |
| from typing import List, Optional | |
| import torch | |
| from torch.nn import functional as F | |
| from detectron2.utils.env import TORCH_VERSION | |
| def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor: | |
| """ | |
| Turn a list of integer scalars or integer Tensor scalars into a vector, | |
| in a way that's both traceable and scriptable. | |
| In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs. | |
| In scripting or eager, `x` should be a list of int. | |
| """ | |
| if torch.jit.is_scripting(): | |
| return torch.as_tensor(x, device=device) | |
| if torch.jit.is_tracing(): | |
| assert all( | |
| [isinstance(t, torch.Tensor) for t in x] | |
| ), "Shape should be tensor during tracing!" | |
| # as_tensor should not be used in tracing because it records a constant | |
| ret = torch.stack(x) | |
| if ret.device != device: # avoid recording a hard-coded device if not necessary | |
| ret = ret.to(device=device) | |
| return ret | |
| return torch.as_tensor(x, device=device) | |
| def check_if_dynamo_compiling(): | |
| if TORCH_VERSION >= (1, 14): | |
| from torch._dynamo import is_compiling | |
| return is_compiling() | |
| else: | |
| return False | |
| def cat(tensors: List[torch.Tensor], dim: int = 0): | |
| """ | |
| Efficient version of torch.cat that avoids a copy if there is only a single element in a list | |
| """ | |
| assert isinstance(tensors, (list, tuple)) | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| return torch.cat(tensors, dim) | |
| def empty_input_loss_func_wrapper(loss_func): | |
| def wrapped_loss_func(input, target, *, reduction="mean", **kwargs): | |
| """ | |
| Same as `loss_func`, but returns 0 (instead of nan) for empty inputs. | |
| """ | |
| if target.numel() == 0 and reduction == "mean": | |
| return input.sum() * 0.0 # connect the gradient | |
| return loss_func(input, target, reduction=reduction, **kwargs) | |
| return wrapped_loss_func | |
| cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy) | |
| class _NewEmptyTensorOp(torch.autograd.Function): | |
| def forward(ctx, x, new_shape): | |
| ctx.shape = x.shape | |
| return x.new_empty(new_shape) | |
| def backward(ctx, grad): | |
| shape = ctx.shape | |
| return _NewEmptyTensorOp.apply(grad, shape), None | |
| class Conv2d(torch.nn.Conv2d): | |
| """ | |
| A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: | |
| Args: | |
| norm (nn.Module, optional): a normalization layer | |
| activation (callable(Tensor) -> Tensor): a callable activation function | |
| It assumes that norm layer is used before activation. | |
| """ | |
| norm = kwargs.pop("norm", None) | |
| activation = kwargs.pop("activation", None) | |
| super().__init__(*args, **kwargs) | |
| self.norm = norm | |
| self.activation = activation | |
| def forward(self, x): | |
| # torchscript does not support SyncBatchNorm yet | |
| # https://github.com/pytorch/pytorch/issues/40507 | |
| # and we skip these codes in torchscript since: | |
| # 1. currently we only support torchscript in evaluation mode | |
| # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or | |
| # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. | |
| if not torch.jit.is_scripting(): | |
| # Dynamo doesn't support context managers yet | |
| is_dynamo_compiling = check_if_dynamo_compiling() | |
| if not is_dynamo_compiling: | |
| with warnings.catch_warnings(record=True): | |
| if x.numel() == 0 and self.training: | |
| # https://github.com/pytorch/pytorch/issues/12013 | |
| assert not isinstance( | |
| self.norm, torch.nn.SyncBatchNorm | |
| ), "SyncBatchNorm does not support empty inputs!" | |
| x = F.conv2d( | |
| x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
| ) | |
| if self.norm is not None: | |
| x = self.norm(x) | |
| if self.activation is not None: | |
| x = self.activation(x) | |
| return x | |
| ConvTranspose2d = torch.nn.ConvTranspose2d | |
| BatchNorm2d = torch.nn.BatchNorm2d | |
| interpolate = F.interpolate | |
| Linear = torch.nn.Linear | |
| def nonzero_tuple(x): | |
| """ | |
| A 'as_tuple=True' version of torch.nonzero to support torchscript. | |
| because of https://github.com/pytorch/pytorch/issues/38718 | |
| """ | |
| if torch.jit.is_scripting(): | |
| if x.dim() == 0: | |
| return x.unsqueeze(0).nonzero().unbind(1) | |
| return x.nonzero().unbind(1) | |
| else: | |
| return x.nonzero(as_tuple=True) | |
| def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Tracing friendly way to cast tensor to another tensor's device. Device will be treated | |
| as constant during tracing, scripting the casting process as whole can workaround this issue. | |
| """ | |
| return src.to(dst.device) | |