Spaces:
Sleeping
Sleeping
# Copyright (c) ONNX Project Contributors | |
# SPDX-License-Identifier: Apache-2.0 | |
import numpy as np | |
from onnx.reference.op_run import OpRun | |
def _conv_implementation( # type: ignore | |
X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides | |
): | |
if dilations is None: | |
dilations = [1 for s in X.shape[2:]] | |
if kernel_shape is None: | |
kernel_shape = W.shape[2:] | |
if pads is None: | |
pads = [0 for s in X.shape[2:]] * 2 | |
if strides is None: | |
strides = [1 for s in X.shape[2:]] | |
if X.shape[1] != W.shape[1] * group or W.shape[0] % group != 0: | |
raise ValueError( | |
f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={group}, " | |
f"W should be {(W.shape[0], X.shape[1] // group, np.prod(W.shape[1:]) // X.shape[1] * group)}." | |
) | |
if group > 1: | |
res = [] | |
td = 0 | |
mg = W.shape[0] // group | |
dw = W.shape[1] | |
for b in range(X.shape[0]): | |
for g in range(group): | |
gx = X[b : b + 1, g * dw : (g + 1) * dw] | |
gw = W[g * mg : (g + 1) * mg] | |
try: | |
cv = _conv_implementation( | |
gx, | |
gw, | |
None, | |
auto_pad, | |
dilations, | |
1, | |
kernel_shape, | |
pads, | |
strides, | |
) | |
except (ValueError, RuntimeError) as e: | |
raise ValueError( | |
f"Shape inconsistencies, X.shape={X.shape}, W.shape={W.shape}, group={g}/{group}, " | |
f"gx.shape={gx.shape}, gw.shape={gw.shape}, auto_pad={auto_pad}, " | |
f"dilations={dilations}, kernel_shape={kernel_shape}, pads={pads}, " | |
f"strides={strides}." | |
) from e | |
if b == 0: | |
td += cv.shape[1] | |
res.append((b, cv)) | |
new_shape = [X.shape[0], *list(res[0][1].shape[1:])] | |
new_shape[1] = td | |
final = np.zeros(tuple(new_shape), dtype=res[0][1].dtype) | |
p = 0 | |
for b, cv in res: | |
final[b : b + 1, p : p + cv.shape[1]] = cv | |
p += cv.shape[1] | |
if p >= final.shape[1]: | |
p = 0 | |
if B is not None: | |
new_shape = [1 for s in final.shape] | |
new_shape[1] = B.shape[0] | |
b = B.reshape(tuple(new_shape)) | |
final += b | |
return final | |
if dilations[0] != 1 or min(dilations) != max(dilations): | |
# Let's compute the dilated kernel. | |
nd = len(dilations) | |
new_kernel_shape = [] | |
new_shape = list(W.shape[:-nd]) | |
for i, d in enumerate(dilations): | |
di = len(W.shape) - nd + i | |
new_shape.append(W.shape[di] + (W.shape[di] - 1) * (d - 1)) | |
new_kernel_shape.append(kernel_shape[i] + (kernel_shape[i] - 1) * (d - 1)) | |
new_w = np.zeros(tuple(new_shape), dtype=W.dtype) | |
indices = [slice(0, new_w.shape[0]), slice(0, new_w.shape[1])] | |
for i, d in enumerate(dilations): | |
di = len(W.shape) - nd + i | |
indices.append(slice(0, new_w.shape[di], d)) | |
new_w[tuple(indices)] = W | |
W = new_w | |
kernel_shape = new_kernel_shape | |
if auto_pad in {"SAME_LOWER", "SAME_UPPER", "VALID"}: | |
head = [] | |
tail = [] | |
for i in range(len(X.shape) - 2): | |
d = X.shape[i] | |
target_size = (d + strides[i] - 1) // strides[i] | |
pad_needed = (target_size - 1) * strides[i] + kernel_shape[i] - d | |
if auto_pad == "SAME_LOWER": | |
pad_head = (pad_needed + 1) // 2 | |
else: | |
pad_head = pad_needed // 2 | |
pad_tail = pad_needed - pad_head | |
head.append(pad_head) | |
tail.append(pad_tail) | |
pads = head + tail | |
if len(X.shape) == 3: | |
sN, sC, sH = X.shape | |
# M, C_group, kH, kW = W.shape | |
(kh,) = kernel_shape | |
(sth,) = strides | |
h_out = int(((sH - kh + pads[0] + pads[1]) / sth) + 1) | |
h0 = pads[0] | |
oh = -1 * (kh % 2) | |
bh = -h0 | |
eh = h_out * sth | |
res = np.zeros((X.shape[0], W.shape[0], h_out)) # type: ignore[assignment] | |
if B is not None: | |
res[:, :, :] += B.reshape((1, -1, 1)) # type: ignore | |
for n in range(0, sN): | |
for nw in range(W.shape[0]): | |
for c in range(0, sC): | |
w = W[nw : nw + 1, c : c + 1] | |
for io in range(bh, eh, sth): | |
hr = (io - bh) // sth | |
if hr >= h_out: | |
continue | |
i = io + kh % 2 | |
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) | |
img = X[n : n + 1, c : c + 1, ih1:ih2] | |
if img.shape != w.shape: | |
jh1, jh2 = max(-oh - i, 0), min(kh, kh + sH - (i + oh + kh)) | |
w_ = w[:1, :1, jh1:jh2] | |
if img.shape != w_.shape: | |
raise RuntimeError( | |
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, " | |
f"i={i}, kh={kh}, sH={sH}, sth={sth}." | |
) | |
s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[ | |
0, 0 | |
] # (img * w_).sum() | |
else: | |
s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[ | |
0, 0 | |
] # (img * w).sum() | |
res[n, nw, hr] += s # type: ignore | |
return res | |
if len(X.shape) == 4: | |
sN, sC, sH, sW = X.shape | |
# M, C_group, kH, kW = W.shape | |
kh, kw = kernel_shape | |
sth, stw = strides | |
h_out = int(((sH - kh + pads[0] + pads[2]) / sth) + 1) | |
w_out = int(((sW - kw + pads[1] + pads[3]) / stw) + 1) | |
h0, w0 = pads[0], pads[1] | |
oh, ow = -1 * (kh % 2), -1 * (kw % 2) | |
bh, bw = -h0, -w0 | |
eh, ew = h_out * sth, w_out * stw | |
res = np.zeros((X.shape[0], W.shape[0], h_out, w_out)) # type: ignore[assignment] | |
if B is not None: | |
res[:, :, :, :] = B.reshape((1, -1, 1, 1)) # type: ignore | |
for n in range(0, sN): | |
for nw in range(W.shape[0]): | |
for c in range(0, sC): | |
w = W[nw : nw + 1, c : c + 1] | |
for io in range(bh, eh, sth): | |
hr = (io - bh) // sth | |
if hr >= h_out: | |
continue | |
i = io + kh % 2 | |
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) | |
for jo in range(bw, ew, stw): | |
wr = (jo - bw) // stw | |
if wr >= w_out: | |
continue | |
j = jo + kw % 2 | |
iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW) | |
img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2] | |
if img.shape != w.shape: | |
jh1, jh2 = max(-oh - i, 0), min( | |
kh, kh + sH - (i + oh + kh) | |
) | |
jw1, jw2 = max(-ow - j, 0), min( | |
kw, kw + sW - (j + ow + kw) | |
) | |
w_ = w[:1, :1, jh1:jh2, jw1:jw2] | |
if img.shape != w_.shape: | |
raise RuntimeError( | |
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, " | |
f"i={i}, j={j}, kh={kh}, kw={kw}, sH={sH}, sW={sW}, sth={sth}, stw={stw}." | |
) | |
s = np.dot(img.reshape((1, -1)), w_.reshape((-1, 1)))[ | |
0, 0 | |
] # (img * w_).sum() | |
else: | |
s = np.dot(img.reshape((1, -1)), w.reshape((-1, 1)))[ | |
0, 0 | |
] # (img * w).sum() | |
res[n, nw, hr, wr] += s # type: ignore | |
return res | |
if len(X.shape) == 5: | |
sN, sC, sH, sW, sZ = X.shape | |
kh, kw, kz = kernel_shape | |
sth, stw, stz = strides | |
h_out = int(((sH - kh + pads[0] + pads[3]) / sth) + 1) | |
w_out = int(((sW - kw + pads[1] + pads[4]) / stw) + 1) | |
z_out = int(((sZ - kz + pads[2] + pads[5]) / stz) + 1) | |
h0, w0, z0 = pads[0], pads[1], pads[2] | |
oh, ow, oz = -1 * (kh % 2), -1 * (kw % 2), -1 * (kz % 2) | |
bh, bw, bz = -h0, -w0, -z0 | |
eh, ew, ez = h_out * sth, w_out * stw, z_out * stz | |
res = np.zeros((X.shape[0], W.shape[0], h_out, w_out, z_out)) # type: ignore[assignment] | |
if B is not None: | |
res[:, :, :, :, :] = B.reshape((1, -1, 1, 1, 1)) # type: ignore | |
for n in range(0, sN): | |
for nw in range(W.shape[0]): | |
for c in range(0, sC): | |
w = W[nw : nw + 1, c : c + 1] | |
for io in range(bh, eh, sth): | |
hr = (io - bh) // sth | |
if hr >= h_out: | |
continue | |
i = io + kh % 2 | |
ih1, ih2 = max(0, i + oh), min(i + oh + kh, sH) | |
for jo in range(bw, ew, stw): | |
wr = (jo - bw) // stw | |
if wr >= w_out: | |
continue | |
j = jo + kw % 2 | |
iw1, iw2 = max(0, j + ow), min(j + ow + kw, sW) | |
for zo in range(bz, ez, stz): | |
zr = (zo - bz) // stz | |
if zr >= z_out: | |
continue | |
z = zo + kz % 2 | |
iz1, iz2 = max(0, z + oz), min(z + oz + kz, sZ) | |
img = X[n : n + 1, c : c + 1, ih1:ih2, iw1:iw2, iz1:iz2] | |
if img.shape != w.shape: | |
jh1, jh2 = max(-oh - i, 0), min( | |
kh, kh + sH - (i + oh + kh) | |
) | |
jw1, jw2 = max(-ow - j, 0), min( | |
kw, kw + sW - (j + ow + kw) | |
) | |
jz1, jz2 = max(-oz - z, 0), min( | |
kz, kz + sZ - (z + oz + kz) | |
) | |
w_ = w[:1, :1, jh1:jh2, jw1:jw2, jz1:jz2] | |
if img.shape != w_.shape: | |
raise RuntimeError( | |
f"Unexpected shape {img.shape} != {w_.shape}, oh={oh}, ow={ow}, oz={oz}, " | |
f"i={i}, j={j}, z={z}, kh={kh}, kw={kw}, kz={kz}, " | |
f"sH={sH}, sW={sW}, sZ={sZ}, sth={sth}, stw={stw}, stz={stz}." | |
) | |
s = np.dot( | |
img.reshape((1, -1)), w_.reshape((-1, 1)) | |
)[ | |
0, 0 | |
] # (img * w_).sum() | |
else: | |
s = np.dot( | |
img.reshape((1, -1)), w.reshape((-1, 1)) | |
)[ | |
0, 0 | |
] # (img * w).sum() | |
res[n, nw, hr, wr, zr] += s # type: ignore | |
return res | |
raise RuntimeError( | |
f"The convolution for X.shape={X.shape}, W.shape={W.shape}, " | |
f"kernel_shape={kernel_shape} is not implemented yet." | |
) | |
class Conv(OpRun): | |
def _run( # type: ignore | |
self, | |
X, | |
W, | |
B=None, | |
auto_pad=None, | |
dilations=None, | |
group=None, | |
kernel_shape=None, | |
pads=None, | |
strides=None, | |
): | |
if len(X.shape) < 3: | |
raise ValueError( | |
f"X must have at least 3 dimensions but its shape is {X.shape}." | |
) | |
return ( | |
# _conv_implementation( | |
_conv_implementation( | |
X, W, B, auto_pad, dilations, group, kernel_shape, pads, strides | |
).astype(X.dtype), | |
) | |