Spaces:
Sleeping
Sleeping
| from sympy.core.containers import Tuple | |
| from sympy.core.numbers import oo | |
| from sympy.core.relational import (Gt, Lt) | |
| from sympy.core.symbol import (Dummy, Symbol) | |
| from sympy.functions.elementary.complexes import Abs | |
| from sympy.functions.elementary.miscellaneous import Min, Max | |
| from sympy.logic.boolalg import And | |
| from sympy.codegen.ast import ( | |
| Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition, | |
| Print, Return, Scope, While, Variable, Pointer, real | |
| ) | |
| from sympy.codegen.cfunctions import isnan | |
| """ This module collects functions for constructing ASTs representing algorithms. """ | |
| def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False, | |
| itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x), | |
| cse=False, handle_nan=None, | |
| bounds=None): | |
| """ Generates an AST for Newton-Raphson method (a root-finding algorithm). | |
| Explanation | |
| =========== | |
| Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's | |
| method of root-finding. | |
| Parameters | |
| ========== | |
| expr : expression | |
| wrt : Symbol | |
| With respect to, i.e. what is the variable. | |
| atol : number or expression | |
| Absolute tolerance (stopping criterion) | |
| rtol : number or expression | |
| Relative tolerance (stopping criterion) | |
| delta : Symbol | |
| Will be a ``Dummy`` if ``None``. | |
| debug : bool | |
| Whether to print convergence information during iterations | |
| itermax : number or expr | |
| Maximum number of iterations. | |
| counter : Symbol | |
| Will be a ``Dummy`` if ``None``. | |
| delta_fn: Callable[[Expr, Symbol], Expr] | |
| computes the step, default is newtons method. For e.g. Halley's method | |
| use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2)) | |
| cse: bool | |
| Perform common sub-expression elimination on delta expression | |
| handle_nan: Token | |
| How to handle occurrence of not-a-number (NaN). | |
| bounds: Optional[tuple[Expr, Expr]] | |
| Perform optimization within bounds | |
| Examples | |
| ======== | |
| >>> from sympy import symbols, cos | |
| >>> from sympy.codegen.ast import Assignment | |
| >>> from sympy.codegen.algorithms import newtons_method | |
| >>> x, dx, atol = symbols('x dx atol') | |
| >>> expr = cos(x) - x**3 | |
| >>> algo = newtons_method(expr, x, atol=atol, delta=dx) | |
| >>> algo.has(Assignment(dx, -expr/expr.diff(x))) | |
| True | |
| References | |
| ========== | |
| .. [1] https://en.wikipedia.org/wiki/Newton%27s_method | |
| """ | |
| if delta is None: | |
| delta = Dummy() | |
| Wrapper = Scope | |
| name_d = 'delta' | |
| else: | |
| Wrapper = lambda x: x | |
| name_d = delta.name | |
| delta_expr = delta_fn(expr, wrt) | |
| if cse: | |
| from sympy.simplify.cse_main import cse | |
| cses, (red,) = cse([delta_expr.factor()]) | |
| whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses] | |
| whl_bdy += [Assignment(delta, red)] | |
| else: | |
| whl_bdy = [Assignment(delta, delta_expr)] | |
| if handle_nan is not None: | |
| whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))] | |
| whl_bdy += [AddAugmentedAssignment(wrt, delta)] | |
| if bounds is not None: | |
| whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))] | |
| if debug: | |
| prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d)) | |
| whl_bdy += [prnt] | |
| req = Gt(Abs(delta), atol + rtol*Abs(wrt)) | |
| declars = [Declaration(Variable(delta, type=real, value=oo))] | |
| if itermax is not None: | |
| counter = counter or Dummy(integer=True) | |
| v_counter = Variable.deduced(counter, 0) | |
| declars.append(Declaration(v_counter)) | |
| whl_bdy.append(AddAugmentedAssignment(counter, 1)) | |
| req = And(req, Lt(counter, itermax)) | |
| whl = While(req, CodeBlock(*whl_bdy)) | |
| blck = declars | |
| if debug: | |
| blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name))) | |
| blck += [whl] | |
| return Wrapper(CodeBlock(*blck)) | |
| def _symbol_of(arg): | |
| if isinstance(arg, Declaration): | |
| arg = arg.variable.symbol | |
| elif isinstance(arg, Variable): | |
| arg = arg.symbol | |
| return arg | |
| def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs): | |
| """ Generates an AST for a function implementing the Newton-Raphson method. | |
| Parameters | |
| ========== | |
| expr : expression | |
| wrt : Symbol | |
| With respect to, i.e. what is the variable | |
| params : iterable of symbols | |
| Symbols appearing in expr that are taken as constants during the iterations | |
| (these will be accepted as parameters to the generated function). | |
| func_name : str | |
| Name of the generated function. | |
| attrs : Tuple | |
| Attribute instances passed as ``attrs`` to ``FunctionDefinition``. | |
| \\*\\*kwargs : | |
| Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`. | |
| Examples | |
| ======== | |
| >>> from sympy import symbols, cos | |
| >>> from sympy.codegen.algorithms import newtons_method_function | |
| >>> from sympy.codegen.pyutils import render_as_module | |
| >>> x = symbols('x') | |
| >>> expr = cos(x) - x**3 | |
| >>> func = newtons_method_function(expr, x) | |
| >>> py_mod = render_as_module(func) # source code as string | |
| >>> namespace = {} | |
| >>> exec(py_mod, namespace, namespace) | |
| >>> res = eval('newton(0.5)', namespace) | |
| >>> abs(res - 0.865474033102) < 1e-12 | |
| True | |
| See Also | |
| ======== | |
| sympy.codegen.algorithms.newtons_method | |
| """ | |
| if params is None: | |
| params = (wrt,) | |
| pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name) | |
| for p in params if isinstance(p, Pointer)} | |
| if delta is None: | |
| delta = Symbol('d_' + wrt.name) | |
| if expr.has(delta): | |
| delta = None # will use Dummy | |
| algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs) | |
| if isinstance(algo, Scope): | |
| algo = algo.body | |
| not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params}) | |
| if not_in_params: | |
| raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params))) | |
| declars = tuple(Variable(p, real) for p in params) | |
| body = CodeBlock(algo, Return(wrt)) | |
| return FunctionDefinition(real, func_name, declars, body, attrs=attrs) | |