Spaces:
Sleeping
Sleeping
| from collections.abc import Iterable | |
| from functools import singledispatch | |
| from sympy.core.expr import Expr | |
| from sympy.core.mul import Mul | |
| from sympy.core.singleton import S | |
| from sympy.core.sympify import sympify | |
| from sympy.core.parameters import global_parameters | |
| class TensorProduct(Expr): | |
| """ | |
| Generic class for tensor products. | |
| """ | |
| is_number = False | |
| def __new__(cls, *args, **kwargs): | |
| from sympy.tensor.array import NDimArray, tensorproduct, Array | |
| from sympy.matrices.expressions.matexpr import MatrixExpr | |
| from sympy.matrices.matrixbase import MatrixBase | |
| from sympy.strategies import flatten | |
| args = [sympify(arg) for arg in args] | |
| evaluate = kwargs.get("evaluate", global_parameters.evaluate) | |
| if not evaluate: | |
| obj = Expr.__new__(cls, *args) | |
| return obj | |
| arrays = [] | |
| other = [] | |
| scalar = S.One | |
| for arg in args: | |
| if isinstance(arg, (Iterable, MatrixBase, NDimArray)): | |
| arrays.append(Array(arg)) | |
| elif isinstance(arg, (MatrixExpr,)): | |
| other.append(arg) | |
| else: | |
| scalar *= arg | |
| coeff = scalar*tensorproduct(*arrays) | |
| if len(other) == 0: | |
| return coeff | |
| if coeff != 1: | |
| newargs = [coeff] + other | |
| else: | |
| newargs = other | |
| obj = Expr.__new__(cls, *newargs, **kwargs) | |
| return flatten(obj) | |
| def rank(self): | |
| return len(self.shape) | |
| def _get_args_shapes(self): | |
| from sympy.tensor.array import Array | |
| return [i.shape if hasattr(i, "shape") else Array(i).shape for i in self.args] | |
| def shape(self): | |
| shape_list = self._get_args_shapes() | |
| return sum(shape_list, ()) | |
| def __getitem__(self, index): | |
| index = iter(index) | |
| return Mul.fromiter( | |
| arg.__getitem__(tuple(next(index) for i in shp)) | |
| for arg, shp in zip(self.args, self._get_args_shapes()) | |
| ) | |
| def shape(expr): | |
| """ | |
| Return the shape of the *expr* as a tuple. *expr* should represent | |
| suitable object such as matrix or array. | |
| Parameters | |
| ========== | |
| expr : SymPy object having ``MatrixKind`` or ``ArrayKind``. | |
| Raises | |
| ====== | |
| NoShapeError : Raised when object with wrong kind is passed. | |
| Examples | |
| ======== | |
| This function returns the shape of any object representing matrix or array. | |
| >>> from sympy import shape, Array, ImmutableDenseMatrix, Integral | |
| >>> from sympy.abc import x | |
| >>> A = Array([1, 2]) | |
| >>> shape(A) | |
| (2,) | |
| >>> shape(Integral(A, x)) | |
| (2,) | |
| >>> M = ImmutableDenseMatrix([1, 2]) | |
| >>> shape(M) | |
| (2, 1) | |
| >>> shape(Integral(M, x)) | |
| (2, 1) | |
| You can support new type by dispatching. | |
| >>> from sympy import Expr | |
| >>> class NewExpr(Expr): | |
| ... pass | |
| >>> @shape.register(NewExpr) | |
| ... def _(expr): | |
| ... return shape(expr.args[0]) | |
| >>> shape(NewExpr(M)) | |
| (2, 1) | |
| If unsuitable expression is passed, ``NoShapeError()`` will be raised. | |
| >>> shape(Integral(x, x)) | |
| Traceback (most recent call last): | |
| ... | |
| sympy.tensor.functions.NoShapeError: shape() called on non-array object: Integral(x, x) | |
| Notes | |
| ===== | |
| Array-like classes (such as ``Matrix`` or ``NDimArray``) has ``shape`` | |
| property which returns its shape, but it cannot be used for non-array | |
| classes containing array. This function returns the shape of any | |
| registered object representing array. | |
| """ | |
| if hasattr(expr, "shape"): | |
| return expr.shape | |
| raise NoShapeError( | |
| "%s does not have shape, or its type is not registered to shape()." % expr) | |
| class NoShapeError(Exception): | |
| """ | |
| Raised when ``shape()`` is called on non-array object. | |
| This error can be imported from ``sympy.tensor.functions``. | |
| Examples | |
| ======== | |
| >>> from sympy import shape | |
| >>> from sympy.abc import x | |
| >>> shape(x) | |
| Traceback (most recent call last): | |
| ... | |
| sympy.tensor.functions.NoShapeError: shape() called on non-array object: x | |
| """ | |
| pass | |