Spaces:
Running
Running
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import numpy as np | |
from onnx.reference.op_run import OpRun | |
from onnx.reference.ops.op_col2im import col2im_naive_implementation | |
class ConvTranspose(OpRun): | |
def _run( # type: ignore | |
self, | |
X, | |
W, | |
B=None, | |
auto_pad=None, | |
dilations=None, | |
group=None, | |
kernel_shape=None, | |
output_padding=None, | |
output_shape=None, | |
pads=None, | |
strides=None, | |
): | |
if group != 1: | |
raise RuntimeError(f"group={group} != 1 is not implemented yet.") | |
if dilations is None: | |
dilations = [1 for s in X.shape[2:]] | |
if kernel_shape is None: | |
kernel_shape = W.shape[2:] | |
if output_padding is None: | |
output_padding = [0 for s in X.shape[2:]] * 2 | |
if strides is None: | |
strides = [1 for s in X.shape[2:]] | |
if pads is None and auto_pad not in {"SAME_UPPER", "SAME_LOWER"}: | |
pads = [0 for i in range(2 * len(strides))] | |
if pads is None: | |
if output_shape is None: | |
output_shape = [ | |
X.shape[i + 2] * strides[i] for i in range(len(strides)) | |
] | |
total_padding = [ | |
strides[i] * (X.shape[i + 2] - 1) | |
+ output_padding[i] | |
+ ((kernel_shape[i] - 1) * dilations[i] + 1) | |
- output_shape[i] | |
for i in range(len(output_shape)) | |
] | |
pads_1 = [] | |
pads_2 = [] | |
for i in range(len(output_shape)): | |
if auto_pad == "SAME_UPPER": | |
pads_1.append(total_padding[i] // 2) | |
pads_2.append(total_padding[i] - (total_padding[i] // 2)) | |
else: | |
pads_1.append(total_padding[i] - (total_padding[i] // 2)) | |
pads_2.append(total_padding[i] // 2) | |
pads = pads_1 + pads_2 | |
n_dims = len(pads) // 2 | |
else: | |
n_dims = len(X.shape) - 2 | |
new_pads = np.array([(pads[i], pads[i + n_dims]) for i in range(n_dims)]) | |
if output_shape is None: | |
output_shape = [ | |
strides[i] * (X.shape[i + 2] - 1) | |
+ output_padding[i] | |
+ ((kernel_shape[i] - 1) * dilations[i] + 1) | |
- new_pads[i, :].sum() | |
for i in range(n_dims) | |
] | |
kernel_shape = W.shape[2:] | |
kernel_size = np.prod(kernel_shape) | |
num_output_channels = W.shape[1] * group | |
kernel_dim = num_output_channels // group * kernel_size | |
C = X.shape[1] # num_inputs_channels | |
m = kernel_dim # kernel_dim | |
n = np.prod(X.shape[2:]) # input_image_size | |
k = C // group | |
w_reshaped = W.reshape((group, k, m)) | |
final = None | |
# N x C x H x W = X.shape | |
# C x M/group x k1 x k2 = W.shape | |
if group == 1: | |
for image_id in range(X.shape[0]): | |
w_t = w_reshaped[0].T | |
gemm = np.matmul(w_t, X[image_id].reshape((k, n))) | |
gemmc = gemm.reshape((num_output_channels, -1, gemm.shape[-1])) | |
for c in range(num_output_channels): | |
res = col2im_naive_implementation( | |
gemmc[c], output_shape, kernel_shape, dilations, pads, strides | |
) | |
if final is None: | |
final = np.empty( | |
X.shape[:1] + (num_output_channels,) + res.shape, | |
dtype=X.dtype, | |
) | |
if B is not None: | |
res += B[c] | |
final[image_id, c, ...] = res[...] | |
else: | |
raise NotImplementedError( | |
f"Implementation for group={group} > 1 is not available yet." | |
) | |
return (final.astype(X.dtype),) # type: ignore[union-attr] | |