Spaces:
Sleeping
Sleeping
import math | |
from sympy.sets.sets import Interval | |
from sympy.calculus.singularities import is_increasing, is_decreasing | |
from sympy.codegen.rewriting import Optimization | |
from sympy.core.function import UndefinedFunction | |
""" | |
This module collects classes useful for approimate rewriting of expressions. | |
This can be beneficial when generating numeric code for which performance is | |
of greater importance than precision (e.g. for preconditioners used in iterative | |
methods). | |
""" | |
class SumApprox(Optimization): | |
""" | |
Approximates sum by neglecting small terms. | |
Explanation | |
=========== | |
If terms are expressions which can be determined to be monotonic, then | |
bounds for those expressions are added. | |
Parameters | |
========== | |
bounds : dict | |
Mapping expressions to length 2 tuple of bounds (low, high). | |
reltol : number | |
Threshold for when to ignore a term. Taken relative to the largest | |
lower bound among bounds. | |
Examples | |
======== | |
>>> from sympy import exp | |
>>> from sympy.abc import x, y, z | |
>>> from sympy.codegen.rewriting import optimize | |
>>> from sympy.codegen.approximations import SumApprox | |
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)} | |
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3) | |
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2) | |
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1) | |
>>> expr = 3*(x + y + exp(z)) | |
>>> optimize(expr, [sum_approx3]) | |
3*(x + y + exp(z)) | |
>>> optimize(expr, [sum_approx2]) | |
3*y + 3*exp(z) | |
>>> optimize(expr, [sum_approx1]) | |
3*y | |
""" | |
def __init__(self, bounds, reltol, **kwargs): | |
super().__init__(**kwargs) | |
self.bounds = bounds | |
self.reltol = reltol | |
def __call__(self, expr): | |
return expr.factor().replace(self.query, lambda arg: self.value(arg)) | |
def query(self, expr): | |
return expr.is_Add | |
def value(self, add): | |
for term in add.args: | |
if term.is_number or term in self.bounds or len(term.free_symbols) != 1: | |
continue | |
fs, = term.free_symbols | |
if fs not in self.bounds: | |
continue | |
intrvl = Interval(*self.bounds[fs]) | |
if is_increasing(term, intrvl, fs): | |
self.bounds[term] = ( | |
term.subs({fs: self.bounds[fs][0]}), | |
term.subs({fs: self.bounds[fs][1]}) | |
) | |
elif is_decreasing(term, intrvl, fs): | |
self.bounds[term] = ( | |
term.subs({fs: self.bounds[fs][1]}), | |
term.subs({fs: self.bounds[fs][0]}) | |
) | |
else: | |
return add | |
if all(term.is_number or term in self.bounds for term in add.args): | |
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args] | |
largest_abs_guarantee = 0 | |
for lo, hi in bounds: | |
if lo <= 0 <= hi: | |
continue | |
largest_abs_guarantee = max(largest_abs_guarantee, | |
min(abs(lo), abs(hi))) | |
new_terms = [] | |
for term, (lo, hi) in zip(add.args, bounds): | |
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol: | |
new_terms.append(term) | |
return add.func(*new_terms) | |
else: | |
return add | |
class SeriesApprox(Optimization): | |
""" Approximates functions by expanding them as a series. | |
Parameters | |
========== | |
bounds : dict | |
Mapping expressions to length 2 tuple of bounds (low, high). | |
reltol : number | |
Threshold for when to ignore a term. Taken relative to the largest | |
lower bound among bounds. | |
max_order : int | |
Largest order to include in series expansion | |
n_point_checks : int (even) | |
The validity of an expansion (with respect to reltol) is checked at | |
discrete points (linearly spaced over the bounds of the variable). The | |
number of points used in this numerical check is given by this number. | |
Examples | |
======== | |
>>> from sympy import sin, pi | |
>>> from sympy.abc import x, y | |
>>> from sympy.codegen.rewriting import optimize | |
>>> from sympy.codegen.approximations import SeriesApprox | |
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)} | |
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2) | |
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3) | |
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8) | |
>>> expr = sin(x)*sin(y) | |
>>> optimize(expr, [series_approx2]) | |
x*(-y + (y - pi)**3/6 + pi) | |
>>> optimize(expr, [series_approx3]) | |
(-x**3/6 + x)*sin(y) | |
>>> optimize(expr, [series_approx8]) | |
sin(x)*sin(y) | |
""" | |
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs): | |
super().__init__(**kwargs) | |
self.bounds = bounds | |
self.reltol = reltol | |
self.max_order = max_order | |
if n_point_checks % 2 == 1: | |
raise ValueError("Checking the solution at expansion point is not helpful") | |
self.n_point_checks = n_point_checks | |
self._prec = math.ceil(-math.log10(self.reltol)) | |
def __call__(self, expr): | |
return expr.factor().replace(self.query, lambda arg: self.value(arg)) | |
def query(self, expr): | |
return (expr.is_Function and not isinstance(expr, UndefinedFunction) | |
and len(expr.args) == 1) | |
def value(self, fexpr): | |
free_symbols = fexpr.free_symbols | |
if len(free_symbols) != 1: | |
return fexpr | |
symb, = free_symbols | |
if symb not in self.bounds: | |
return fexpr | |
lo, hi = self.bounds[symb] | |
x0 = (lo + hi)/2 | |
cheapest = None | |
for n in range(self.max_order+1, 0, -1): | |
fseri = fexpr.series(symb, x0=x0, n=n).removeO() | |
n_ok = True | |
for idx in range(self.n_point_checks): | |
x = lo + idx*(hi - lo)/(self.n_point_checks - 1) | |
val = fseri.xreplace({symb: x}) | |
ref = fexpr.xreplace({symb: x}) | |
if abs((1 - val/ref).evalf(self._prec)) > self.reltol: | |
n_ok = False | |
break | |
if n_ok: | |
cheapest = fseri | |
else: | |
break | |
if cheapest is None: | |
return fexpr | |
else: | |
return cheapest | |