Spaces:
Running
Running
File size: 3,956 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
""" Generic SymPy-Independent Strategies """
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import TypeVar
from sys import stdout
_S = TypeVar('_S')
_T = TypeVar('_T')
def identity(x: _T) -> _T:
return x
def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]:
""" Apply a rule repeatedly until it has no effect """
def exhaustive_rl(expr: _T) -> _T:
new, old = rule(expr), expr
while new != old:
new, old = rule(new), new
return new
return exhaustive_rl
def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]:
"""Memoized version of a rule
Notes
=====
This cache can grow infinitely, so it is not recommended to use this
than ``functools.lru_cache`` unless you need very heavy computation.
"""
cache: dict[_S, _T] = {}
def memoized_rl(expr: _S) -> _T:
if expr in cache:
return cache[expr]
else:
result = rule(expr)
cache[expr] = result
return result
return memoized_rl
def condition(
cond: Callable[[_T], bool], rule: Callable[[_T], _T]
) -> Callable[[_T], _T]:
""" Only apply rule if condition is true """
def conditioned_rl(expr: _T) -> _T:
if cond(expr):
return rule(expr)
return expr
return conditioned_rl
def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
"""
Compose a sequence of rules so that they apply to the expr sequentially
"""
def chain_rl(expr: _T) -> _T:
for rule in rules:
expr = rule(expr)
return expr
return chain_rl
def debug(rule, file=None):
""" Print out before and after expressions each time rule is used """
if file is None:
file = stdout
def debug_rl(*args, **kwargs):
expr = args[0]
result = rule(*args, **kwargs)
if result != expr:
file.write("Rule: %s\n" % rule.__name__)
file.write("In: %s\nOut: %s\n\n" % (expr, result))
return result
return debug_rl
def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]:
""" Return original expr if rule returns None """
def null_safe_rl(expr: _T) -> _T:
result = rule(expr)
if result is None:
return expr
return result
return null_safe_rl
def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]:
""" Return original expr if rule raises exception """
def try_rl(expr: _T) -> _T:
try:
return rule(expr)
except exception:
return expr
return try_rl
def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
""" Try each of the rules until one works. Then stop. """
def do_one_rl(expr: _T) -> _T:
for rl in rules:
result = rl(expr)
if result != expr:
return result
return expr
return do_one_rl
def switch(
key: Callable[[_S], _T],
ruledict: Mapping[_T, Callable[[_S], _S]]
) -> Callable[[_S], _S]:
""" Select a rule based on the result of key called on the function """
def switch_rl(expr: _S) -> _S:
rl = ruledict.get(key(expr), identity)
return rl(expr)
return switch_rl
# XXX Untyped default argument for minimize function
# where python requires SupportsRichComparison type
def _identity(x):
return x
def minimize(
*rules: Callable[[_S], _T],
objective=_identity
) -> Callable[[_S], _T]:
""" Select result of rules that minimizes objective
>>> from sympy.strategies import minimize
>>> inc = lambda x: x + 1
>>> dec = lambda x: x - 1
>>> rl = minimize(inc, dec)
>>> rl(4)
3
>>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize
>>> rl(4)
5
"""
def minrule(expr: _S) -> _T:
return min([rule(expr) for rule in rules], key=objective)
return minrule
|