Spaces:
Sleeping
Sleeping
from collections.abc import Iterable | |
from functools import singledispatch | |
from sympy.core.expr import Expr | |
from sympy.core.mul import Mul | |
from sympy.core.singleton import S | |
from sympy.core.sympify import sympify | |
from sympy.core.parameters import global_parameters | |
class TensorProduct(Expr): | |
""" | |
Generic class for tensor products. | |
""" | |
is_number = False | |
def __new__(cls, *args, **kwargs): | |
from sympy.tensor.array import NDimArray, tensorproduct, Array | |
from sympy.matrices.expressions.matexpr import MatrixExpr | |
from sympy.matrices.matrixbase import MatrixBase | |
from sympy.strategies import flatten | |
args = [sympify(arg) for arg in args] | |
evaluate = kwargs.get("evaluate", global_parameters.evaluate) | |
if not evaluate: | |
obj = Expr.__new__(cls, *args) | |
return obj | |
arrays = [] | |
other = [] | |
scalar = S.One | |
for arg in args: | |
if isinstance(arg, (Iterable, MatrixBase, NDimArray)): | |
arrays.append(Array(arg)) | |
elif isinstance(arg, (MatrixExpr,)): | |
other.append(arg) | |
else: | |
scalar *= arg | |
coeff = scalar*tensorproduct(*arrays) | |
if len(other) == 0: | |
return coeff | |
if coeff != 1: | |
newargs = [coeff] + other | |
else: | |
newargs = other | |
obj = Expr.__new__(cls, *newargs, **kwargs) | |
return flatten(obj) | |
def rank(self): | |
return len(self.shape) | |
def _get_args_shapes(self): | |
from sympy.tensor.array import Array | |
return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] | |
def shape(self): | |
shape_list = self._get_args_shapes() | |
return sum(shape_list, ()) | |
def __getitem__(self, index): | |
index = iter(index) | |
return Mul.fromiter( | |
arg.__getitem__(tuple(next(index) for i in shp)) | |
for arg, shp in zip(self.args, self._get_args_shapes()) | |
) | |
def shape(expr): | |
""" | |
Return the shape of the *expr* as a tuple. *expr* should represent | |
suitable object such as matrix or array. | |
Parameters | |
========== | |
expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. | |
Raises | |
====== | |
NoShapeError : Raised when object with wrong kind is passed. | |
Examples | |
======== | |
This function returns the shape of any object representing matrix or array. | |
>>> from sympy import shape, Array, ImmutableDenseMatrix, Integral | |
>>> from sympy.abc import x | |
>>> A = Array([1, 2]) | |
>>> shape(A) | |
(2,) | |
>>> shape(Integral(A, x)) | |
(2,) | |
>>> M = ImmutableDenseMatrix([1, 2]) | |
>>> shape(M) | |
(2, 1) | |
>>> shape(Integral(M, x)) | |
(2, 1) | |
You can support new type by dispatching. | |
>>> from sympy import Expr | |
>>> class NewExpr(Expr): | |
... pass | |
>>> @shape.register(NewExpr) | |
... def _(expr): | |
... return shape(expr.args[0]) | |
>>> shape(NewExpr(M)) | |
(2, 1) | |
If unsuitable expression is passed, ``NoShapeError()`` will be raised. | |
>>> shape(Integral(x, x)) | |
Traceback (most recent call last): | |
... | |
sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) | |
Notes | |
===== | |
Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` | |
property which returns its shape, but it cannot be used for non-array | |
classes containing array. This function returns the shape of any | |
registered object representing array. | |
""" | |
if hasattr(expr, "shape"): | |
return expr.shape | |
raise NoShapeError( | |
"%s does not have shape, or its type is not registered to shape()." % expr) | |
class NoShapeError(Exception): | |
""" | |
Raised when ``shape()`` is called on non-array object. | |
This error can be imported from ``sympy.tensor.functions``. | |
Examples | |
======== | |
>>> from sympy import shape | |
>>> from sympy.abc import x | |
>>> shape(x) | |
Traceback (most recent call last): | |
... | |
sympy.tensor.functions.NoShapeError: shape() called on non-array object: x | |
""" | |
pass | |