Spaces:
Running
Running
# mypy: ignore-errors | |
from __future__ import annotations | |
import builtins | |
import math | |
import operator | |
from typing import Sequence | |
import torch | |
from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util | |
from ._normalizations import ( | |
ArrayLike, | |
normalize_array_like, | |
normalizer, | |
NotImplementedType, | |
) | |
newaxis = None | |
FLAGS = [ | |
"C_CONTIGUOUS", | |
"F_CONTIGUOUS", | |
"OWNDATA", | |
"WRITEABLE", | |
"ALIGNED", | |
"WRITEBACKIFCOPY", | |
"FNC", | |
"FORC", | |
"BEHAVED", | |
"CARRAY", | |
"FARRAY", | |
] | |
SHORTHAND_TO_FLAGS = { | |
"C": "C_CONTIGUOUS", | |
"F": "F_CONTIGUOUS", | |
"O": "OWNDATA", | |
"W": "WRITEABLE", | |
"A": "ALIGNED", | |
"X": "WRITEBACKIFCOPY", | |
"B": "BEHAVED", | |
"CA": "CARRAY", | |
"FA": "FARRAY", | |
} | |
class Flags: | |
def __init__(self, flag_to_value: dict): | |
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check | |
self._flag_to_value = flag_to_value | |
def __getattr__(self, attr: str): | |
if attr.islower() and attr.upper() in FLAGS: | |
return self[attr.upper()] | |
else: | |
raise AttributeError(f"No flag attribute '{attr}'") | |
def __getitem__(self, key): | |
if key in SHORTHAND_TO_FLAGS.keys(): | |
key = SHORTHAND_TO_FLAGS[key] | |
if key in FLAGS: | |
try: | |
return self._flag_to_value[key] | |
except KeyError as e: | |
raise NotImplementedError(f"{key=}") from e | |
else: | |
raise KeyError(f"No flag key '{key}'") | |
def __setattr__(self, attr, value): | |
if attr.islower() and attr.upper() in FLAGS: | |
self[attr.upper()] = value | |
else: | |
super().__setattr__(attr, value) | |
def __setitem__(self, key, value): | |
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): | |
raise NotImplementedError("Modifying flags is not implemented") | |
else: | |
raise KeyError(f"No flag key '{key}'") | |
def create_method(fn, name=None): | |
name = name or fn.__name__ | |
def f(*args, **kwargs): | |
return fn(*args, **kwargs) | |
f.__name__ = name | |
f.__qualname__ = f"ndarray.{name}" | |
return f | |
# Map ndarray.name_method -> np.name_func | |
# If name_func == None, it means that name_method == name_func | |
methods = { | |
"clip": None, | |
"nonzero": None, | |
"repeat": None, | |
"round": None, | |
"squeeze": None, | |
"swapaxes": None, | |
"ravel": None, | |
# linalg | |
"diagonal": None, | |
"dot": None, | |
"trace": None, | |
# sorting | |
"argsort": None, | |
"searchsorted": None, | |
# reductions | |
"argmax": None, | |
"argmin": None, | |
"any": None, | |
"all": None, | |
"max": None, | |
"min": None, | |
"ptp": None, | |
"sum": None, | |
"prod": None, | |
"mean": None, | |
"var": None, | |
"std": None, | |
# scans | |
"cumsum": None, | |
"cumprod": None, | |
# advanced indexing | |
"take": None, | |
"choose": None, | |
} | |
dunder = { | |
"abs": "absolute", | |
"invert": None, | |
"pos": "positive", | |
"neg": "negative", | |
"gt": "greater", | |
"lt": "less", | |
"ge": "greater_equal", | |
"le": "less_equal", | |
} | |
# dunder methods with right-looking and in-place variants | |
ri_dunder = { | |
"add": None, | |
"sub": "subtract", | |
"mul": "multiply", | |
"truediv": "divide", | |
"floordiv": "floor_divide", | |
"pow": "power", | |
"mod": "remainder", | |
"and": "bitwise_and", | |
"or": "bitwise_or", | |
"xor": "bitwise_xor", | |
"lshift": "left_shift", | |
"rshift": "right_shift", | |
"matmul": None, | |
} | |
def _upcast_int_indices(index): | |
if isinstance(index, torch.Tensor): | |
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): | |
return index.to(torch.int64) | |
elif isinstance(index, tuple): | |
return tuple(_upcast_int_indices(i) for i in index) | |
return index | |
# Used to indicate that a parameter is unspecified (as opposed to explicitly | |
# `None`) | |
class _Unspecified: | |
pass | |
_Unspecified.unspecified = _Unspecified() | |
############################################################### | |
# ndarray class # | |
############################################################### | |
class ndarray: | |
def __init__(self, t=None): | |
if t is None: | |
self.tensor = torch.Tensor() | |
elif isinstance(t, torch.Tensor): | |
self.tensor = t | |
else: | |
raise ValueError( | |
"ndarray constructor is not recommended; prefer" | |
"either array(...) or zeros/empty(...)" | |
) | |
# Register NumPy functions as methods | |
for method, name in methods.items(): | |
fn = getattr(_funcs, name or method) | |
vars()[method] = create_method(fn, method) | |
# Regular methods but coming from ufuncs | |
conj = create_method(_ufuncs.conjugate, "conj") | |
conjugate = create_method(_ufuncs.conjugate) | |
for method, name in dunder.items(): | |
fn = getattr(_ufuncs, name or method) | |
method = f"__{method}__" | |
vars()[method] = create_method(fn, method) | |
for method, name in ri_dunder.items(): | |
fn = getattr(_ufuncs, name or method) | |
plain = f"__{method}__" | |
vars()[plain] = create_method(fn, plain) | |
rvar = f"__r{method}__" | |
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) | |
ivar = f"__i{method}__" | |
vars()[ivar] = create_method( | |
lambda self, other, fn=fn: fn(self, other, out=self), ivar | |
) | |
# There's no __idivmod__ | |
__divmod__ = create_method(_ufuncs.divmod, "__divmod__") | |
__rdivmod__ = create_method( | |
lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" | |
) | |
# prevent loop variables leaking into the ndarray class namespace | |
del ivar, rvar, name, plain, fn, method | |
def shape(self): | |
return tuple(self.tensor.shape) | |
def size(self): | |
return self.tensor.numel() | |
def ndim(self): | |
return self.tensor.ndim | |
def dtype(self): | |
return _dtypes.dtype(self.tensor.dtype) | |
def strides(self): | |
elsize = self.tensor.element_size() | |
return tuple(stride * elsize for stride in self.tensor.stride()) | |
def itemsize(self): | |
return self.tensor.element_size() | |
def flags(self): | |
# Note contiguous in torch is assumed C-style | |
return Flags( | |
{ | |
"C_CONTIGUOUS": self.tensor.is_contiguous(), | |
"F_CONTIGUOUS": self.T.tensor.is_contiguous(), | |
"OWNDATA": self.tensor._base is None, | |
"WRITEABLE": True, # pytorch does not have readonly tensors | |
} | |
) | |
def data(self): | |
return self.tensor.data_ptr() | |
def nbytes(self): | |
return self.tensor.storage().nbytes() | |
def T(self): | |
return self.transpose() | |
def real(self): | |
return _funcs.real(self) | |
def real(self, value): | |
self.tensor.real = asarray(value).tensor | |
def imag(self): | |
return _funcs.imag(self) | |
def imag(self, value): | |
self.tensor.imag = asarray(value).tensor | |
# ctors | |
def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): | |
if order != "K": | |
raise NotImplementedError(f"astype(..., order={order} is not implemented.") | |
if casting != "unsafe": | |
raise NotImplementedError( | |
f"astype(..., casting={casting} is not implemented." | |
) | |
if not subok: | |
raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") | |
if not copy: | |
raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") | |
torch_dtype = _dtypes.dtype(dtype).torch_dtype | |
t = self.tensor.to(torch_dtype) | |
return ndarray(t) | |
def copy(self: ArrayLike, order: NotImplementedType = "C"): | |
return self.clone() | |
def flatten(self: ArrayLike, order: NotImplementedType = "C"): | |
return torch.flatten(self) | |
def resize(self, *new_shape, refcheck=False): | |
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input. | |
if refcheck: | |
raise NotImplementedError( | |
f"resize(..., refcheck={refcheck} is not implemented." | |
) | |
if new_shape in [(), (None,)]: | |
return | |
# support both x.resize((2, 2)) and x.resize(2, 2) | |
if len(new_shape) == 1: | |
new_shape = new_shape[0] | |
if isinstance(new_shape, int): | |
new_shape = (new_shape,) | |
if builtins.any(x < 0 for x in new_shape): | |
raise ValueError("all elements of `new_shape` must be non-negative") | |
new_numel, old_numel = math.prod(new_shape), self.tensor.numel() | |
self.tensor.resize_(new_shape) | |
if new_numel >= old_numel: | |
# zero-fill new elements | |
assert self.tensor.is_contiguous() | |
b = self.tensor.flatten() # does not copy | |
b[old_numel:].zero_() | |
def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): | |
if dtype is _Unspecified.unspecified: | |
dtype = self.dtype | |
if type is not _Unspecified.unspecified: | |
raise NotImplementedError(f"view(..., type={type} is not implemented.") | |
torch_dtype = _dtypes.dtype(dtype).torch_dtype | |
tview = self.tensor.view(torch_dtype) | |
return ndarray(tview) | |
def fill(self, value: ArrayLike): | |
# Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and | |
# error out on D > 0 arrays | |
self.tensor.fill_(value) | |
def tolist(self): | |
return self.tensor.tolist() | |
def __iter__(self): | |
return (ndarray(x) for x in self.tensor.__iter__()) | |
def __str__(self): | |
return ( | |
str(self.tensor) | |
.replace("tensor", "torch.ndarray") | |
.replace("dtype=torch.", "dtype=") | |
) | |
__repr__ = create_method(__str__) | |
def __eq__(self, other): | |
try: | |
return _ufuncs.equal(self, other) | |
except (RuntimeError, TypeError): | |
# Failed to convert other to array: definitely not equal. | |
falsy = torch.full(self.shape, fill_value=False, dtype=bool) | |
return asarray(falsy) | |
def __ne__(self, other): | |
return ~(self == other) | |
def __index__(self): | |
try: | |
return operator.index(self.tensor.item()) | |
except Exception as exc: | |
raise TypeError( | |
"only integer scalar arrays can be converted to a scalar index" | |
) from exc | |
def __bool__(self): | |
return bool(self.tensor) | |
def __int__(self): | |
return int(self.tensor) | |
def __float__(self): | |
return float(self.tensor) | |
def __complex__(self): | |
return complex(self.tensor) | |
def is_integer(self): | |
try: | |
v = self.tensor.item() | |
result = int(v) == v | |
except Exception: | |
result = False | |
return result | |
def __len__(self): | |
return self.tensor.shape[0] | |
def __contains__(self, x): | |
return self.tensor.__contains__(x) | |
def transpose(self, *axes): | |
# np.transpose(arr, axis=None) but arr.transpose(*axes) | |
return _funcs.transpose(self, axes) | |
def reshape(self, *shape, order="C"): | |
# arr.reshape(shape) and arr.reshape(*shape) | |
return _funcs.reshape(self, shape, order=order) | |
def sort(self, axis=-1, kind=None, order=None): | |
# ndarray.sort works in-place | |
_funcs.copyto(self, _funcs.sort(self, axis, kind, order)) | |
def item(self, *args): | |
# Mimic NumPy's implementation with three special cases (no arguments, | |
# a flat index and a multi-index): | |
# https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702 | |
if args == (): | |
return self.tensor.item() | |
elif len(args) == 1: | |
# int argument | |
return self.ravel()[args[0]] | |
else: | |
return self.__getitem__(args) | |
def __getitem__(self, index): | |
tensor = self.tensor | |
def neg_step(i, s): | |
if not (isinstance(s, slice) and s.step is not None and s.step < 0): | |
return s | |
nonlocal tensor | |
tensor = torch.flip(tensor, (i,)) | |
# Account for the fact that a slice includes the start but not the end | |
assert isinstance(s.start, int) or s.start is None | |
assert isinstance(s.stop, int) or s.stop is None | |
start = s.stop + 1 if s.stop else None | |
stop = s.start + 1 if s.start else None | |
return slice(start, stop, -s.step) | |
if isinstance(index, Sequence): | |
index = type(index)(neg_step(i, s) for i, s in enumerate(index)) | |
else: | |
index = neg_step(0, index) | |
index = _util.ndarrays_to_tensors(index) | |
index = _upcast_int_indices(index) | |
return ndarray(tensor.__getitem__(index)) | |
def __setitem__(self, index, value): | |
index = _util.ndarrays_to_tensors(index) | |
index = _upcast_int_indices(index) | |
if not _dtypes_impl.is_scalar(value): | |
value = normalize_array_like(value) | |
value = _util.cast_if_needed(value, self.tensor.dtype) | |
return self.tensor.__setitem__(index, value) | |
take = _funcs.take | |
put = _funcs.put | |
def __dlpack__(self, *, stream=None): | |
return self.tensor.__dlpack__(stream=stream) | |
def __dlpack_device__(self): | |
return self.tensor.__dlpack_device__() | |
def _tolist(obj): | |
"""Recursively convert tensors into lists.""" | |
a1 = [] | |
for elem in obj: | |
if isinstance(elem, (list, tuple)): | |
elem = _tolist(elem) | |
if isinstance(elem, ndarray): | |
a1.append(elem.tensor.tolist()) | |
else: | |
a1.append(elem) | |
return a1 | |
# This is the ideally the only place which talks to ndarray directly. | |
# The rest goes through asarray (preferred) or array. | |
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): | |
if subok is not False: | |
raise NotImplementedError("'subok' parameter is not supported.") | |
if like is not None: | |
raise NotImplementedError("'like' parameter is not supported.") | |
if order != "K": | |
raise NotImplementedError() | |
# a happy path | |
if ( | |
isinstance(obj, ndarray) | |
and copy is False | |
and dtype is None | |
and ndmin <= obj.ndim | |
): | |
return obj | |
if isinstance(obj, (list, tuple)): | |
# FIXME and they have the same dtype, device, etc | |
if obj and all(isinstance(x, torch.Tensor) for x in obj): | |
# list of arrays: *under torch.Dynamo* these are FakeTensors | |
obj = torch.stack(obj) | |
else: | |
# XXX: remove tolist | |
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists | |
obj = _tolist(obj) | |
# is obj an ndarray already? | |
if isinstance(obj, ndarray): | |
obj = obj.tensor | |
# is a specific dtype requested? | |
torch_dtype = None | |
if dtype is not None: | |
torch_dtype = _dtypes.dtype(dtype).torch_dtype | |
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) | |
return ndarray(tensor) | |
def asarray(a, dtype=None, order="K", *, like=None): | |
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) | |
def ascontiguousarray(a, dtype=None, *, like=None): | |
arr = asarray(a, dtype=dtype, like=like) | |
if not arr.tensor.is_contiguous(): | |
arr.tensor = arr.tensor.contiguous() | |
return arr | |
def from_dlpack(x, /): | |
t = torch.from_dlpack(x) | |
return ndarray(t) | |
def _extract_dtype(entry): | |
try: | |
dty = _dtypes.dtype(entry) | |
except Exception: | |
dty = asarray(entry).dtype | |
return dty | |
def can_cast(from_, to, casting="safe"): | |
from_ = _extract_dtype(from_) | |
to_ = _extract_dtype(to) | |
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) | |
def result_type(*arrays_and_dtypes): | |
tensors = [] | |
for entry in arrays_and_dtypes: | |
try: | |
t = asarray(entry).tensor | |
except (RuntimeError, ValueError, TypeError): | |
dty = _dtypes.dtype(entry) | |
t = torch.empty(1, dtype=dty.torch_dtype) | |
tensors.append(t) | |
torch_dtype = _dtypes_impl.result_type_impl(*tensors) | |
return _dtypes.dtype(torch_dtype) | |