File size: 1,455 Bytes
dc2106c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0


from onnx.reference.ops.experimental._op_run_experimental import OpRunExperimental
from onnx.reference.ops_optimized.op_conv_optimized import im2col_fast


class Im2Col(OpRunExperimental):
    def _run(self, img, kernel_shape, dilations=None, pads=None, strides=None):  # type: ignore
        if dilations is None:
            dilations = [1 for s in img.shape[2:]]
        if pads is None:
            pads = [0 for s in img.shape[2:]] * 2
        if strides is None:
            strides = [1 for s in img.shape[2:]]

        if min(dilations) == max(dilations) == 1:
            return (im2col_fast(img, tuple(kernel_shape[2:]), pads, strides)[0],)  # type: ignore

        if dilations[0] != 1 or min(dilations) != max(dilations):
            # Let's compute the dilated kernel.
            nd = len(dilations)
            new_kernel_shape = []
            new_shape = list(kernel_shape)
            for i, d in enumerate(dilations):
                di = len(kernel_shape) - nd + i
                new_shape.append(kernel_shape[di] + (kernel_shape[di] - 1) * (d - 1))
                new_kernel_shape.append(
                    kernel_shape[i] + (kernel_shape[i] - 1) * (d - 1)
                )
            kernel_shape = new_kernel_shape

        return (im2col_fast(img, tuple(kernel_shape[2:]), pads, strides),)  # type: ignore