Spaces:
Sleeping
Sleeping
File size: 4,403 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
""" Generic Rules for SymPy
This file assumes knowledge of Basic and little else.
"""
from sympy.utilities.iterables import sift
from .util import new
# Functions that create rules
def rm_id(isid, new=new):
""" Create a rule to remove identities.
isid - fn :: x -> Bool --- whether or not this element is an identity.
Examples
========
>>> from sympy.strategies import rm_id
>>> from sympy import Basic, S
>>> remove_zeros = rm_id(lambda x: x==0)
>>> remove_zeros(Basic(S(1), S(0), S(2)))
Basic(1, 2)
>>> remove_zeros(Basic(S(0), S(0))) # If only identites then we keep one
Basic(0)
See Also:
unpack
"""
def ident_remove(expr):
""" Remove identities """
ids = list(map(isid, expr.args))
if sum(ids) == 0: # No identities. Common case
return expr
elif sum(ids) != len(ids): # there is at least one non-identity
return new(expr.__class__,
*[arg for arg, x in zip(expr.args, ids) if not x])
else:
return new(expr.__class__, expr.args[0])
return ident_remove
def glom(key, count, combine):
""" Create a rule to conglomerate identical args.
Examples
========
>>> from sympy.strategies import glom
>>> from sympy import Add
>>> from sympy.abc import x
>>> key = lambda x: x.as_coeff_Mul()[1]
>>> count = lambda x: x.as_coeff_Mul()[0]
>>> combine = lambda cnt, arg: cnt * arg
>>> rl = glom(key, count, combine)
>>> rl(Add(x, -x, 3*x, 2, 3, evaluate=False))
3*x + 5
Wait, how are key, count and combine supposed to work?
>>> key(2*x)
x
>>> count(2*x)
2
>>> combine(2, x)
2*x
"""
def conglomerate(expr):
""" Conglomerate together identical args x + x -> 2x """
groups = sift(expr.args, key)
counts = {k: sum(map(count, args)) for k, args in groups.items()}
newargs = [combine(cnt, mat) for mat, cnt in counts.items()]
if set(newargs) != set(expr.args):
return new(type(expr), *newargs)
else:
return expr
return conglomerate
def sort(key, new=new):
""" Create a rule to sort by a key function.
Examples
========
>>> from sympy.strategies import sort
>>> from sympy import Basic, S
>>> sort_rl = sort(str)
>>> sort_rl(Basic(S(3), S(1), S(2)))
Basic(1, 2, 3)
"""
def sort_rl(expr):
return new(expr.__class__, *sorted(expr.args, key=key))
return sort_rl
def distribute(A, B):
""" Turns an A containing Bs into a B of As
where A, B are container types
>>> from sympy.strategies import distribute
>>> from sympy import Add, Mul, symbols
>>> x, y = symbols('x,y')
>>> dist = distribute(Mul, Add)
>>> expr = Mul(2, x+y, evaluate=False)
>>> expr
2*(x + y)
>>> dist(expr)
2*x + 2*y
"""
def distribute_rl(expr):
for i, arg in enumerate(expr.args):
if isinstance(arg, B):
first, b, tail = expr.args[:i], expr.args[i], expr.args[i + 1:]
return B(*[A(*(first + (arg,) + tail)) for arg in b.args])
return expr
return distribute_rl
def subs(a, b):
""" Replace expressions exactly """
def subs_rl(expr):
if expr == a:
return b
else:
return expr
return subs_rl
# Functions that are rules
def unpack(expr):
""" Rule to unpack singleton args
>>> from sympy.strategies import unpack
>>> from sympy import Basic, S
>>> unpack(Basic(S(2)))
2
"""
if len(expr.args) == 1:
return expr.args[0]
else:
return expr
def flatten(expr, new=new):
""" Flatten T(a, b, T(c, d), T2(e)) to T(a, b, c, d, T2(e)) """
cls = expr.__class__
args = []
for arg in expr.args:
if arg.__class__ == cls:
args.extend(arg.args)
else:
args.append(arg)
return new(expr.__class__, *args)
def rebuild(expr):
""" Rebuild a SymPy tree.
Explanation
===========
This function recursively calls constructors in the expression tree.
This forces canonicalization and removes ugliness introduced by the use of
Basic.__new__
"""
if expr.is_Atom:
return expr
else:
return expr.func(*list(map(rebuild, expr.args)))
|