Spaces:
Running
Running
File size: 1,692 Bytes
dc2106c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from onnx.reference.op_run import OpRun
class CumSum(OpRun):
def _run(self, x, *axis, exclusive=None, reverse=None): # type: ignore
axis = None if not axis else axis[0] # type: ignore
if axis is None: # type: ignore
if reverse or exclusive:
raise NotImplementedError("reverse=1 or exclusive=1 not implemented")
return (np.cumsum(x),)
if not isinstance(axis, (np.int32, np.int64)):
if len(axis.shape) > 1 or (len(axis.shape) > 0 and axis.shape[0] != 1): # type: ignore
raise RuntimeError(
f"axis must be an array of one number not {axis} (shape {axis.shape})." # type: ignore
)
if len(axis.shape) > 0: # type: ignore
axis = axis[0]
if reverse:
rev_indices = [slice(0, s) for s in x.shape]
rev_indices[axis] = slice(None, None, -1) # type: ignore
x = x[tuple(rev_indices)]
if exclusive:
indices_c = [slice(0, s) for s in x.shape]
indices_d = [slice(0, s) for s in x.shape]
indices_c[axis] = slice(0, -1) # type: ignore
indices_d[axis] = slice(1, x.shape[axis]) # type: ignore
res = np.zeros(x.shape, dtype=x.dtype)
np.cumsum(x[tuple(indices_c)], axis=axis, out=res[tuple(indices_d)]) # type: ignore
else:
res = np.cumsum(x, axis=axis) # type: ignore
if reverse:
res = res[tuple(rev_indices)]
return (res,)
|