Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
from onnx.reference.ops.experimental._op_run_experimental import OpRunExperimental | |
from onnx.reference.ops_optimized.op_conv_optimized import im2col_fast | |
class Im2Col(OpRunExperimental): | |
def _run(self, img, kernel_shape, dilations=None, pads=None, strides=None): # type: ignore | |
if dilations is None: | |
dilations = [1 for s in img.shape[2:]] | |
if pads is None: | |
pads = [0 for s in img.shape[2:]] * 2 | |
if strides is None: | |
strides = [1 for s in img.shape[2:]] | |
if min(dilations) == max(dilations) == 1: | |
return (im2col_fast(img, tuple(kernel_shape[2:]), pads, strides)[0],) # type: ignore | |
if dilations[0] != 1 or min(dilations) != max(dilations): | |
# Let's compute the dilated kernel. | |
nd = len(dilations) | |
new_kernel_shape = [] | |
new_shape = list(kernel_shape) | |
for i, d in enumerate(dilations): | |
di = len(kernel_shape) - nd + i | |
new_shape.append(kernel_shape[di] + (kernel_shape[di] - 1) * (d - 1)) | |
new_kernel_shape.append( | |
kernel_shape[i] + (kernel_shape[i] - 1) * (d - 1) | |
) | |
kernel_shape = new_kernel_shape | |
return (im2col_fast(img, tuple(kernel_shape[2:]), pads, strides),) # type: ignore | |