Spaces:
Sleeping
Sleeping
from sympy.core.containers import Tuple | |
from sympy.core.numbers import oo | |
from sympy.core.relational import (Gt, Lt) | |
from sympy.core.symbol import (Dummy, Symbol) | |
from sympy.functions.elementary.complexes import Abs | |
from sympy.functions.elementary.miscellaneous import Min, Max | |
from sympy.logic.boolalg import And | |
from sympy.codegen.ast import ( | |
Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition, | |
Print, Return, Scope, While, Variable, Pointer, real | |
) | |
from sympy.codegen.cfunctions import isnan | |
""" This module collects functions for constructing ASTs representing algorithms. """ | |
def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False, | |
itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x), | |
cse=False, handle_nan=None, | |
bounds=None): | |
""" Generates an AST for Newton-Raphson method (a root-finding algorithm). | |
Explanation | |
=========== | |
Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's | |
method of root-finding. | |
Parameters | |
========== | |
expr : expression | |
wrt : Symbol | |
With respect to, i.e. what is the variable. | |
atol : number or expression | |
Absolute tolerance (stopping criterion) | |
rtol : number or expression | |
Relative tolerance (stopping criterion) | |
delta : Symbol | |
Will be a ``Dummy`` if ``None``. | |
debug : bool | |
Whether to print convergence information during iterations | |
itermax : number or expr | |
Maximum number of iterations. | |
counter : Symbol | |
Will be a ``Dummy`` if ``None``. | |
delta_fn: Callable[[Expr, Symbol], Expr] | |
computes the step, default is newtons method. For e.g. Halley's method | |
use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2)) | |
cse: bool | |
Perform common sub-expression elimination on delta expression | |
handle_nan: Token | |
How to handle occurrence of not-a-number (NaN). | |
bounds: Optional[tuple[Expr, Expr]] | |
Perform optimization within bounds | |
Examples | |
======== | |
>>> from sympy import symbols, cos | |
>>> from sympy.codegen.ast import Assignment | |
>>> from sympy.codegen.algorithms import newtons_method | |
>>> x, dx, atol = symbols('x dx atol') | |
>>> expr = cos(x) - x**3 | |
>>> algo = newtons_method(expr, x, atol=atol, delta=dx) | |
>>> algo.has(Assignment(dx, -expr/expr.diff(x))) | |
True | |
References | |
========== | |
.. [1] https://en.wikipedia.org/wiki/Newton%27s_method | |
""" | |
if delta is None: | |
delta = Dummy() | |
Wrapper = Scope | |
name_d = 'delta' | |
else: | |
Wrapper = lambda x: x | |
name_d = delta.name | |
delta_expr = delta_fn(expr, wrt) | |
if cse: | |
from sympy.simplify.cse_main import cse | |
cses, (red,) = cse([delta_expr.factor()]) | |
whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses] | |
whl_bdy += [Assignment(delta, red)] | |
else: | |
whl_bdy = [Assignment(delta, delta_expr)] | |
if handle_nan is not None: | |
whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))] | |
whl_bdy += [AddAugmentedAssignment(wrt, delta)] | |
if bounds is not None: | |
whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))] | |
if debug: | |
prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d)) | |
whl_bdy += [prnt] | |
req = Gt(Abs(delta), atol + rtol*Abs(wrt)) | |
declars = [Declaration(Variable(delta, type=real, value=oo))] | |
if itermax is not None: | |
counter = counter or Dummy(integer=True) | |
v_counter = Variable.deduced(counter, 0) | |
declars.append(Declaration(v_counter)) | |
whl_bdy.append(AddAugmentedAssignment(counter, 1)) | |
req = And(req, Lt(counter, itermax)) | |
whl = While(req, CodeBlock(*whl_bdy)) | |
blck = declars | |
if debug: | |
blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name))) | |
blck += [whl] | |
return Wrapper(CodeBlock(*blck)) | |
def _symbol_of(arg): | |
if isinstance(arg, Declaration): | |
arg = arg.variable.symbol | |
elif isinstance(arg, Variable): | |
arg = arg.symbol | |
return arg | |
def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs): | |
""" Generates an AST for a function implementing the Newton-Raphson method. | |
Parameters | |
========== | |
expr : expression | |
wrt : Symbol | |
With respect to, i.e. what is the variable | |
params : iterable of symbols | |
Symbols appearing in expr that are taken as constants during the iterations | |
(these will be accepted as parameters to the generated function). | |
func_name : str | |
Name of the generated function. | |
attrs : Tuple | |
Attribute instances passed as ``attrs`` to ``FunctionDefinition``. | |
\\*\\*kwargs : | |
Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`. | |
Examples | |
======== | |
>>> from sympy import symbols, cos | |
>>> from sympy.codegen.algorithms import newtons_method_function | |
>>> from sympy.codegen.pyutils import render_as_module | |
>>> x = symbols('x') | |
>>> expr = cos(x) - x**3 | |
>>> func = newtons_method_function(expr, x) | |
>>> py_mod = render_as_module(func) # source code as string | |
>>> namespace = {} | |
>>> exec(py_mod, namespace, namespace) | |
>>> res = eval('newton(0.5)', namespace) | |
>>> abs(res - 0.865474033102) < 1e-12 | |
True | |
See Also | |
======== | |
sympy.codegen.algorithms.newtons_method | |
""" | |
if params is None: | |
params = (wrt,) | |
pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name) | |
for p in params if isinstance(p, Pointer)} | |
if delta is None: | |
delta = Symbol('d_' + wrt.name) | |
if expr.has(delta): | |
delta = None # will use Dummy | |
algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs) | |
if isinstance(algo, Scope): | |
algo = algo.body | |
not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params}) | |
if not_in_params: | |
raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params))) | |
declars = tuple(Variable(p, real) for p in params) | |
body = CodeBlock(algo, Return(wrt)) | |
return FunctionDefinition(real, func_name, declars, body, attrs=attrs) | |