Spaces:
Sleeping
Sleeping
File size: 5,548 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from .module import Module
from typing import Tuple, Union
from torch import Tensor
from torch.types import _size
__all__ = ['Flatten', 'Unflatten']
class Flatten(Module):
r"""
Flattens a contiguous range of dims into a tensor.
For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
Shape:
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
number of dimensions including none.
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
Args:
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
Examples::
>>> input = torch.randn(32, 1, 5, 5)
>>> # With default parameters
>>> m = nn.Flatten()
>>> output = m(input)
>>> output.size()
torch.Size([32, 25])
>>> # With non-default parameters
>>> m = nn.Flatten(0, 2)
>>> output = m(input)
>>> output.size()
torch.Size([160, 5])
"""
__constants__ = ['start_dim', 'end_dim']
start_dim: int
end_dim: int
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, input: Tensor) -> Tensor:
return input.flatten(self.start_dim, self.end_dim)
def extra_repr(self) -> str:
return f'start_dim={self.start_dim}, end_dim={self.end_dim}'
class Unflatten(Module):
r"""
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
(tuple of `(name, size)` tuples) for `NamedTensor` input.
Shape:
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
Args:
dim (Union[int, str]): Dimension to be unflattened
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
Examples:
>>> input = torch.randn(2, 50)
>>> # With tuple of ints
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With torch.Size
>>> m = nn.Sequential(
>>> nn.Linear(50, 50),
>>> nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
"""
NamedShape = Tuple[Tuple[str, int]]
__constants__ = ['dim', 'unflattened_size']
dim: Union[int, str]
unflattened_size: Union[_size, NamedShape]
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
super().__init__()
if isinstance(dim, int):
self._require_tuple_int(unflattened_size)
elif isinstance(dim, str):
self._require_tuple_tuple(unflattened_size)
else:
raise TypeError("invalid argument type for dim parameter")
self.dim = dim
self.unflattened_size = unflattened_size
def _require_tuple_tuple(self, input):
if (isinstance(input, tuple)):
for idx, elem in enumerate(input):
if not isinstance(elem, tuple):
raise TypeError("unflattened_size must be tuple of tuples, " +
f"but found element of type {type(elem).__name__} at pos {idx}")
return
raise TypeError("unflattened_size must be a tuple of tuples, " +
f"but found type {type(input).__name__}")
def _require_tuple_int(self, input):
if (isinstance(input, (tuple, list))):
for idx, elem in enumerate(input):
if not isinstance(elem, int):
raise TypeError("unflattened_size must be tuple of ints, " +
f"but found element of type {type(elem).__name__} at pos {idx}")
return
raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}")
def forward(self, input: Tensor) -> Tensor:
return input.unflatten(self.dim, self.unflattened_size)
def extra_repr(self) -> str:
return f'dim={self.dim}, unflattened_size={self.unflattened_size}'
|