Spaces:
Sleeping
Sleeping
""" Tools for doing common subexpression elimination. | |
""" | |
from collections import defaultdict | |
from sympy.core import Basic, Mul, Add, Pow, sympify | |
from sympy.core.containers import Tuple, OrderedSet | |
from sympy.core.exprtools import factor_terms | |
from sympy.core.singleton import S | |
from sympy.core.sorting import ordered | |
from sympy.core.symbol import symbols, Symbol | |
from sympy.matrices import (MatrixBase, Matrix, ImmutableMatrix, | |
SparseMatrix, ImmutableSparseMatrix) | |
from sympy.matrices.expressions import (MatrixExpr, MatrixSymbol, MatMul, | |
MatAdd, MatPow, Inverse) | |
from sympy.matrices.expressions.matexpr import MatrixElement | |
from sympy.polys.rootoftools import RootOf | |
from sympy.utilities.iterables import numbered_symbols, sift, \ | |
topological_sort, iterable | |
from . import cse_opts | |
# (preprocessor, postprocessor) pairs which are commonly useful. They should | |
# each take a SymPy expression and return a possibly transformed expression. | |
# When used in the function ``cse()``, the target expressions will be transformed | |
# by each of the preprocessor functions in order. After the common | |
# subexpressions are eliminated, each resulting expression will have the | |
# postprocessor functions transform them in *reverse* order in order to undo the | |
# transformation if necessary. This allows the algorithm to operate on | |
# a representation of the expressions that allows for more optimization | |
# opportunities. | |
# ``None`` can be used to specify no transformation for either the preprocessor or | |
# postprocessor. | |
basic_optimizations = [(cse_opts.sub_pre, cse_opts.sub_post), | |
(factor_terms, None)] | |
# sometimes we want the output in a different format; non-trivial | |
# transformations can be put here for users | |
# =============================================================== | |
def reps_toposort(r): | |
"""Sort replacements ``r`` so (k1, v1) appears before (k2, v2) | |
if k2 is in v1's free symbols. This orders items in the | |
way that cse returns its results (hence, in order to use the | |
replacements in a substitution option it would make sense | |
to reverse the order). | |
Examples | |
======== | |
>>> from sympy.simplify.cse_main import reps_toposort | |
>>> from sympy.abc import x, y | |
>>> from sympy import Eq | |
>>> for l, r in reps_toposort([(x, y + 1), (y, 2)]): | |
... print(Eq(l, r)) | |
... | |
Eq(y, 2) | |
Eq(x, y + 1) | |
""" | |
r = sympify(r) | |
E = [] | |
for c1, (k1, v1) in enumerate(r): | |
for c2, (k2, v2) in enumerate(r): | |
if k1 in v2.free_symbols: | |
E.append((c1, c2)) | |
return [r[i] for i in topological_sort((range(len(r)), E))] | |
def cse_separate(r, e): | |
"""Move expressions that are in the form (symbol, expr) out of the | |
expressions and sort them into the replacements using the reps_toposort. | |
Examples | |
======== | |
>>> from sympy.simplify.cse_main import cse_separate | |
>>> from sympy.abc import x, y, z | |
>>> from sympy import cos, exp, cse, Eq, symbols | |
>>> x0, x1 = symbols('x:2') | |
>>> eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) | |
>>> cse([eq, Eq(x, z + 1), z - 2], postprocess=cse_separate) in [ | |
... [[(x0, y + 1), (x, z + 1), (x1, x + 1)], | |
... [x1 + exp(x1/x0) + cos(x0), z - 2]], | |
... [[(x1, y + 1), (x, z + 1), (x0, x + 1)], | |
... [x0 + exp(x0/x1) + cos(x1), z - 2]]] | |
... | |
True | |
""" | |
d = sift(e, lambda w: w.is_Equality and w.lhs.is_Symbol) | |
r = r + [w.args for w in d[True]] | |
e = d[False] | |
return [reps_toposort(r), e] | |
def cse_release_variables(r, e): | |
""" | |
Return tuples giving ``(a, b)`` where ``a`` is a symbol and ``b`` is | |
either an expression or None. The value of None is used when a | |
symbol is no longer needed for subsequent expressions. | |
Use of such output can reduce the memory footprint of lambdified | |
expressions that contain large, repeated subexpressions. | |
Examples | |
======== | |
>>> from sympy import cse | |
>>> from sympy.simplify.cse_main import cse_release_variables | |
>>> from sympy.abc import x, y | |
>>> eqs = [(x + y - 1)**2, x, x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)] | |
>>> defs, rvs = cse_release_variables(*cse(eqs)) | |
>>> for i in defs: | |
... print(i) | |
... | |
(x0, x + y) | |
(x1, (x0 - 1)**2) | |
(x2, 2*x + 1) | |
(_3, x0/x2 + x1) | |
(_4, x2**x0) | |
(x2, None) | |
(_0, x1) | |
(x1, None) | |
(_2, x0) | |
(x0, None) | |
(_1, x) | |
>>> print(rvs) | |
(_0, _1, _2, _3, _4) | |
""" | |
if not r: | |
return r, e | |
s, p = zip(*r) | |
esyms = symbols('_:%d' % len(e)) | |
syms = list(esyms) | |
s = list(s) | |
in_use = set(s) | |
p = list(p) | |
# sort e so those with most sub-expressions appear first | |
e = [(e[i], syms[i]) for i in range(len(e))] | |
e, syms = zip(*sorted(e, | |
key=lambda x: -sum(p[s.index(i)].count_ops() | |
for i in x[0].free_symbols & in_use))) | |
syms = list(syms) | |
p += e | |
rv = [] | |
i = len(p) - 1 | |
while i >= 0: | |
_p = p.pop() | |
c = in_use & _p.free_symbols | |
if c: # sorting for canonical results | |
rv.extend([(s, None) for s in sorted(c, key=str)]) | |
if i >= len(r): | |
rv.append((syms.pop(), _p)) | |
else: | |
rv.append((s[i], _p)) | |
in_use -= c | |
i -= 1 | |
rv.reverse() | |
return rv, esyms | |
# ====end of cse postprocess idioms=========================== | |
def preprocess_for_cse(expr, optimizations): | |
""" Preprocess an expression to optimize for common subexpression | |
elimination. | |
Parameters | |
========== | |
expr : SymPy expression | |
The target expression to optimize. | |
optimizations : list of (callable, callable) pairs | |
The (preprocessor, postprocessor) pairs. | |
Returns | |
======= | |
expr : SymPy expression | |
The transformed expression. | |
""" | |
for pre, post in optimizations: | |
if pre is not None: | |
expr = pre(expr) | |
return expr | |
def postprocess_for_cse(expr, optimizations): | |
"""Postprocess an expression after common subexpression elimination to | |
return the expression to canonical SymPy form. | |
Parameters | |
========== | |
expr : SymPy expression | |
The target expression to transform. | |
optimizations : list of (callable, callable) pairs, optional | |
The (preprocessor, postprocessor) pairs. The postprocessors will be | |
applied in reversed order to undo the effects of the preprocessors | |
correctly. | |
Returns | |
======= | |
expr : SymPy expression | |
The transformed expression. | |
""" | |
for pre, post in reversed(optimizations): | |
if post is not None: | |
expr = post(expr) | |
return expr | |
class FuncArgTracker: | |
""" | |
A class which manages a mapping from functions to arguments and an inverse | |
mapping from arguments to functions. | |
""" | |
def __init__(self, funcs): | |
# To minimize the number of symbolic comparisons, all function arguments | |
# get assigned a value number. | |
self.value_numbers = {} | |
self.value_number_to_value = [] | |
# Both of these maps use integer indices for arguments / functions. | |
self.arg_to_funcset = [] | |
self.func_to_argset = [] | |
for func_i, func in enumerate(funcs): | |
func_argset = OrderedSet() | |
for func_arg in func.args: | |
arg_number = self.get_or_add_value_number(func_arg) | |
func_argset.add(arg_number) | |
self.arg_to_funcset[arg_number].add(func_i) | |
self.func_to_argset.append(func_argset) | |
def get_args_in_value_order(self, argset): | |
""" | |
Return the list of arguments in sorted order according to their value | |
numbers. | |
""" | |
return [self.value_number_to_value[argn] for argn in sorted(argset)] | |
def get_or_add_value_number(self, value): | |
""" | |
Return the value number for the given argument. | |
""" | |
nvalues = len(self.value_numbers) | |
value_number = self.value_numbers.setdefault(value, nvalues) | |
if value_number == nvalues: | |
self.value_number_to_value.append(value) | |
self.arg_to_funcset.append(OrderedSet()) | |
return value_number | |
def stop_arg_tracking(self, func_i): | |
""" | |
Remove the function func_i from the argument to function mapping. | |
""" | |
for arg in self.func_to_argset[func_i]: | |
self.arg_to_funcset[arg].remove(func_i) | |
def get_common_arg_candidates(self, argset, min_func_i=0): | |
"""Return a dict whose keys are function numbers. The entries of the dict are | |
the number of arguments said function has in common with | |
``argset``. Entries have at least 2 items in common. All keys have | |
value at least ``min_func_i``. | |
""" | |
count_map = defaultdict(lambda: 0) | |
if not argset: | |
return count_map | |
funcsets = [self.arg_to_funcset[arg] for arg in argset] | |
# As an optimization below, we handle the largest funcset separately from | |
# the others. | |
largest_funcset = max(funcsets, key=len) | |
for funcset in funcsets: | |
if largest_funcset is funcset: | |
continue | |
for func_i in funcset: | |
if func_i >= min_func_i: | |
count_map[func_i] += 1 | |
# We pick the smaller of the two containers (count_map, largest_funcset) | |
# to iterate over to reduce the number of iterations needed. | |
(smaller_funcs_container, | |
larger_funcs_container) = sorted( | |
[largest_funcset, count_map], | |
key=len) | |
for func_i in smaller_funcs_container: | |
# Not already in count_map? It can't possibly be in the output, so | |
# skip it. | |
if count_map[func_i] < 1: | |
continue | |
if func_i in larger_funcs_container: | |
count_map[func_i] += 1 | |
return {k: v for k, v in count_map.items() if v >= 2} | |
def get_subset_candidates(self, argset, restrict_to_funcset=None): | |
""" | |
Return a set of functions each of which whose argument list contains | |
``argset``, optionally filtered only to contain functions in | |
``restrict_to_funcset``. | |
""" | |
iarg = iter(argset) | |
indices = OrderedSet( | |
fi for fi in self.arg_to_funcset[next(iarg)]) | |
if restrict_to_funcset is not None: | |
indices &= restrict_to_funcset | |
for arg in iarg: | |
indices &= self.arg_to_funcset[arg] | |
return indices | |
def update_func_argset(self, func_i, new_argset): | |
""" | |
Update a function with a new set of arguments. | |
""" | |
new_args = OrderedSet(new_argset) | |
old_args = self.func_to_argset[func_i] | |
for deleted_arg in old_args - new_args: | |
self.arg_to_funcset[deleted_arg].remove(func_i) | |
for added_arg in new_args - old_args: | |
self.arg_to_funcset[added_arg].add(func_i) | |
self.func_to_argset[func_i].clear() | |
self.func_to_argset[func_i].update(new_args) | |
class Unevaluated: | |
def __init__(self, func, args): | |
self.func = func | |
self.args = args | |
def __str__(self): | |
return "Uneval<{}>({})".format( | |
self.func, ", ".join(str(a) for a in self.args)) | |
def as_unevaluated_basic(self): | |
return self.func(*self.args, evaluate=False) | |
def free_symbols(self): | |
return set().union(*[a.free_symbols for a in self.args]) | |
__repr__ = __str__ | |
def match_common_args(func_class, funcs, opt_subs): | |
""" | |
Recognize and extract common subexpressions of function arguments within a | |
set of function calls. For instance, for the following function calls:: | |
x + z + y | |
sin(x + y) | |
this will extract a common subexpression of `x + y`:: | |
w = x + y | |
w + z | |
sin(w) | |
The function we work with is assumed to be associative and commutative. | |
Parameters | |
========== | |
func_class: class | |
The function class (e.g. Add, Mul) | |
funcs: list of functions | |
A list of function calls. | |
opt_subs: dict | |
A dictionary of substitutions which this function may update. | |
""" | |
# Sort to ensure that whole-function subexpressions come before the items | |
# that use them. | |
funcs = sorted(funcs, key=lambda f: len(f.args)) | |
arg_tracker = FuncArgTracker(funcs) | |
changed = OrderedSet() | |
for i in range(len(funcs)): | |
common_arg_candidates_counts = arg_tracker.get_common_arg_candidates( | |
arg_tracker.func_to_argset[i], min_func_i=i + 1) | |
# Sort the candidates in order of match size. | |
# This makes us try combining smaller matches first. | |
common_arg_candidates = OrderedSet(sorted( | |
common_arg_candidates_counts.keys(), | |
key=lambda k: (common_arg_candidates_counts[k], k))) | |
while common_arg_candidates: | |
j = common_arg_candidates.pop(last=False) | |
com_args = arg_tracker.func_to_argset[i].intersection( | |
arg_tracker.func_to_argset[j]) | |
if len(com_args) <= 1: | |
# This may happen if a set of common arguments was already | |
# combined in a previous iteration. | |
continue | |
# For all sets, replace the common symbols by the function | |
# over them, to allow recursive matches. | |
diff_i = arg_tracker.func_to_argset[i].difference(com_args) | |
if diff_i: | |
# com_func needs to be unevaluated to allow for recursive matches. | |
com_func = Unevaluated( | |
func_class, arg_tracker.get_args_in_value_order(com_args)) | |
com_func_number = arg_tracker.get_or_add_value_number(com_func) | |
arg_tracker.update_func_argset(i, diff_i | OrderedSet([com_func_number])) | |
changed.add(i) | |
else: | |
# Treat the whole expression as a CSE. | |
# | |
# The reason this needs to be done is somewhat subtle. Within | |
# tree_cse(), to_eliminate only contains expressions that are | |
# seen more than once. The problem is unevaluated expressions | |
# do not compare equal to the evaluated equivalent. So | |
# tree_cse() won't mark funcs[i] as a CSE if we use an | |
# unevaluated version. | |
com_func_number = arg_tracker.get_or_add_value_number(funcs[i]) | |
diff_j = arg_tracker.func_to_argset[j].difference(com_args) | |
arg_tracker.update_func_argset(j, diff_j | OrderedSet([com_func_number])) | |
changed.add(j) | |
for k in arg_tracker.get_subset_candidates( | |
com_args, common_arg_candidates): | |
diff_k = arg_tracker.func_to_argset[k].difference(com_args) | |
arg_tracker.update_func_argset(k, diff_k | OrderedSet([com_func_number])) | |
changed.add(k) | |
if i in changed: | |
opt_subs[funcs[i]] = Unevaluated(func_class, | |
arg_tracker.get_args_in_value_order(arg_tracker.func_to_argset[i])) | |
arg_tracker.stop_arg_tracking(i) | |
def opt_cse(exprs, order='canonical'): | |
"""Find optimization opportunities in Adds, Muls, Pows and negative | |
coefficient Muls. | |
Parameters | |
========== | |
exprs : list of SymPy expressions | |
The expressions to optimize. | |
order : string, 'none' or 'canonical' | |
The order by which Mul and Add arguments are processed. For large | |
expressions where speed is a concern, use the setting order='none'. | |
Returns | |
======= | |
opt_subs : dictionary of expression substitutions | |
The expression substitutions which can be useful to optimize CSE. | |
Examples | |
======== | |
>>> from sympy.simplify.cse_main import opt_cse | |
>>> from sympy.abc import x | |
>>> opt_subs = opt_cse([x**-2]) | |
>>> k, v = list(opt_subs.keys())[0], list(opt_subs.values())[0] | |
>>> print((k, v.as_unevaluated_basic())) | |
(x**(-2), 1/(x**2)) | |
""" | |
opt_subs = {} | |
adds = OrderedSet() | |
muls = OrderedSet() | |
seen_subexp = set() | |
collapsible_subexp = set() | |
def _find_opts(expr): | |
if not isinstance(expr, (Basic, Unevaluated)): | |
return | |
if expr.is_Atom or expr.is_Order: | |
return | |
if iterable(expr): | |
list(map(_find_opts, expr)) | |
return | |
if expr in seen_subexp: | |
return expr | |
seen_subexp.add(expr) | |
list(map(_find_opts, expr.args)) | |
if not isinstance(expr, MatrixExpr) and expr.could_extract_minus_sign(): | |
# XXX -expr does not always work rigorously for some expressions | |
# containing UnevaluatedExpr. | |
# https://github.com/sympy/sympy/issues/24818 | |
if isinstance(expr, Add): | |
neg_expr = Add(*(-i for i in expr.args)) | |
else: | |
neg_expr = -expr | |
if not neg_expr.is_Atom: | |
opt_subs[expr] = Unevaluated(Mul, (S.NegativeOne, neg_expr)) | |
seen_subexp.add(neg_expr) | |
expr = neg_expr | |
if isinstance(expr, (Mul, MatMul)): | |
if len(expr.args) == 1: | |
collapsible_subexp.add(expr) | |
else: | |
muls.add(expr) | |
elif isinstance(expr, (Add, MatAdd)): | |
if len(expr.args) == 1: | |
collapsible_subexp.add(expr) | |
else: | |
adds.add(expr) | |
elif isinstance(expr, Inverse): | |
# Do not want to treat `Inverse` as a `MatPow` | |
pass | |
elif isinstance(expr, (Pow, MatPow)): | |
base, exp = expr.base, expr.exp | |
if exp.could_extract_minus_sign(): | |
opt_subs[expr] = Unevaluated(Pow, (Pow(base, -exp), -1)) | |
for e in exprs: | |
if isinstance(e, (Basic, Unevaluated)): | |
_find_opts(e) | |
# Handle collapsing of multinary operations with single arguments | |
edges = [(s, s.args[0]) for s in collapsible_subexp | |
if s.args[0] in collapsible_subexp] | |
for e in reversed(topological_sort((collapsible_subexp, edges))): | |
opt_subs[e] = opt_subs.get(e.args[0], e.args[0]) | |
# split muls into commutative | |
commutative_muls = OrderedSet() | |
for m in muls: | |
c, nc = m.args_cnc(cset=False) | |
if c: | |
c_mul = m.func(*c) | |
if nc: | |
if c_mul == 1: | |
new_obj = m.func(*nc) | |
else: | |
if isinstance(m, MatMul): | |
new_obj = m.func(c_mul, *nc, evaluate=False) | |
else: | |
new_obj = m.func(c_mul, m.func(*nc), evaluate=False) | |
opt_subs[m] = new_obj | |
if len(c) > 1: | |
commutative_muls.add(c_mul) | |
match_common_args(Add, adds, opt_subs) | |
match_common_args(Mul, commutative_muls, opt_subs) | |
return opt_subs | |
def tree_cse(exprs, symbols, opt_subs=None, order='canonical', ignore=()): | |
"""Perform raw CSE on expression tree, taking opt_subs into account. | |
Parameters | |
========== | |
exprs : list of SymPy expressions | |
The expressions to reduce. | |
symbols : infinite iterator yielding unique Symbols | |
The symbols used to label the common subexpressions which are pulled | |
out. | |
opt_subs : dictionary of expression substitutions | |
The expressions to be substituted before any CSE action is performed. | |
order : string, 'none' or 'canonical' | |
The order by which Mul and Add arguments are processed. For large | |
expressions where speed is a concern, use the setting order='none'. | |
ignore : iterable of Symbols | |
Substitutions containing any Symbol from ``ignore`` will be ignored. | |
""" | |
if opt_subs is None: | |
opt_subs = {} | |
## Find repeated sub-expressions | |
to_eliminate = set() | |
seen_subexp = set() | |
excluded_symbols = set() | |
def _find_repeated(expr): | |
if not isinstance(expr, (Basic, Unevaluated)): | |
return | |
if isinstance(expr, RootOf): | |
return | |
if isinstance(expr, Basic) and ( | |
expr.is_Atom or | |
expr.is_Order or | |
isinstance(expr, (MatrixSymbol, MatrixElement))): | |
if expr.is_Symbol: | |
excluded_symbols.add(expr.name) | |
return | |
if iterable(expr): | |
args = expr | |
else: | |
if expr in seen_subexp: | |
for ign in ignore: | |
if ign in expr.free_symbols: | |
break | |
else: | |
to_eliminate.add(expr) | |
return | |
seen_subexp.add(expr) | |
if expr in opt_subs: | |
expr = opt_subs[expr] | |
args = expr.args | |
list(map(_find_repeated, args)) | |
for e in exprs: | |
if isinstance(e, Basic): | |
_find_repeated(e) | |
## Rebuild tree | |
# Remove symbols from the generator that conflict with names in the expressions. | |
symbols = (_ for _ in symbols if _.name not in excluded_symbols) | |
replacements = [] | |
subs = {} | |
def _rebuild(expr): | |
if not isinstance(expr, (Basic, Unevaluated)): | |
return expr | |
if not expr.args: | |
return expr | |
if iterable(expr): | |
new_args = [_rebuild(arg) for arg in expr.args] | |
return expr.func(*new_args) | |
if expr in subs: | |
return subs[expr] | |
orig_expr = expr | |
if expr in opt_subs: | |
expr = opt_subs[expr] | |
# If enabled, parse Muls and Adds arguments by order to ensure | |
# replacement order independent from hashes | |
if order != 'none': | |
if isinstance(expr, (Mul, MatMul)): | |
c, nc = expr.args_cnc() | |
if c == [1]: | |
args = nc | |
else: | |
args = list(ordered(c)) + nc | |
elif isinstance(expr, (Add, MatAdd)): | |
args = list(ordered(expr.args)) | |
else: | |
args = expr.args | |
else: | |
args = expr.args | |
new_args = list(map(_rebuild, args)) | |
if isinstance(expr, Unevaluated) or new_args != args: | |
new_expr = expr.func(*new_args) | |
else: | |
new_expr = expr | |
if orig_expr in to_eliminate: | |
try: | |
sym = next(symbols) | |
except StopIteration: | |
raise ValueError("Symbols iterator ran out of symbols.") | |
if isinstance(orig_expr, MatrixExpr): | |
sym = MatrixSymbol(sym.name, orig_expr.rows, | |
orig_expr.cols) | |
subs[orig_expr] = sym | |
replacements.append((sym, new_expr)) | |
return sym | |
else: | |
return new_expr | |
reduced_exprs = [] | |
for e in exprs: | |
if isinstance(e, Basic): | |
reduced_e = _rebuild(e) | |
else: | |
reduced_e = e | |
reduced_exprs.append(reduced_e) | |
return replacements, reduced_exprs | |
def cse(exprs, symbols=None, optimizations=None, postprocess=None, | |
order='canonical', ignore=(), list=True): | |
""" Perform common subexpression elimination on an expression. | |
Parameters | |
========== | |
exprs : list of SymPy expressions, or a single SymPy expression | |
The expressions to reduce. | |
symbols : infinite iterator yielding unique Symbols | |
The symbols used to label the common subexpressions which are pulled | |
out. The ``numbered_symbols`` generator is useful. The default is a | |
stream of symbols of the form "x0", "x1", etc. This must be an | |
infinite iterator. | |
optimizations : list of (callable, callable) pairs | |
The (preprocessor, postprocessor) pairs of external optimization | |
functions. Optionally 'basic' can be passed for a set of predefined | |
basic optimizations. Such 'basic' optimizations were used by default | |
in old implementation, however they can be really slow on larger | |
expressions. Now, no pre or post optimizations are made by default. | |
postprocess : a function which accepts the two return values of cse and | |
returns the desired form of output from cse, e.g. if you want the | |
replacements reversed the function might be the following lambda: | |
lambda r, e: return reversed(r), e | |
order : string, 'none' or 'canonical' | |
The order by which Mul and Add arguments are processed. If set to | |
'canonical', arguments will be canonically ordered. If set to 'none', | |
ordering will be faster but dependent on expressions hashes, thus | |
machine dependent and variable. For large expressions where speed is a | |
concern, use the setting order='none'. | |
ignore : iterable of Symbols | |
Substitutions containing any Symbol from ``ignore`` will be ignored. | |
list : bool, (default True) | |
Returns expression in list or else with same type as input (when False). | |
Returns | |
======= | |
replacements : list of (Symbol, expression) pairs | |
All of the common subexpressions that were replaced. Subexpressions | |
earlier in this list might show up in subexpressions later in this | |
list. | |
reduced_exprs : list of SymPy expressions | |
The reduced expressions with all of the replacements above. | |
Examples | |
======== | |
>>> from sympy import cse, SparseMatrix | |
>>> from sympy.abc import x, y, z, w | |
>>> cse(((w + x + y + z)*(w + y + z))/(w + x)**3) | |
([(x0, y + z), (x1, w + x)], [(w + x0)*(x0 + x1)/x1**3]) | |
List of expressions with recursive substitutions: | |
>>> m = SparseMatrix([x + y, x + y + z]) | |
>>> cse([(x+y)**2, x + y + z, y + z, x + z + y, m]) | |
([(x0, x + y), (x1, x0 + z)], [x0**2, x1, y + z, x1, Matrix([ | |
[x0], | |
[x1]])]) | |
Note: the type and mutability of input matrices is retained. | |
>>> isinstance(_[1][-1], SparseMatrix) | |
True | |
The user may disallow substitutions containing certain symbols: | |
>>> cse([y**2*(x + 1), 3*y**2*(x + 1)], ignore=(y,)) | |
([(x0, x + 1)], [x0*y**2, 3*x0*y**2]) | |
The default return value for the reduced expression(s) is a list, even if there is only | |
one expression. The `list` flag preserves the type of the input in the output: | |
>>> cse(x) | |
([], [x]) | |
>>> cse(x, list=False) | |
([], x) | |
""" | |
if not list: | |
return _cse_homogeneous(exprs, | |
symbols=symbols, optimizations=optimizations, | |
postprocess=postprocess, order=order, ignore=ignore) | |
if isinstance(exprs, (int, float)): | |
exprs = sympify(exprs) | |
# Handle the case if just one expression was passed. | |
if isinstance(exprs, (Basic, MatrixBase)): | |
exprs = [exprs] | |
copy = exprs | |
temp = [] | |
for e in exprs: | |
if isinstance(e, (Matrix, ImmutableMatrix)): | |
temp.append(Tuple(*e.flat())) | |
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): | |
temp.append(Tuple(*e.todok().items())) | |
else: | |
temp.append(e) | |
exprs = temp | |
del temp | |
if optimizations is None: | |
optimizations = [] | |
elif optimizations == 'basic': | |
optimizations = basic_optimizations | |
# Preprocess the expressions to give us better optimization opportunities. | |
reduced_exprs = [preprocess_for_cse(e, optimizations) for e in exprs] | |
if symbols is None: | |
symbols = numbered_symbols(cls=Symbol) | |
else: | |
# In case we get passed an iterable with an __iter__ method instead of | |
# an actual iterator. | |
symbols = iter(symbols) | |
# Find other optimization opportunities. | |
opt_subs = opt_cse(reduced_exprs, order) | |
# Main CSE algorithm. | |
replacements, reduced_exprs = tree_cse(reduced_exprs, symbols, opt_subs, | |
order, ignore) | |
# Postprocess the expressions to return the expressions to canonical form. | |
exprs = copy | |
for i, (sym, subtree) in enumerate(replacements): | |
subtree = postprocess_for_cse(subtree, optimizations) | |
replacements[i] = (sym, subtree) | |
reduced_exprs = [postprocess_for_cse(e, optimizations) | |
for e in reduced_exprs] | |
# Get the matrices back | |
for i, e in enumerate(exprs): | |
if isinstance(e, (Matrix, ImmutableMatrix)): | |
reduced_exprs[i] = Matrix(e.rows, e.cols, reduced_exprs[i]) | |
if isinstance(e, ImmutableMatrix): | |
reduced_exprs[i] = reduced_exprs[i].as_immutable() | |
elif isinstance(e, (SparseMatrix, ImmutableSparseMatrix)): | |
m = SparseMatrix(e.rows, e.cols, {}) | |
for k, v in reduced_exprs[i]: | |
m[k] = v | |
if isinstance(e, ImmutableSparseMatrix): | |
m = m.as_immutable() | |
reduced_exprs[i] = m | |
if postprocess is None: | |
return replacements, reduced_exprs | |
return postprocess(replacements, reduced_exprs) | |
def _cse_homogeneous(exprs, **kwargs): | |
""" | |
Same as ``cse`` but the ``reduced_exprs`` are returned | |
with the same type as ``exprs`` or a sympified version of the same. | |
Parameters | |
========== | |
exprs : an Expr, iterable of Expr or dictionary with Expr values | |
the expressions in which repeated subexpressions will be identified | |
kwargs : additional arguments for the ``cse`` function | |
Returns | |
======= | |
replacements : list of (Symbol, expression) pairs | |
All of the common subexpressions that were replaced. Subexpressions | |
earlier in this list might show up in subexpressions later in this | |
list. | |
reduced_exprs : list of SymPy expressions | |
The reduced expressions with all of the replacements above. | |
Examples | |
======== | |
>>> from sympy.simplify.cse_main import cse | |
>>> from sympy import cos, Tuple, Matrix | |
>>> from sympy.abc import x | |
>>> output = lambda x: type(cse(x, list=False)[1]) | |
>>> output(1) | |
<class 'sympy.core.numbers.One'> | |
>>> output('cos(x)') | |
<class 'str'> | |
>>> output(cos(x)) | |
cos | |
>>> output(Tuple(1, x)) | |
<class 'sympy.core.containers.Tuple'> | |
>>> output(Matrix([[1,0], [0,1]])) | |
<class 'sympy.matrices.dense.MutableDenseMatrix'> | |
>>> output([1, x]) | |
<class 'list'> | |
>>> output((1, x)) | |
<class 'tuple'> | |
>>> output({1, x}) | |
<class 'set'> | |
""" | |
if isinstance(exprs, str): | |
replacements, reduced_exprs = _cse_homogeneous( | |
sympify(exprs), **kwargs) | |
return replacements, repr(reduced_exprs) | |
if isinstance(exprs, (list, tuple, set)): | |
replacements, reduced_exprs = cse(exprs, **kwargs) | |
return replacements, type(exprs)(reduced_exprs) | |
if isinstance(exprs, dict): | |
keys = list(exprs.keys()) # In order to guarantee the order of the elements. | |
replacements, values = cse([exprs[k] for k in keys], **kwargs) | |
reduced_exprs = dict(zip(keys, values)) | |
return replacements, reduced_exprs | |
try: | |
replacements, (reduced_exprs,) = cse(exprs, **kwargs) | |
except TypeError: # For example 'mpf' objects | |
return [], exprs | |
else: | |
return replacements, reduced_exprs | |