Spaces:
Sleeping
Sleeping
| """ | |
| The objects in this module allow the usage of the MatchPy pattern matching | |
| library on SymPy expressions. | |
| """ | |
| import re | |
| from typing import List, Callable, NamedTuple, Any, Dict | |
| from sympy.core.sympify import _sympify | |
| from sympy.external import import_module | |
| from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma) | |
| from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch | |
| from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec | |
| from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei | |
| from sympy.core.add import Add | |
| from sympy.core.basic import Basic | |
| from sympy.core.expr import Expr | |
| from sympy.core.mul import Mul | |
| from sympy.core.power import Pow | |
| from sympy.core.relational import (Equality, Unequality) | |
| from sympy.core.symbol import Symbol | |
| from sympy.functions.elementary.exponential import exp | |
| from sympy.integrals.integrals import Integral | |
| from sympy.printing.repr import srepr | |
| from sympy.utilities.decorator import doctest_depends_on | |
| matchpy = import_module("matchpy") | |
| __doctest_requires__ = {('*',): ['matchpy']} | |
| if matchpy: | |
| from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation | |
| from matchpy.expressions.functions import op_iter, create_operation_expression, op_len | |
| Operation.register(Integral) | |
| Operation.register(Pow) | |
| OneIdentityOperation.register(Pow) | |
| Operation.register(Add) | |
| OneIdentityOperation.register(Add) | |
| CommutativeOperation.register(Add) | |
| AssociativeOperation.register(Add) | |
| Operation.register(Mul) | |
| OneIdentityOperation.register(Mul) | |
| CommutativeOperation.register(Mul) | |
| AssociativeOperation.register(Mul) | |
| Operation.register(Equality) | |
| CommutativeOperation.register(Equality) | |
| Operation.register(Unequality) | |
| CommutativeOperation.register(Unequality) | |
| Operation.register(exp) | |
| Operation.register(log) | |
| Operation.register(gamma) | |
| Operation.register(uppergamma) | |
| Operation.register(fresnels) | |
| Operation.register(fresnelc) | |
| Operation.register(erf) | |
| Operation.register(Ei) | |
| Operation.register(erfc) | |
| Operation.register(erfi) | |
| Operation.register(sin) | |
| Operation.register(cos) | |
| Operation.register(tan) | |
| Operation.register(cot) | |
| Operation.register(csc) | |
| Operation.register(sec) | |
| Operation.register(sinh) | |
| Operation.register(cosh) | |
| Operation.register(tanh) | |
| Operation.register(coth) | |
| Operation.register(csch) | |
| Operation.register(sech) | |
| Operation.register(asin) | |
| Operation.register(acos) | |
| Operation.register(atan) | |
| Operation.register(acot) | |
| Operation.register(acsc) | |
| Operation.register(asec) | |
| Operation.register(asinh) | |
| Operation.register(acosh) | |
| Operation.register(atanh) | |
| Operation.register(acoth) | |
| Operation.register(acsch) | |
| Operation.register(asech) | |
| # type: ignore | |
| def _(operation): | |
| return iter((operation._args[0],) + operation._args[1]) | |
| # type: ignore | |
| def _(operation): | |
| return iter(operation._args) | |
| # type: ignore | |
| def _(operation): | |
| return 1 + len(operation._args[1]) | |
| # type: ignore | |
| def _(operation): | |
| return len(operation._args) | |
| def sympy_op_factory(old_operation, new_operands, variable_name=True): | |
| return type(old_operation)(*new_operands) | |
| if matchpy: | |
| from matchpy import Wildcard | |
| else: | |
| class Wildcard: # type: ignore | |
| def __init__(self, min_length, fixed_size, variable_name, optional): | |
| self.min_count = min_length | |
| self.fixed_size = fixed_size | |
| self.variable_name = variable_name | |
| self.optional = optional | |
| class _WildAbstract(Wildcard, Symbol): | |
| min_length: int # abstract field required in subclasses | |
| fixed_size: bool # abstract field required in subclasses | |
| def __init__(self, variable_name=None, optional=None, **assumptions): | |
| min_length = self.min_length | |
| fixed_size = self.fixed_size | |
| if optional is not None: | |
| optional = _sympify(optional) | |
| Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional) | |
| def __getstate__(self): | |
| return { | |
| "min_length": self.min_length, | |
| "fixed_size": self.fixed_size, | |
| "min_count": self.min_count, | |
| "variable_name": self.variable_name, | |
| "optional": self.optional, | |
| } | |
| def __new__(cls, variable_name=None, optional=None, **assumptions): | |
| cls._sanitize(assumptions, cls) | |
| return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions) | |
| def __getnewargs__(self): | |
| return self.variable_name, self.optional | |
| def __xnew__(cls, variable_name=None, optional=None, **assumptions): | |
| obj = Symbol.__xnew__(cls, variable_name, **assumptions) | |
| return obj | |
| def _hashable_content(self): | |
| if self.optional: | |
| return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional) | |
| else: | |
| return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name) | |
| def __copy__(self) -> '_WildAbstract': | |
| return type(self)(variable_name=self.variable_name, optional=self.optional) | |
| def __repr__(self): | |
| return str(self) | |
| def __str__(self): | |
| return self.name | |
| class WildDot(_WildAbstract): | |
| min_length = 1 | |
| fixed_size = True | |
| class WildPlus(_WildAbstract): | |
| min_length = 1 | |
| fixed_size = False | |
| class WildStar(_WildAbstract): | |
| min_length = 0 | |
| fixed_size = False | |
| def _get_srepr(expr): | |
| s = srepr(expr) | |
| s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s) | |
| s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s) | |
| s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s) | |
| return s | |
| class ReplacementInfo(NamedTuple): | |
| replacement: Any | |
| info: Any | |
| class Replacer: | |
| """ | |
| Replacer object to perform multiple pattern matching and subexpression | |
| replacements in SymPy expressions. | |
| Examples | |
| ======== | |
| Example to construct a simple first degree equation solver: | |
| >>> from sympy.utilities.matchpy_connector import WildDot, Replacer | |
| >>> from sympy import Equality, Symbol | |
| >>> x = Symbol("x") | |
| >>> a_ = WildDot("a_", optional=1) | |
| >>> b_ = WildDot("b_", optional=0) | |
| The lines above have defined two wildcards, ``a_`` and ``b_``, the | |
| coefficients of the equation `a x + b = 0`. The optional values specified | |
| indicate which expression to return in case no match is found, they are | |
| necessary in equations like `a x = 0` and `x + b = 0`. | |
| Create two constraints to make sure that ``a_`` and ``b_`` will not match | |
| any expression containing ``x``: | |
| >>> from matchpy import CustomConstraint | |
| >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x)) | |
| >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x)) | |
| Now create the rule replacer with the constraints: | |
| >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b]) | |
| Add the matching rule: | |
| >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_) | |
| Let's try it: | |
| >>> replacer.replace(Equality(3*x + 4, 0)) | |
| -4/3 | |
| Notice that it will not match equations expressed with other patterns: | |
| >>> eq = Equality(3*x, 4) | |
| >>> replacer.replace(eq) | |
| Eq(3*x, 4) | |
| In order to extend the matching patterns, define another one (we also need | |
| to clear the cache, because the previous result has already been memorized | |
| and the pattern matcher will not iterate again if given the same expression) | |
| >>> replacer.add(Equality(a_*x, b_), b_/a_) | |
| >>> replacer._matcher.clear() | |
| >>> replacer.replace(eq) | |
| 4/3 | |
| """ | |
| def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False): | |
| self._matcher = matchpy.ManyToOneMatcher() | |
| self._common_constraint = common_constraints | |
| self._lambdify = lambdify | |
| self._info = info | |
| self._wildcards: Dict[str, Wildcard] = {} | |
| def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]: | |
| exec("from sympy import *") | |
| return eval(lambda_str, locals()) | |
| def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]: | |
| wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)] | |
| lambdaargs = ', '.join(wilds) | |
| fullexpr = _get_srepr(constraint_expr) | |
| condition = condition_template.format(fullexpr) | |
| return matchpy.CustomConstraint( | |
| self._get_lambda(f"lambda {lambdaargs}: ({condition})")) | |
| def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]: | |
| return self._get_custom_constraint(constraint_expr, "({}) != False") | |
| def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]: | |
| return self._get_custom_constraint(constraint_expr, "({}) == True") | |
| def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [], | |
| conditions_nonfalse: List[Expr] = [], info: Any = None) -> None: | |
| expr = _sympify(expr) | |
| replacement = _sympify(replacement) | |
| constraints = self._common_constraint[:] | |
| constraint_conditions_true = [ | |
| self._get_custom_constraint_true(cond) for cond in conditions_true] | |
| constraint_conditions_nonfalse = [ | |
| self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse] | |
| constraints.extend(constraint_conditions_true) | |
| constraints.extend(constraint_conditions_nonfalse) | |
| pattern = matchpy.Pattern(expr, *constraints) | |
| if self._lambdify: | |
| lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}" | |
| lambda_expr = self._get_lambda(lambda_str) | |
| replacement = lambda_expr | |
| else: | |
| self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)}) | |
| if self._info: | |
| replacement = ReplacementInfo(replacement, info) | |
| self._matcher.add(pattern, replacement) | |
| def replace(self, expression, max_count: int = -1): | |
| # This method partly rewrites the .replace method of ManyToOneReplacer | |
| # in MatchPy. | |
| # License: https://github.com/HPAC/matchpy/blob/master/LICENSE | |
| infos = [] | |
| replaced = True | |
| replace_count = 0 | |
| while replaced and (max_count < 0 or replace_count < max_count): | |
| replaced = False | |
| for subexpr, pos in matchpy.preorder_iter_with_position(expression): | |
| try: | |
| replacement_data, subst = next(iter(self._matcher.match(subexpr))) | |
| if self._info: | |
| replacement = replacement_data.replacement | |
| infos.append(replacement_data.info) | |
| else: | |
| replacement = replacement_data | |
| if self._lambdify: | |
| result = replacement(**subst) | |
| else: | |
| result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()}) | |
| expression = matchpy.functions.replace(expression, pos, result) | |
| replaced = True | |
| break | |
| except StopIteration: | |
| pass | |
| replace_count += 1 | |
| if self._info: | |
| return expression, infos | |
| else: | |
| return expression | |