Spaces:
Sleeping
Sleeping
File size: 6,447 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import math
from sympy.sets.sets import Interval
from sympy.calculus.singularities import is_increasing, is_decreasing
from sympy.codegen.rewriting import Optimization
from sympy.core.function import UndefinedFunction
"""
This module collects classes useful for approimate rewriting of expressions.
This can be beneficial when generating numeric code for which performance is
of greater importance than precision (e.g. for preconditioners used in iterative
methods).
"""
class SumApprox(Optimization):
"""
Approximates sum by neglecting small terms.
Explanation
===========
If terms are expressions which can be determined to be monotonic, then
bounds for those expressions are added.
Parameters
==========
bounds : dict
Mapping expressions to length 2 tuple of bounds (low, high).
reltol : number
Threshold for when to ignore a term. Taken relative to the largest
lower bound among bounds.
Examples
========
>>> from sympy import exp
>>> from sympy.abc import x, y, z
>>> from sympy.codegen.rewriting import optimize
>>> from sympy.codegen.approximations import SumApprox
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
>>> expr = 3*(x + y + exp(z))
>>> optimize(expr, [sum_approx3])
3*(x + y + exp(z))
>>> optimize(expr, [sum_approx2])
3*y + 3*exp(z)
>>> optimize(expr, [sum_approx1])
3*y
"""
def __init__(self, bounds, reltol, **kwargs):
super().__init__(**kwargs)
self.bounds = bounds
self.reltol = reltol
def __call__(self, expr):
return expr.factor().replace(self.query, lambda arg: self.value(arg))
def query(self, expr):
return expr.is_Add
def value(self, add):
for term in add.args:
if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
continue
fs, = term.free_symbols
if fs not in self.bounds:
continue
intrvl = Interval(*self.bounds[fs])
if is_increasing(term, intrvl, fs):
self.bounds[term] = (
term.subs({fs: self.bounds[fs][0]}),
term.subs({fs: self.bounds[fs][1]})
)
elif is_decreasing(term, intrvl, fs):
self.bounds[term] = (
term.subs({fs: self.bounds[fs][1]}),
term.subs({fs: self.bounds[fs][0]})
)
else:
return add
if all(term.is_number or term in self.bounds for term in add.args):
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
largest_abs_guarantee = 0
for lo, hi in bounds:
if lo <= 0 <= hi:
continue
largest_abs_guarantee = max(largest_abs_guarantee,
min(abs(lo), abs(hi)))
new_terms = []
for term, (lo, hi) in zip(add.args, bounds):
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
new_terms.append(term)
return add.func(*new_terms)
else:
return add
class SeriesApprox(Optimization):
""" Approximates functions by expanding them as a series.
Parameters
==========
bounds : dict
Mapping expressions to length 2 tuple of bounds (low, high).
reltol : number
Threshold for when to ignore a term. Taken relative to the largest
lower bound among bounds.
max_order : int
Largest order to include in series expansion
n_point_checks : int (even)
The validity of an expansion (with respect to reltol) is checked at
discrete points (linearly spaced over the bounds of the variable). The
number of points used in this numerical check is given by this number.
Examples
========
>>> from sympy import sin, pi
>>> from sympy.abc import x, y
>>> from sympy.codegen.rewriting import optimize
>>> from sympy.codegen.approximations import SeriesApprox
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
>>> expr = sin(x)*sin(y)
>>> optimize(expr, [series_approx2])
x*(-y + (y - pi)**3/6 + pi)
>>> optimize(expr, [series_approx3])
(-x**3/6 + x)*sin(y)
>>> optimize(expr, [series_approx8])
sin(x)*sin(y)
"""
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
super().__init__(**kwargs)
self.bounds = bounds
self.reltol = reltol
self.max_order = max_order
if n_point_checks % 2 == 1:
raise ValueError("Checking the solution at expansion point is not helpful")
self.n_point_checks = n_point_checks
self._prec = math.ceil(-math.log10(self.reltol))
def __call__(self, expr):
return expr.factor().replace(self.query, lambda arg: self.value(arg))
def query(self, expr):
return (expr.is_Function and not isinstance(expr, UndefinedFunction)
and len(expr.args) == 1)
def value(self, fexpr):
free_symbols = fexpr.free_symbols
if len(free_symbols) != 1:
return fexpr
symb, = free_symbols
if symb not in self.bounds:
return fexpr
lo, hi = self.bounds[symb]
x0 = (lo + hi)/2
cheapest = None
for n in range(self.max_order+1, 0, -1):
fseri = fexpr.series(symb, x0=x0, n=n).removeO()
n_ok = True
for idx in range(self.n_point_checks):
x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
val = fseri.xreplace({symb: x})
ref = fexpr.xreplace({symb: x})
if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
n_ok = False
break
if n_ok:
cheapest = fseri
else:
break
if cheapest is None:
return fexpr
else:
return cheapest
|