Generateurv2/backend/env/lib/python3.10/site-packages/sympy/codegen/algorithms.py
2022-06-24 17:14:37 +02:00

147 lines
4.8 KiB
Python

from sympy import And, Gt, Lt, Abs, Dummy, oo, Tuple, Symbol
from sympy.codegen.ast import (
Assignment, AddAugmentedAssignment, CodeBlock, Declaration, FunctionDefinition,
Print, Return, Scope, While, Variable, Pointer, real
)
""" This module collects functions for constructing ASTs representing algorithms. """
def newtons_method(expr, wrt, atol=1e-12, delta=None, debug=False,
itermax=None, counter=None):
""" Generates an AST for Newton-Raphson method (a root-finding algorithm).
Explanation
===========
Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
method of root-finding.
Parameters
==========
expr : expression
wrt : Symbol
With respect to, i.e. what is the variable.
atol : number or expr
Absolute tolerance (stopping criterion)
delta : Symbol
Will be a ``Dummy`` if ``None``.
debug : bool
Whether to print convergence information during iterations
itermax : number or expr
Maximum number of iterations.
counter : Symbol
Will be a ``Dummy`` if ``None``.
Examples
========
>>> from sympy import symbols, cos
>>> from sympy.codegen.ast import Assignment
>>> from sympy.codegen.algorithms import newtons_method
>>> x, dx, atol = symbols('x dx atol')
>>> expr = cos(x) - x**3
>>> algo = newtons_method(expr, x, atol, dx)
>>> algo.has(Assignment(dx, -expr/expr.diff(x)))
True
References
==========
.. [1] https://en.wikipedia.org/wiki/Newton%27s_method
"""
if delta is None:
delta = Dummy()
Wrapper = Scope
name_d = 'delta'
else:
Wrapper = lambda x: x
name_d = delta.name
delta_expr = -expr/expr.diff(wrt)
whl_bdy = [Assignment(delta, delta_expr), AddAugmentedAssignment(wrt, delta)]
if debug:
prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
whl_bdy = [whl_bdy[0], prnt] + whl_bdy[1:]
req = Gt(Abs(delta), atol)
declars = [Declaration(Variable(delta, type=real, value=oo))]
if itermax is not None:
counter = counter or Dummy(integer=True)
v_counter = Variable.deduced(counter, 0)
declars.append(Declaration(v_counter))
whl_bdy.append(AddAugmentedAssignment(counter, 1))
req = And(req, Lt(counter, itermax))
whl = While(req, CodeBlock(*whl_bdy))
blck = declars + [whl]
return Wrapper(CodeBlock(*blck))
def _symbol_of(arg):
if isinstance(arg, Declaration):
arg = arg.variable.symbol
elif isinstance(arg, Variable):
arg = arg.symbol
return arg
def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
""" Generates an AST for a function implementing the Newton-Raphson method.
Parameters
==========
expr : expression
wrt : Symbol
With respect to, i.e. what is the variable
params : iterable of symbols
Symbols appearing in expr that are taken as constants during the iterations
(these will be accepted as parameters to the generated function).
func_name : str
Name of the generated function.
attrs : Tuple
Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
\\*\\*kwargs :
Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
Examples
========
>>> from sympy import symbols, cos
>>> from sympy.codegen.algorithms import newtons_method_function
>>> from sympy.codegen.pyutils import render_as_module
>>> x = symbols('x')
>>> expr = cos(x) - x**3
>>> func = newtons_method_function(expr, x)
>>> py_mod = render_as_module(func) # source code as string
>>> namespace = {}
>>> exec(py_mod, namespace, namespace)
>>> res = eval('newton(0.5)', namespace)
>>> abs(res - 0.865474033102) < 1e-12
True
See Also
========
sympy.codegen.algorithms.newtons_method
"""
if params is None:
params = (wrt,)
pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
for p in params if isinstance(p, Pointer)}
if delta is None:
delta = Symbol('d_' + wrt.name)
if expr.has(delta):
delta = None # will use Dummy
algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
if isinstance(algo, Scope):
algo = algo.body
not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
if not_in_params:
raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
declars = tuple(Variable(p, real) for p in params)
body = CodeBlock(algo, Return(wrt))
return FunctionDefinition(real, func_name, declars, body, attrs=attrs)