Spaces:
Running
Running
""" Generic SymPy-Independent Strategies """ | |
from __future__ import annotations | |
from collections.abc import Callable, Mapping | |
from typing import TypeVar | |
from sys import stdout | |
_S = TypeVar('_S') | |
_T = TypeVar('_T') | |
def identity(x: _T) -> _T: | |
return x | |
def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]: | |
""" Apply a rule repeatedly until it has no effect """ | |
def exhaustive_rl(expr: _T) -> _T: | |
new, old = rule(expr), expr | |
while new != old: | |
new, old = rule(new), new | |
return new | |
return exhaustive_rl | |
def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]: | |
"""Memoized version of a rule | |
Notes | |
===== | |
This cache can grow infinitely, so it is not recommended to use this | |
than ``functools.lru_cache`` unless you need very heavy computation. | |
""" | |
cache: dict[_S, _T] = {} | |
def memoized_rl(expr: _S) -> _T: | |
if expr in cache: | |
return cache[expr] | |
else: | |
result = rule(expr) | |
cache[expr] = result | |
return result | |
return memoized_rl | |
def condition( | |
cond: Callable[[_T], bool], rule: Callable[[_T], _T] | |
) -> Callable[[_T], _T]: | |
""" Only apply rule if condition is true """ | |
def conditioned_rl(expr: _T) -> _T: | |
if cond(expr): | |
return rule(expr) | |
return expr | |
return conditioned_rl | |
def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: | |
""" | |
Compose a sequence of rules so that they apply to the expr sequentially | |
""" | |
def chain_rl(expr: _T) -> _T: | |
for rule in rules: | |
expr = rule(expr) | |
return expr | |
return chain_rl | |
def debug(rule, file=None): | |
""" Print out before and after expressions each time rule is used """ | |
if file is None: | |
file = stdout | |
def debug_rl(*args, **kwargs): | |
expr = args[0] | |
result = rule(*args, **kwargs) | |
if result != expr: | |
file.write("Rule: %s\n" % rule.__name__) | |
file.write("In: %s\nOut: %s\n\n" % (expr, result)) | |
return result | |
return debug_rl | |
def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]: | |
""" Return original expr if rule returns None """ | |
def null_safe_rl(expr: _T) -> _T: | |
result = rule(expr) | |
if result is None: | |
return expr | |
return result | |
return null_safe_rl | |
def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]: | |
""" Return original expr if rule raises exception """ | |
def try_rl(expr: _T) -> _T: | |
try: | |
return rule(expr) | |
except exception: | |
return expr | |
return try_rl | |
def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]: | |
""" Try each of the rules until one works. Then stop. """ | |
def do_one_rl(expr: _T) -> _T: | |
for rl in rules: | |
result = rl(expr) | |
if result != expr: | |
return result | |
return expr | |
return do_one_rl | |
def switch( | |
key: Callable[[_S], _T], | |
ruledict: Mapping[_T, Callable[[_S], _S]] | |
) -> Callable[[_S], _S]: | |
""" Select a rule based on the result of key called on the function """ | |
def switch_rl(expr: _S) -> _S: | |
rl = ruledict.get(key(expr), identity) | |
return rl(expr) | |
return switch_rl | |
# XXX Untyped default argument for minimize function | |
# where python requires SupportsRichComparison type | |
def _identity(x): | |
return x | |
def minimize( | |
*rules: Callable[[_S], _T], | |
objective=_identity | |
) -> Callable[[_S], _T]: | |
""" Select result of rules that minimizes objective | |
>>> from sympy.strategies import minimize | |
>>> inc = lambda x: x + 1 | |
>>> dec = lambda x: x - 1 | |
>>> rl = minimize(inc, dec) | |
>>> rl(4) | |
3 | |
>>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize | |
>>> rl(4) | |
5 | |
""" | |
def minrule(expr: _S) -> _T: | |
return min([rule(expr) for rule in rules], key=objective) | |
return minrule | |