Spaces:
Sleeping
Sleeping
""" Generic Rules for SymPy | |
This file assumes knowledge of Basic and little else. | |
""" | |
from sympy.utilities.iterables import sift | |
from .util import new | |
# Functions that create rules | |
def rm_id(isid, new=new): | |
""" Create a rule to remove identities. | |
isid - fn :: x -> Bool --- whether or not this element is an identity. | |
Examples | |
======== | |
>>> from sympy.strategies import rm_id | |
>>> from sympy import Basic, S | |
>>> remove_zeros = rm_id(lambda x: x==0) | |
>>> remove_zeros(Basic(S(1), S(0), S(2))) | |
Basic(1, 2) | |
>>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one | |
Basic(0) | |
See Also: | |
unpack | |
""" | |
def ident_remove(expr): | |
""" Remove identities """ | |
ids = list(map(isid, expr.args)) | |
if sum(ids) == 0: # No identities. Common case | |
return expr | |
elif sum(ids) != len(ids): # there is at least one non-identity | |
return new(expr.__class__, | |
*[arg for arg, x in zip(expr.args, ids) if not x]) | |
else: | |
return new(expr.__class__, expr.args[0]) | |
return ident_remove | |
def glom(key, count, combine): | |
""" Create a rule to conglomerate identical args. | |
Examples | |
======== | |
>>> from sympy.strategies import glom | |
>>> from sympy import Add | |
>>> from sympy.abc import x | |
>>> key = lambda x: x.as_coeff_Mul()[1] | |
>>> count = lambda x: x.as_coeff_Mul()[0] | |
>>> combine = lambda cnt, arg: cnt * arg | |
>>> rl = glom(key, count, combine) | |
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False)) | |
3*x + 5 | |
Wait, how are key, count and combine supposed to work? | |
>>> key(2*x) | |
x | |
>>> count(2*x) | |
2 | |
>>> combine(2, x) | |
2*x | |
""" | |
def conglomerate(expr): | |
""" Conglomerate together identical args x + x -> 2x """ | |
groups = sift(expr.args, key) | |
counts = {k: sum(map(count, args)) for k, args in groups.items()} | |
newargs = [combine(cnt, mat) for mat, cnt in counts.items()] | |
if set(newargs) != set(expr.args): | |
return new(type(expr), *newargs) | |
else: | |
return expr | |
return conglomerate | |
def sort(key, new=new): | |
""" Create a rule to sort by a key function. | |
Examples | |
======== | |
>>> from sympy.strategies import sort | |
>>> from sympy import Basic, S | |
>>> sort_rl = sort(str) | |
>>> sort_rl(Basic(S(3), S(1), S(2))) | |
Basic(1, 2, 3) | |
""" | |
def sort_rl(expr): | |
return new(expr.__class__, *sorted(expr.args, key=key)) | |
return sort_rl | |
def distribute(A, B): | |
""" Turns an A containing Bs into a B of As | |
where A, B are container types | |
>>> from sympy.strategies import distribute | |
>>> from sympy import Add, Mul, symbols | |
>>> x, y = symbols('x,y') | |
>>> dist = distribute(Mul, Add) | |
>>> expr = Mul(2, x+y, evaluate=False) | |
>>> expr | |
2*(x + y) | |
>>> dist(expr) | |
2*x + 2*y | |
""" | |
def distribute_rl(expr): | |
for i, arg in enumerate(expr.args): | |
if isinstance(arg, B): | |
first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:] | |
return B(*[A(*(first + (arg,) + tail)) for arg in b.args]) | |
return expr | |
return distribute_rl | |
def subs(a, b): | |
""" Replace expressions exactly """ | |
def subs_rl(expr): | |
if expr == a: | |
return b | |
else: | |
return expr | |
return subs_rl | |
# Functions that are rules | |
def unpack(expr): | |
""" Rule to unpack singleton args | |
>>> from sympy.strategies import unpack | |
>>> from sympy import Basic, S | |
>>> unpack(Basic(S(2))) | |
2 | |
""" | |
if len(expr.args) == 1: | |
return expr.args[0] | |
else: | |
return expr | |
def flatten(expr, new=new): | |
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """ | |
cls = expr.__class__ | |
args = [] | |
for arg in expr.args: | |
if arg.__class__ == cls: | |
args.extend(arg.args) | |
else: | |
args.append(arg) | |
return new(expr.__class__, *args) | |
def rebuild(expr): | |
""" Rebuild a SymPy tree. | |
Explanation | |
=========== | |
This function recursively calls constructors in the expression tree. | |
This forces canonicalization and removes ugliness introduced by the use of | |
Basic.__new__ | |
""" | |
if expr.is_Atom: | |
return expr | |
else: | |
return expr.func(*list(map(rebuild, expr.args))) | |