Spaces:
Sleeping
Sleeping
File size: 4,796 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from __future__ import annotations
from sympy.core.expr import Expr
from sympy.core.function import Derivative
from sympy.core.numbers import Integer
from sympy.matrices.matrixbase import MatrixBase
from .ndim_array import NDimArray
from .arrayop import derive_by_array
from sympy.matrices.expressions.matexpr import MatrixExpr
from sympy.matrices.expressions.special import ZeroMatrix
from sympy.matrices.expressions.matexpr import _matrix_derivative
class ArrayDerivative(Derivative):
is_scalar = False
def __new__(cls, expr, *variables, **kwargs):
obj = super().__new__(cls, expr, *variables, **kwargs)
if isinstance(obj, ArrayDerivative):
obj._shape = obj._get_shape()
return obj
def _get_shape(self):
shape = ()
for v, count in self.variable_count:
if hasattr(v, "shape"):
for i in range(count):
shape += v.shape
if hasattr(self.expr, "shape"):
shape += self.expr.shape
return shape
@property
def shape(self):
return self._shape
@classmethod
def _get_zero_with_shape_like(cls, expr):
if isinstance(expr, (MatrixBase, NDimArray)):
return expr.zeros(*expr.shape)
elif isinstance(expr, MatrixExpr):
return ZeroMatrix(*expr.shape)
else:
raise RuntimeError("Unable to determine shape of array-derivative.")
@staticmethod
def _call_derive_scalar_by_matrix(expr: Expr, v: MatrixBase) -> Expr:
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_scalar_by_matexpr(expr: Expr, v: MatrixExpr) -> Expr:
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return ZeroMatrix(*v.shape)
@staticmethod
def _call_derive_scalar_by_array(expr: Expr, v: NDimArray) -> Expr:
return v.applyfunc(lambda x: expr.diff(x))
@staticmethod
def _call_derive_matrix_by_scalar(expr: MatrixBase, v: Expr) -> Expr:
return _matrix_derivative(expr, v)
@staticmethod
def _call_derive_matexpr_by_scalar(expr: MatrixExpr, v: Expr) -> Expr:
return expr._eval_derivative(v)
@staticmethod
def _call_derive_array_by_scalar(expr: NDimArray, v: Expr) -> Expr:
return expr.applyfunc(lambda x: x.diff(v))
@staticmethod
def _call_derive_default(expr: Expr, v: Expr) -> Expr | None:
if expr.has(v):
return _matrix_derivative(expr, v)
else:
return None
@classmethod
def _dispatch_eval_derivative_n_times(cls, expr, v, count):
# Evaluate the derivative `n` times. If
# `_eval_derivative_n_times` is not overridden by the current
# object, the default in `Basic` will call a loop over
# `_eval_derivative`:
if not isinstance(count, (int, Integer)) or ((count <= 0) == True):
return None
# TODO: this could be done with multiple-dispatching:
if expr.is_scalar:
if isinstance(v, MatrixBase):
result = cls._call_derive_scalar_by_matrix(expr, v)
elif isinstance(v, MatrixExpr):
result = cls._call_derive_scalar_by_matexpr(expr, v)
elif isinstance(v, NDimArray):
result = cls._call_derive_scalar_by_array(expr, v)
elif v.is_scalar:
# scalar by scalar has a special
return super()._dispatch_eval_derivative_n_times(expr, v, count)
else:
return None
elif v.is_scalar:
if isinstance(expr, MatrixBase):
result = cls._call_derive_matrix_by_scalar(expr, v)
elif isinstance(expr, MatrixExpr):
result = cls._call_derive_matexpr_by_scalar(expr, v)
elif isinstance(expr, NDimArray):
result = cls._call_derive_array_by_scalar(expr, v)
else:
return None
else:
# Both `expr` and `v` are some array/matrix type:
if isinstance(expr, MatrixBase) or isinstance(v, MatrixBase):
result = derive_by_array(expr, v)
elif isinstance(expr, MatrixExpr) and isinstance(v, MatrixExpr):
result = cls._call_derive_default(expr, v)
elif isinstance(expr, MatrixExpr) or isinstance(v, MatrixExpr):
# if one expression is a symbolic matrix expression while the other isn't, don't evaluate:
return None
else:
result = derive_by_array(expr, v)
if result is None:
return None
if count == 1:
return result
else:
return cls._dispatch_eval_derivative_n_times(result, v, count - 1)
|