Spaces:
Sleeping
Sleeping
File size: 6,383 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from sympy.core.containers import Tuple
from sympy.core.numbers import oo
from sympy.core.relational import (Gt, Lt)
from sympy.core.symbol import (Dummy, Symbol)
from sympy.functions.elementary.complexes import Abs
from sympy.functions.elementary.miscellaneous import Min, Max
from sympy.logic.boolalg import And
from sympy.codegen.ast import (
Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition,
Print, Return, Scope, While, Variable, Pointer, real
)
from sympy.codegen.cfunctions import isnan
""" This module collects functions for constructing ASTs representing algorithms. """
def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False,
itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x),
cse=False, handle_nan=None,
bounds=None):
""" Generates an AST for Newton-Raphson method (a root-finding algorithm).
Explanation
===========
Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
method of root-finding.
Parameters
==========
expr : expression
wrt : Symbol
With respect to, i.e. what is the variable.
atol : number or expression
Absolute tolerance (stopping criterion)
rtol : number or expression
Relative tolerance (stopping criterion)
delta : Symbol
Will be a ``Dummy`` if ``None``.
debug : bool
Whether to print convergence information during iterations
itermax : number or expr
Maximum number of iterations.
counter : Symbol
Will be a ``Dummy`` if ``None``.
delta_fn: Callable[[Expr, Symbol], Expr]
computes the step, default is newtons method. For e.g. Halley's method
use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2))
cse: bool
Perform common sub-expression elimination on delta expression
handle_nan: Token
How to handle occurrence of not-a-number (NaN).
bounds: Optional[tuple[Expr, Expr]]
Perform optimization within bounds
Examples
========
>>> from sympy import symbols, cos
>>> from sympy.codegen.ast import Assignment
>>> from sympy.codegen.algorithms import newtons_method
>>> x, dx, atol = symbols('x dx atol')
>>> expr = cos(x) - x**3
>>> algo = newtons_method(expr, x, atol=atol, delta=dx)
>>> algo.has(Assignment(dx, -expr/expr.diff(x)))
True
References
==========
.. [1] https://en.wikipedia.org/wiki/Newton%27s_method
"""
if delta is None:
delta = Dummy()
Wrapper = Scope
name_d = 'delta'
else:
Wrapper = lambda x: x
name_d = delta.name
delta_expr = delta_fn(expr, wrt)
if cse:
from sympy.simplify.cse_main import cse
cses, (red,) = cse([delta_expr.factor()])
whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses]
whl_bdy += [Assignment(delta, red)]
else:
whl_bdy = [Assignment(delta, delta_expr)]
if handle_nan is not None:
whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))]
whl_bdy += [AddAugmentedAssignment(wrt, delta)]
if bounds is not None:
whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))]
if debug:
prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
whl_bdy += [prnt]
req = Gt(Abs(delta), atol + rtol*Abs(wrt))
declars = [Declaration(Variable(delta, type=real, value=oo))]
if itermax is not None:
counter = counter or Dummy(integer=True)
v_counter = Variable.deduced(counter, 0)
declars.append(Declaration(v_counter))
whl_bdy.append(AddAugmentedAssignment(counter, 1))
req = And(req, Lt(counter, itermax))
whl = While(req, CodeBlock(*whl_bdy))
blck = declars
if debug:
blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name)))
blck += [whl]
return Wrapper(CodeBlock(*blck))
def _symbol_of(arg):
if isinstance(arg, Declaration):
arg = arg.variable.symbol
elif isinstance(arg, Variable):
arg = arg.symbol
return arg
def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
""" Generates an AST for a function implementing the Newton-Raphson method.
Parameters
==========
expr : expression
wrt : Symbol
With respect to, i.e. what is the variable
params : iterable of symbols
Symbols appearing in expr that are taken as constants during the iterations
(these will be accepted as parameters to the generated function).
func_name : str
Name of the generated function.
attrs : Tuple
Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
\\*\\*kwargs :
Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
Examples
========
>>> from sympy import symbols, cos
>>> from sympy.codegen.algorithms import newtons_method_function
>>> from sympy.codegen.pyutils import render_as_module
>>> x = symbols('x')
>>> expr = cos(x) - x**3
>>> func = newtons_method_function(expr, x)
>>> py_mod = render_as_module(func) # source code as string
>>> namespace = {}
>>> exec(py_mod, namespace, namespace)
>>> res = eval('newton(0.5)', namespace)
>>> abs(res - 0.865474033102) < 1e-12
True
See Also
========
sympy.codegen.algorithms.newtons_method
"""
if params is None:
params = (wrt,)
pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
for p in params if isinstance(p, Pointer)}
if delta is None:
delta = Symbol('d_' + wrt.name)
if expr.has(delta):
delta = None # will use Dummy
algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
if isinstance(algo, Scope):
algo = algo.body
not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
if not_in_params:
raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
declars = tuple(Variable(p, real) for p in params)
body = CodeBlock(algo, Return(wrt))
return FunctionDefinition(real, func_name, declars, body, attrs=attrs)
|