|
from __future__ import annotations |
|
|
|
from ._dtypes import ( |
|
_floating_dtypes, |
|
_numeric_dtypes, |
|
float32, |
|
float64, |
|
complex64, |
|
complex128 |
|
) |
|
from ._manipulation_functions import reshape |
|
from ._elementwise_functions import conj |
|
from ._array_object import Array |
|
|
|
from ..core.numeric import normalize_axis_tuple |
|
|
|
from typing import TYPE_CHECKING |
|
if TYPE_CHECKING: |
|
from ._typing import Literal, Optional, Sequence, Tuple, Union, Dtype |
|
|
|
from typing import NamedTuple |
|
|
|
import numpy.linalg |
|
import numpy as np |
|
|
|
class EighResult(NamedTuple): |
|
eigenvalues: Array |
|
eigenvectors: Array |
|
|
|
class QRResult(NamedTuple): |
|
Q: Array |
|
R: Array |
|
|
|
class SlogdetResult(NamedTuple): |
|
sign: Array |
|
logabsdet: Array |
|
|
|
class SVDResult(NamedTuple): |
|
U: Array |
|
S: Array |
|
Vh: Array |
|
|
|
|
|
|
|
def cholesky(x: Array, /, *, upper: bool = False) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in cholesky') |
|
L = np.linalg.cholesky(x._array) |
|
if upper: |
|
U = Array._new(L).mT |
|
if U.dtype in [complex64, complex128]: |
|
U = conj(U) |
|
return U |
|
return Array._new(L) |
|
|
|
|
|
def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in cross') |
|
|
|
if x1.shape != x2.shape: |
|
raise ValueError('x1 and x2 must have the same shape') |
|
if x1.ndim == 0: |
|
raise ValueError('cross() requires arrays of dimension at least 1') |
|
|
|
if x1.shape[axis] != 3: |
|
raise ValueError('cross() dimension must equal 3') |
|
return Array._new(np.cross(x1._array, x2._array, axis=axis)) |
|
|
|
def det(x: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in det') |
|
return Array._new(np.linalg.det(x._array)) |
|
|
|
|
|
def diagonal(x: Array, /, *, offset: int = 0) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1)) |
|
|
|
|
|
def eigh(x: Array, /) -> EighResult: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in eigh') |
|
|
|
|
|
|
|
return EighResult(*map(Array._new, np.linalg.eigh(x._array))) |
|
|
|
|
|
def eigvalsh(x: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in eigvalsh') |
|
|
|
return Array._new(np.linalg.eigvalsh(x._array)) |
|
|
|
def inv(x: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in inv') |
|
|
|
return Array._new(np.linalg.inv(x._array)) |
|
|
|
|
|
|
|
def matmul(x1: Array, x2: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in matmul') |
|
|
|
return Array._new(np.matmul(x1._array, x2._array)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in matrix_norm') |
|
|
|
return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord)) |
|
|
|
|
|
def matrix_power(x: Array, n: int, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power') |
|
|
|
|
|
return Array._new(np.linalg.matrix_power(x._array, n)) |
|
|
|
|
|
def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.ndim < 2: |
|
raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") |
|
S = np.linalg.svd(x._array, compute_uv=False) |
|
if rtol is None: |
|
tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps |
|
else: |
|
if isinstance(rtol, Array): |
|
rtol = rtol._array |
|
|
|
|
|
tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis] |
|
return Array._new(np.count_nonzero(S > tol, axis=-1)) |
|
|
|
|
|
|
|
|
|
def matrix_transpose(x: Array, /) -> Array: |
|
if x.ndim < 2: |
|
raise ValueError("x must be at least 2-dimensional for matrix_transpose") |
|
return Array._new(np.swapaxes(x._array, -1, -2)) |
|
|
|
|
|
def outer(x1: Array, x2: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in outer') |
|
|
|
|
|
if x1.ndim != 1 or x2.ndim != 1: |
|
raise ValueError('The input arrays to outer must be 1-dimensional') |
|
|
|
return Array._new(np.outer(x1._array, x2._array)) |
|
|
|
|
|
def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in pinv') |
|
|
|
|
|
|
|
if rtol is None: |
|
rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps |
|
return Array._new(np.linalg.pinv(x._array, rcond=rtol)) |
|
|
|
def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in qr') |
|
|
|
|
|
|
|
return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode))) |
|
|
|
def slogdet(x: Array, /) -> SlogdetResult: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in slogdet') |
|
|
|
|
|
|
|
return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _solve(a, b): |
|
from ..linalg.linalg import (_makearray, _assert_stacked_2d, |
|
_assert_stacked_square, _commonType, |
|
isComplexType, get_linalg_error_extobj, |
|
_raise_linalgerror_singular) |
|
from ..linalg import _umath_linalg |
|
|
|
a, _ = _makearray(a) |
|
_assert_stacked_2d(a) |
|
_assert_stacked_square(a) |
|
b, wrap = _makearray(b) |
|
t, result_t = _commonType(a, b) |
|
|
|
|
|
if b.ndim == 1: |
|
gufunc = _umath_linalg.solve1 |
|
else: |
|
gufunc = _umath_linalg.solve |
|
|
|
|
|
|
|
signature = 'DD->D' if isComplexType(t) else 'dd->d' |
|
with np.errstate(call=_raise_linalgerror_singular, invalid='call', |
|
over='ignore', divide='ignore', under='ignore'): |
|
r = gufunc(a, b, signature=signature) |
|
|
|
return wrap(r.astype(result_t, copy=False)) |
|
|
|
def solve(x1: Array, x2: Array, /) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in solve') |
|
|
|
return Array._new(_solve(x1._array, x2._array)) |
|
|
|
def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in svd') |
|
|
|
|
|
|
|
return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices))) |
|
|
|
|
|
|
|
def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]: |
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in svdvals') |
|
return Array._new(np.linalg.svd(x._array, compute_uv=False)) |
|
|
|
|
|
|
|
|
|
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array: |
|
|
|
|
|
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in tensordot') |
|
|
|
return Array._new(np.tensordot(x1._array, x2._array, axes=axes)) |
|
|
|
|
|
def trace(x: Array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
if x.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in trace') |
|
|
|
|
|
|
|
if dtype is None: |
|
if x.dtype == float32: |
|
dtype = float64 |
|
elif x.dtype == complex64: |
|
dtype = complex128 |
|
|
|
|
|
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=dtype))) |
|
|
|
|
|
def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: |
|
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: |
|
raise TypeError('Only numeric dtypes are allowed in vecdot') |
|
ndim = max(x1.ndim, x2.ndim) |
|
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) |
|
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) |
|
if x1_shape[axis] != x2_shape[axis]: |
|
raise ValueError("x1 and x2 must have the same size along the given axis") |
|
|
|
x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) |
|
x1_ = np.moveaxis(x1_, axis, -1) |
|
x2_ = np.moveaxis(x2_, axis, -1) |
|
|
|
res = x1_[..., None, :] @ x2_[..., None] |
|
return Array._new(res[..., 0, 0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array: |
|
""" |
|
Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`. |
|
|
|
See its docstring for more information. |
|
""" |
|
|
|
|
|
if x.dtype not in _floating_dtypes: |
|
raise TypeError('Only floating-point dtypes are allowed in norm') |
|
|
|
|
|
|
|
|
|
|
|
a = x._array |
|
if axis is None: |
|
|
|
a = a.ravel() |
|
_axis = 0 |
|
elif isinstance(axis, tuple): |
|
|
|
|
|
normalized_axis = normalize_axis_tuple(axis, x.ndim) |
|
rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) |
|
newshape = axis + rest |
|
a = np.transpose(a, newshape).reshape( |
|
(np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) |
|
_axis = 0 |
|
else: |
|
_axis = axis |
|
|
|
res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) |
|
|
|
if keepdims: |
|
|
|
|
|
shape = list(x.shape) |
|
_axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) |
|
for i in _axis: |
|
shape[i] = 1 |
|
res = reshape(res, tuple(shape)) |
|
|
|
return res |
|
|
|
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] |
|
|