File size: 13,066 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from .module import Module
from .. import functional as F

from torch import Tensor
from ..common_types import _size_any_t

__all__ = ['Fold', 'Unfold']

class Fold(Module):
    r"""Combines an array of sliding local blocks into a large containing tensor.



    Consider a batched :attr:`input` tensor containing sliding local blocks,

    e.g., patches of images, of shape :math:`(N, C \times  \prod(\text{kernel\_size}), L)`,

    where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`

    is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`

    spatial locations each containing a :math:`C`-channeled vector), and

    :math:`L` is the total number of blocks. (This is exactly the

    same specification as the output shape of :class:`~torch.nn.Unfold`.) This

    operation combines these local blocks into the large :attr:`output` tensor

    of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`

    by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the

    arguments must satisfy



    .. math::

        L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %

            - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,



    where :math:`d` is over all spatial dimensions.



    * :attr:`output_size` describes the spatial shape of the large containing

      tensor of the sliding local blocks. It is useful to resolve the ambiguity

      when multiple input shapes map to same number of sliding blocks, e.g.,

      with ``stride > 0``.



    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify

    how the sliding blocks are retrieved.



    * :attr:`stride` controls the stride for the sliding blocks.



    * :attr:`padding` controls the amount of implicit zero-paddings on both

      sides for :attr:`padding` number of points for each dimension before

      reshaping.



    * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.

      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.



    Args:

        output_size (int or tuple): the shape of the spatial dimensions of the

                                    output (i.e., ``output.sizes()[2:]``)

        kernel_size (int or tuple): the size of the sliding blocks

        dilation (int or tuple, optional): a parameter that controls the

                                           stride of elements within the

                                           neighborhood. Default: 1

        padding (int or tuple, optional): implicit zero padding to be added on

                                          both sides of input. Default: 0

        stride (int or tuple): the stride of the sliding blocks in the input

                               spatial dimensions. Default: 1



    * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,

      :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then

      their values will be replicated across all spatial dimensions.



    * For the case of two output spatial dimensions this operation is sometimes

      called ``col2im``.



    .. note::

        :class:`~torch.nn.Fold` calculates each combined value in the resulting

        large tensor by summing all values from all containing blocks.

        :class:`~torch.nn.Unfold` extracts the values in the local blocks by

        copying from the large tensor. So, if the blocks overlap, they are not

        inverses of each other.



        In general, folding and unfolding operations are related as

        follows. Consider :class:`~torch.nn.Fold` and

        :class:`~torch.nn.Unfold` instances created with the same

        parameters:



        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)

        >>> fold = nn.Fold(output_size=..., **fold_params)

        >>> unfold = nn.Unfold(**fold_params)



        Then for any (supported) ``input`` tensor the following

        equality holds:



        ::



            fold(unfold(input)) == divisor * input



        where ``divisor`` is a tensor that depends only on the shape

        and dtype of the ``input``:



        >>> # xdoctest: +SKIP

        >>> input_ones = torch.ones(input.shape, dtype=input.dtype)

        >>> divisor = fold(unfold(input_ones))



        When the ``divisor`` tensor contains no zero elements, then

        ``fold`` and ``unfold`` operations are inverses of each

        other (up to constant divisor).



    .. warning::

        Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.



    Shape:

        - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)`

        - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`

          or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above



    Examples::



        >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))

        >>> input = torch.randn(1, 3 * 2 * 2, 12)

        >>> output = fold(input)

        >>> output.size()

        torch.Size([1, 3, 4, 5])



    .. _link:

        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md



    """

    __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
                     'stride']
    output_size: _size_any_t
    kernel_size: _size_any_t
    dilation: _size_any_t
    padding: _size_any_t
    stride: _size_any_t

    def __init__(

        self,

        output_size: _size_any_t,

        kernel_size: _size_any_t,

        dilation: _size_any_t = 1,

        padding: _size_any_t = 0,

        stride: _size_any_t = 1

    ) -> None:
        super().__init__()
        self.output_size = output_size
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = padding
        self.stride = stride

    def forward(self, input: Tensor) -> Tensor:
        return F.fold(input, self.output_size, self.kernel_size, self.dilation,
                      self.padding, self.stride)

    def extra_repr(self) -> str:
        return 'output_size={output_size}, kernel_size={kernel_size}, ' \
            'dilation={dilation}, padding={padding}, stride={stride}'.format(
                **self.__dict__
            )


class Unfold(Module):
    r"""Extracts sliding local blocks from a batched input tensor.



    Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,

    where :math:`N` is the batch dimension, :math:`C` is the channel dimension,

    and :math:`*` represent arbitrary spatial dimensions. This operation flattens

    each sliding :attr:`kernel_size`-sized block within the spatial dimensions

    of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`

    tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where

    :math:`C \times \prod(\text{kernel\_size})` is the total number of values

    within each block (a block has :math:`\prod(\text{kernel\_size})` spatial

    locations each containing a :math:`C`-channeled vector), and :math:`L` is

    the total number of such blocks:



    .. math::

        L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %

            - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,



    where :math:`\text{spatial\_size}` is formed by the spatial dimensions

    of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial

    dimensions.



    Therefore, indexing :attr:`output` at the last dimension (column dimension)

    gives all values within a certain block.



    The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify

    how the sliding blocks are retrieved.



    * :attr:`stride` controls the stride for the sliding blocks.



    * :attr:`padding` controls the amount of implicit zero-paddings on both

      sides for :attr:`padding` number of points for each dimension before

      reshaping.



    * :attr:`dilation` controls the spacing between the kernel points; also known as the à trous algorithm.

      It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.



    Args:

        kernel_size (int or tuple): the size of the sliding blocks

        dilation (int or tuple, optional): a parameter that controls the

                                           stride of elements within the

                                           neighborhood. Default: 1

        padding (int or tuple, optional): implicit zero padding to be added on

                                          both sides of input. Default: 0

        stride (int or tuple, optional): the stride of the sliding blocks in the input

                                         spatial dimensions. Default: 1



    * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or

      :attr:`stride` is an int or a tuple of length 1, their values will be

      replicated across all spatial dimensions.



    * For the case of two input spatial dimensions this operation is sometimes

      called ``im2col``.



    .. note::

        :class:`~torch.nn.Fold` calculates each combined value in the resulting

        large tensor by summing all values from all containing blocks.

        :class:`~torch.nn.Unfold` extracts the values in the local blocks by

        copying from the large tensor. So, if the blocks overlap, they are not

        inverses of each other.



        In general, folding and unfolding operations are related as

        follows. Consider :class:`~torch.nn.Fold` and

        :class:`~torch.nn.Unfold` instances created with the same

        parameters:



        >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)

        >>> fold = nn.Fold(output_size=..., **fold_params)

        >>> unfold = nn.Unfold(**fold_params)



        Then for any (supported) ``input`` tensor the following

        equality holds:



        ::



            fold(unfold(input)) == divisor * input



        where ``divisor`` is a tensor that depends only on the shape

        and dtype of the ``input``:



        >>> # xdoctest: +SKIP

        >>> input_ones = torch.ones(input.shape, dtype=input.dtype)

        >>> divisor = fold(unfold(input_ones))



        When the ``divisor`` tensor contains no zero elements, then

        ``fold`` and ``unfold`` operations are inverses of each

        other (up to constant divisor).



    .. warning::

        Currently, only 4-D input tensors (batched image-like tensors) are

        supported.



    Shape:

        - Input: :math:`(N, C, *)`

        - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above



    Examples::



        >>> unfold = nn.Unfold(kernel_size=(2, 3))

        >>> input = torch.randn(2, 5, 3, 4)

        >>> output = unfold(input)

        >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)

        >>> # 4 blocks (2x3 kernels) in total in the 3x4 input

        >>> output.size()

        torch.Size([2, 30, 4])



        >>> # xdoctest: +IGNORE_WANT

        >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)

        >>> inp = torch.randn(1, 3, 10, 12)

        >>> w = torch.randn(2, 3, 4, 5)

        >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))

        >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)

        >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))

        >>> # or equivalently (and avoiding a copy),

        >>> # out = out_unf.view(1, 2, 7, 8)

        >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()

        tensor(1.9073e-06)



    .. _link:

        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md



    """

    __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
    kernel_size: _size_any_t
    dilation: _size_any_t
    padding: _size_any_t
    stride: _size_any_t

    def __init__(

        self,

        kernel_size: _size_any_t,

        dilation: _size_any_t = 1,

        padding: _size_any_t = 0,

        stride: _size_any_t = 1

    ) -> None:
        super().__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = padding
        self.stride = stride

    def forward(self, input: Tensor) -> Tensor:
        return F.unfold(input, self.kernel_size, self.dilation,
                        self.padding, self.stride)

    def extra_repr(self) -> str:
        return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
            ' stride={stride}'.format(**self.__dict__)