Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
from typing import Any | |
from functools import wraps | |
from sympy.core import Add, Mul, Pow, S, sympify, Float | |
from sympy.core.basic import Basic | |
from sympy.core.expr import UnevaluatedExpr | |
from sympy.core.function import Lambda | |
from sympy.core.mul import _keep_coeff | |
from sympy.core.sorting import default_sort_key | |
from sympy.core.symbol import Symbol | |
from sympy.functions.elementary.complexes import re | |
from sympy.printing.str import StrPrinter | |
from sympy.printing.precedence import precedence, PRECEDENCE | |
class requires: | |
""" Decorator for registering requirements on print methods. """ | |
def __init__(self, **kwargs): | |
self._req = kwargs | |
def __call__(self, method): | |
def _method_wrapper(self_, *args, **kwargs): | |
for k, v in self._req.items(): | |
getattr(self_, k).update(v) | |
return method(self_, *args, **kwargs) | |
return wraps(method)(_method_wrapper) | |
class AssignmentError(Exception): | |
""" | |
Raised if an assignment variable for a loop is missing. | |
""" | |
pass | |
class PrintMethodNotImplementedError(NotImplementedError): | |
""" | |
Raised if a _print_* method is missing in the Printer. | |
""" | |
pass | |
def _convert_python_lists(arg): | |
if isinstance(arg, list): | |
from sympy.codegen.abstract_nodes import List | |
return List(*(_convert_python_lists(e) for e in arg)) | |
elif isinstance(arg, tuple): | |
return tuple(_convert_python_lists(e) for e in arg) | |
else: | |
return arg | |
class CodePrinter(StrPrinter): | |
""" | |
The base class for code-printing subclasses. | |
""" | |
_operators = { | |
'and': '&&', | |
'or': '||', | |
'not': '!', | |
} | |
_default_settings: dict[str, Any] = { | |
'order': None, | |
'full_prec': 'auto', | |
'error_on_reserved': False, | |
'reserved_word_suffix': '_', | |
'human': True, | |
'inline': False, | |
'allow_unknown_functions': False, | |
'strict': None # True or False; None => True if human == True | |
} | |
# Functions which are "simple" to rewrite to other functions that | |
# may be supported | |
# function_to_rewrite : (function_to_rewrite_to, iterable_with_other_functions_required) | |
_rewriteable_functions = { | |
'cot': ('tan', []), | |
'csc': ('sin', []), | |
'sec': ('cos', []), | |
'acot': ('atan', []), | |
'acsc': ('asin', []), | |
'asec': ('acos', []), | |
'coth': ('exp', []), | |
'csch': ('exp', []), | |
'sech': ('exp', []), | |
'acoth': ('log', []), | |
'acsch': ('log', []), | |
'asech': ('log', []), | |
'catalan': ('gamma', []), | |
'fibonacci': ('sqrt', []), | |
'lucas': ('sqrt', []), | |
'beta': ('gamma', []), | |
'sinc': ('sin', ['Piecewise']), | |
'Mod': ('floor', []), | |
'factorial': ('gamma', []), | |
'factorial2': ('gamma', ['Piecewise']), | |
'subfactorial': ('uppergamma', []), | |
'RisingFactorial': ('gamma', ['Piecewise']), | |
'FallingFactorial': ('gamma', ['Piecewise']), | |
'binomial': ('gamma', []), | |
'frac': ('floor', []), | |
'Max': ('Piecewise', []), | |
'Min': ('Piecewise', []), | |
'Heaviside': ('Piecewise', []), | |
'erf2': ('erf', []), | |
'erfc': ('erf', []), | |
'Li': ('li', []), | |
'Ei': ('li', []), | |
'dirichlet_eta': ('zeta', []), | |
'riemann_xi': ('zeta', ['gamma']), | |
'SingularityFunction': ('Piecewise', []), | |
} | |
def __init__(self, settings=None): | |
super().__init__(settings=settings) | |
if self._settings.get('strict', True) == None: | |
# for backwards compatibility, human=False need not to throw: | |
self._settings['strict'] = self._settings.get('human', True) == True | |
if not hasattr(self, 'reserved_words'): | |
self.reserved_words = set() | |
def _handle_UnevaluatedExpr(self, expr): | |
return expr.replace(re, lambda arg: arg if isinstance( | |
arg, UnevaluatedExpr) and arg.args[0].is_real else re(arg)) | |
def doprint(self, expr, assign_to=None): | |
""" | |
Print the expression as code. | |
Parameters | |
---------- | |
expr : Expression | |
The expression to be printed. | |
assign_to : Symbol, string, MatrixSymbol, list of strings or Symbols (optional) | |
If provided, the printed code will set the expression to a variable or multiple variables | |
with the name or names given in ``assign_to``. | |
""" | |
from sympy.matrices.expressions.matexpr import MatrixSymbol | |
from sympy.codegen.ast import CodeBlock, Assignment | |
def _handle_assign_to(expr, assign_to): | |
if assign_to is None: | |
return sympify(expr) | |
if isinstance(assign_to, (list, tuple)): | |
if len(expr) != len(assign_to): | |
raise ValueError('Failed to assign an expression of length {} to {} variables'.format(len(expr), len(assign_to))) | |
return CodeBlock(*[_handle_assign_to(lhs, rhs) for lhs, rhs in zip(expr, assign_to)]) | |
if isinstance(assign_to, str): | |
if expr.is_Matrix: | |
assign_to = MatrixSymbol(assign_to, *expr.shape) | |
else: | |
assign_to = Symbol(assign_to) | |
elif not isinstance(assign_to, Basic): | |
raise TypeError("{} cannot assign to object of type {}".format( | |
type(self).__name__, type(assign_to))) | |
return Assignment(assign_to, expr) | |
expr = _convert_python_lists(expr) | |
expr = _handle_assign_to(expr, assign_to) | |
# Remove re(...) nodes due to UnevaluatedExpr.is_real always is None: | |
expr = self._handle_UnevaluatedExpr(expr) | |
# keep a set of expressions that are not strictly translatable to Code | |
# and number constants that must be declared and initialized | |
self._not_supported = set() | |
self._number_symbols = set() | |
lines = self._print(expr).splitlines() | |
# format the output | |
if self._settings["human"]: | |
frontlines = [] | |
if self._not_supported: | |
frontlines.append(self._get_comment( | |
"Not supported in {}:".format(self.language))) | |
for expr in sorted(self._not_supported, key=str): | |
frontlines.append(self._get_comment(type(expr).__name__)) | |
for name, value in sorted(self._number_symbols, key=str): | |
frontlines.append(self._declare_number_const(name, value)) | |
lines = frontlines + lines | |
lines = self._format_code(lines) | |
result = "\n".join(lines) | |
else: | |
lines = self._format_code(lines) | |
num_syms = {(k, self._print(v)) for k, v in self._number_symbols} | |
result = (num_syms, self._not_supported, "\n".join(lines)) | |
self._not_supported = set() | |
self._number_symbols = set() | |
return result | |
def _doprint_loops(self, expr, assign_to=None): | |
# Here we print an expression that contains Indexed objects, they | |
# correspond to arrays in the generated code. The low-level implementation | |
# involves looping over array elements and possibly storing results in temporary | |
# variables or accumulate it in the assign_to object. | |
if self._settings.get('contract', True): | |
from sympy.tensor import get_contraction_structure | |
# Setup loops over non-dummy indices -- all terms need these | |
indices = self._get_expression_indices(expr, assign_to) | |
# Setup loops over dummy indices -- each term needs separate treatment | |
dummies = get_contraction_structure(expr) | |
else: | |
indices = [] | |
dummies = {None: (expr,)} | |
openloop, closeloop = self._get_loop_opening_ending(indices) | |
# terms with no summations first | |
if None in dummies: | |
text = StrPrinter.doprint(self, Add(*dummies[None])) | |
else: | |
# If all terms have summations we must initialize array to Zero | |
text = StrPrinter.doprint(self, 0) | |
# skip redundant assignments (where lhs == rhs) | |
lhs_printed = self._print(assign_to) | |
lines = [] | |
if text != lhs_printed: | |
lines.extend(openloop) | |
if assign_to is not None: | |
text = self._get_statement("%s = %s" % (lhs_printed, text)) | |
lines.append(text) | |
lines.extend(closeloop) | |
# then terms with summations | |
for d in dummies: | |
if isinstance(d, tuple): | |
indices = self._sort_optimized(d, expr) | |
openloop_d, closeloop_d = self._get_loop_opening_ending( | |
indices) | |
for term in dummies[d]: | |
if term in dummies and not ([list(f.keys()) for f in dummies[term]] | |
== [[None] for f in dummies[term]]): | |
# If one factor in the term has it's own internal | |
# contractions, those must be computed first. | |
# (temporary variables?) | |
raise NotImplementedError( | |
"FIXME: no support for contractions in factor yet") | |
else: | |
# We need the lhs expression as an accumulator for | |
# the loops, i.e | |
# | |
# for (int d=0; d < dim; d++){ | |
# lhs[] = lhs[] + term[][d] | |
# } ^.................. the accumulator | |
# | |
# We check if the expression already contains the | |
# lhs, and raise an exception if it does, as that | |
# syntax is currently undefined. FIXME: What would be | |
# a good interpretation? | |
if assign_to is None: | |
raise AssignmentError( | |
"need assignment variable for loops") | |
if term.has(assign_to): | |
raise ValueError("FIXME: lhs present in rhs,\ | |
this is undefined in CodePrinter") | |
lines.extend(openloop) | |
lines.extend(openloop_d) | |
text = "%s = %s" % (lhs_printed, StrPrinter.doprint( | |
self, assign_to + term)) | |
lines.append(self._get_statement(text)) | |
lines.extend(closeloop_d) | |
lines.extend(closeloop) | |
return "\n".join(lines) | |
def _get_expression_indices(self, expr, assign_to): | |
from sympy.tensor import get_indices | |
rinds, junk = get_indices(expr) | |
linds, junk = get_indices(assign_to) | |
# support broadcast of scalar | |
if linds and not rinds: | |
rinds = linds | |
if rinds != linds: | |
raise ValueError("lhs indices must match non-dummy" | |
" rhs indices in %s" % expr) | |
return self._sort_optimized(rinds, assign_to) | |
def _sort_optimized(self, indices, expr): | |
from sympy.tensor.indexed import Indexed | |
if not indices: | |
return [] | |
# determine optimized loop order by giving a score to each index | |
# the index with the highest score are put in the innermost loop. | |
score_table = {} | |
for i in indices: | |
score_table[i] = 0 | |
arrays = expr.atoms(Indexed) | |
for arr in arrays: | |
for p, ind in enumerate(arr.indices): | |
try: | |
score_table[ind] += self._rate_index_position(p) | |
except KeyError: | |
pass | |
return sorted(indices, key=lambda x: score_table[x]) | |
def _rate_index_position(self, p): | |
"""function to calculate score based on position among indices | |
This method is used to sort loops in an optimized order, see | |
CodePrinter._sort_optimized() | |
""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _get_statement(self, codestring): | |
"""Formats a codestring with the proper line ending.""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _get_comment(self, text): | |
"""Formats a text string as a comment.""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _declare_number_const(self, name, value): | |
"""Declare a numeric constant at the top of a function""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _format_code(self, lines): | |
"""Take in a list of lines of code, and format them accordingly. | |
This may include indenting, wrapping long lines, etc...""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _get_loop_opening_ending(self, indices): | |
"""Returns a tuple (open_lines, close_lines) containing lists | |
of codelines""" | |
raise NotImplementedError("This function must be implemented by " | |
"subclass of CodePrinter.") | |
def _print_Dummy(self, expr): | |
if expr.name.startswith('Dummy_'): | |
return '_' + expr.name | |
else: | |
return '%s_%d' % (expr.name, expr.dummy_index) | |
def _print_CodeBlock(self, expr): | |
return '\n'.join([self._print(i) for i in expr.args]) | |
def _print_String(self, string): | |
return str(string) | |
def _print_QuotedString(self, arg): | |
return '"%s"' % arg.text | |
def _print_Comment(self, string): | |
return self._get_comment(str(string)) | |
def _print_Assignment(self, expr): | |
from sympy.codegen.ast import Assignment | |
from sympy.functions.elementary.piecewise import Piecewise | |
from sympy.matrices.expressions.matexpr import MatrixSymbol | |
from sympy.tensor.indexed import IndexedBase | |
lhs = expr.lhs | |
rhs = expr.rhs | |
# We special case assignments that take multiple lines | |
if isinstance(expr.rhs, Piecewise): | |
# Here we modify Piecewise so each expression is now | |
# an Assignment, and then continue on the print. | |
expressions = [] | |
conditions = [] | |
for (e, c) in rhs.args: | |
expressions.append(Assignment(lhs, e)) | |
conditions.append(c) | |
temp = Piecewise(*zip(expressions, conditions)) | |
return self._print(temp) | |
elif isinstance(lhs, MatrixSymbol): | |
# Here we form an Assignment for each element in the array, | |
# printing each one. | |
lines = [] | |
for (i, j) in self._traverse_matrix_indices(lhs): | |
temp = Assignment(lhs[i, j], rhs[i, j]) | |
code0 = self._print(temp) | |
lines.append(code0) | |
return "\n".join(lines) | |
elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or | |
rhs.has(IndexedBase)): | |
# Here we check if there is looping to be done, and if so | |
# print the required loops. | |
return self._doprint_loops(rhs, lhs) | |
else: | |
lhs_code = self._print(lhs) | |
rhs_code = self._print(rhs) | |
return self._get_statement("%s = %s" % (lhs_code, rhs_code)) | |
def _print_AugmentedAssignment(self, expr): | |
lhs_code = self._print(expr.lhs) | |
rhs_code = self._print(expr.rhs) | |
return self._get_statement("{} {} {}".format( | |
*(self._print(arg) for arg in [lhs_code, expr.op, rhs_code]))) | |
def _print_FunctionCall(self, expr): | |
return '%s(%s)' % ( | |
expr.name, | |
', '.join((self._print(arg) for arg in expr.function_args))) | |
def _print_Variable(self, expr): | |
return self._print(expr.symbol) | |
def _print_Symbol(self, expr): | |
name = super()._print_Symbol(expr) | |
if name in self.reserved_words: | |
if self._settings['error_on_reserved']: | |
msg = ('This expression includes the symbol "{}" which is a ' | |
'reserved keyword in this language.') | |
raise ValueError(msg.format(name)) | |
return name + self._settings['reserved_word_suffix'] | |
else: | |
return name | |
def _can_print(self, name): | |
""" Check if function ``name`` is either a known function or has its own | |
printing method. Used to check if rewriting is possible.""" | |
return name in self.known_functions or getattr(self, '_print_{}'.format(name), False) | |
def _print_Function(self, expr): | |
if expr.func.__name__ in self.known_functions: | |
cond_func = self.known_functions[expr.func.__name__] | |
if isinstance(cond_func, str): | |
return "%s(%s)" % (cond_func, self.stringify(expr.args, ", ")) | |
else: | |
for cond, func in cond_func: | |
if cond(*expr.args): | |
break | |
if func is not None: | |
try: | |
return func(*[self.parenthesize(item, 0) for item in expr.args]) | |
except TypeError: | |
return "%s(%s)" % (func, self.stringify(expr.args, ", ")) | |
elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): | |
# inlined function | |
return self._print(expr._imp_(*expr.args)) | |
elif expr.func.__name__ in self._rewriteable_functions: | |
# Simple rewrite to supported function possible | |
target_f, required_fs = self._rewriteable_functions[expr.func.__name__] | |
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs): | |
return '(' + self._print(expr.rewrite(target_f)) + ')' | |
if expr.is_Function and self._settings.get('allow_unknown_functions', False): | |
return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args))) | |
else: | |
return self._print_not_supported(expr) | |
_print_Expr = _print_Function | |
# Don't inherit the str-printer method for Heaviside to the code printers | |
_print_Heaviside = None | |
def _print_NumberSymbol(self, expr): | |
if self._settings.get("inline", False): | |
return self._print(Float(expr.evalf(self._settings["precision"]))) | |
else: | |
# A Number symbol that is not implemented here or with _printmethod | |
# is registered and evaluated | |
self._number_symbols.add((expr, | |
Float(expr.evalf(self._settings["precision"])))) | |
return str(expr) | |
def _print_Catalan(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_EulerGamma(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_GoldenRatio(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_TribonacciConstant(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_Exp1(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_Pi(self, expr): | |
return self._print_NumberSymbol(expr) | |
def _print_And(self, expr): | |
PREC = precedence(expr) | |
return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC) | |
for a in sorted(expr.args, key=default_sort_key)) | |
def _print_Or(self, expr): | |
PREC = precedence(expr) | |
return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC) | |
for a in sorted(expr.args, key=default_sort_key)) | |
def _print_Xor(self, expr): | |
if self._operators.get('xor') is None: | |
return self._print(expr.to_nnf()) | |
PREC = precedence(expr) | |
return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC) | |
for a in expr.args) | |
def _print_Equivalent(self, expr): | |
if self._operators.get('equivalent') is None: | |
return self._print(expr.to_nnf()) | |
PREC = precedence(expr) | |
return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC) | |
for a in expr.args) | |
def _print_Not(self, expr): | |
PREC = precedence(expr) | |
return self._operators['not'] + self.parenthesize(expr.args[0], PREC) | |
def _print_BooleanFunction(self, expr): | |
return self._print(expr.to_nnf()) | |
def _print_Mul(self, expr): | |
prec = precedence(expr) | |
c, e = expr.as_coeff_Mul() | |
if c < 0: | |
expr = _keep_coeff(-c, e) | |
sign = "-" | |
else: | |
sign = "" | |
a = [] # items in the numerator | |
b = [] # items that are in the denominator (if any) | |
pow_paren = [] # Will collect all pow with more than one base element and exp = -1 | |
if self.order not in ('old', 'none'): | |
args = expr.as_ordered_factors() | |
else: | |
# use make_args in case expr was something like -x -> x | |
args = Mul.make_args(expr) | |
# Gather args for numerator/denominator | |
for item in args: | |
if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: | |
if item.exp != -1: | |
b.append(Pow(item.base, -item.exp, evaluate=False)) | |
else: | |
if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 | |
pow_paren.append(item) | |
b.append(Pow(item.base, -item.exp)) | |
else: | |
a.append(item) | |
a = a or [S.One] | |
if len(a) == 1 and sign == "-": | |
# Unary minus does not have a SymPy class, and hence there's no | |
# precedence weight associated with it, Python's unary minus has | |
# an operator precedence between multiplication and exponentiation, | |
# so we use this to compute a weight. | |
a_str = [self.parenthesize(a[0], 0.5*(PRECEDENCE["Pow"]+PRECEDENCE["Mul"]))] | |
else: | |
a_str = [self.parenthesize(x, prec) for x in a] | |
b_str = [self.parenthesize(x, prec) for x in b] | |
# To parenthesize Pow with exp = -1 and having more than one Symbol | |
for item in pow_paren: | |
if item.base in b: | |
b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] | |
if not b: | |
return sign + '*'.join(a_str) | |
elif len(b) == 1: | |
return sign + '*'.join(a_str) + "/" + b_str[0] | |
else: | |
return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str) | |
def _print_not_supported(self, expr): | |
if self._settings.get('strict', False): | |
raise PrintMethodNotImplementedError("Unsupported by %s: %s" % (str(type(self)), str(type(expr))) + \ | |
"\nSet the printer option 'strict' to False in order to generate partially printed code.") | |
try: | |
self._not_supported.add(expr) | |
except TypeError: | |
# not hashable | |
pass | |
return self.emptyPrinter(expr) | |
# The following can not be simply translated into C or Fortran | |
_print_Basic = _print_not_supported | |
_print_ComplexInfinity = _print_not_supported | |
_print_Derivative = _print_not_supported | |
_print_ExprCondPair = _print_not_supported | |
_print_GeometryEntity = _print_not_supported | |
_print_Infinity = _print_not_supported | |
_print_Integral = _print_not_supported | |
_print_Interval = _print_not_supported | |
_print_AccumulationBounds = _print_not_supported | |
_print_Limit = _print_not_supported | |
_print_MatrixBase = _print_not_supported | |
_print_DeferredVector = _print_not_supported | |
_print_NaN = _print_not_supported | |
_print_NegativeInfinity = _print_not_supported | |
_print_Order = _print_not_supported | |
_print_RootOf = _print_not_supported | |
_print_RootsOf = _print_not_supported | |
_print_RootSum = _print_not_supported | |
_print_Uniform = _print_not_supported | |
_print_Unit = _print_not_supported | |
_print_Wild = _print_not_supported | |
_print_WildFunction = _print_not_supported | |
_print_Relational = _print_not_supported | |
# Code printer functions. These are included in this file so that they can be | |
# imported in the top-level __init__.py without importing the sympy.codegen | |
# module. | |
def ccode(expr, assign_to=None, standard='c99', **settings): | |
"""Converts an expr to a string of c code | |
Parameters | |
========== | |
expr : Expr | |
A SymPy expression to be converted. | |
assign_to : optional | |
When given, the argument is used as the name of the variable to which | |
the expression is assigned. Can be a string, ``Symbol``, | |
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of | |
line-wrapping, or for expressions that generate multi-line statements. | |
standard : str, optional | |
String specifying the standard. If your compiler supports a more modern | |
standard you may set this to 'c99' to allow the printer to use more math | |
functions. [default='c89']. | |
precision : integer, optional | |
The precision for numbers such as pi [default=17]. | |
user_functions : dict, optional | |
A dictionary where the keys are string representations of either | |
``FunctionClass`` or ``UndefinedFunction`` instances and the values | |
are their desired C string representations. Alternatively, the | |
dictionary value can be a list of tuples i.e. [(argument_test, | |
cfunction_string)] or [(argument_test, cfunction_formater)]. See below | |
for examples. | |
dereference : iterable, optional | |
An iterable of symbols that should be dereferenced in the printed code | |
expression. These would be values passed by address to the function. | |
For example, if ``dereference=[a]``, the resulting code would print | |
``(*a)`` instead of ``a``. | |
human : bool, optional | |
If True, the result is a single string that may contain some constant | |
declarations for the number symbols. If False, the same information is | |
returned in a tuple of (symbols_to_declare, not_supported_functions, | |
code_text). [default=True]. | |
contract: bool, optional | |
If True, ``Indexed`` instances are assumed to obey tensor contraction | |
rules and the corresponding nested loops over indices are generated. | |
Setting contract=False will not generate loops, instead the user is | |
responsible to provide values for the indices in the code. | |
[default=True]. | |
Examples | |
======== | |
>>> from sympy import ccode, symbols, Rational, sin, ceiling, Abs, Function | |
>>> x, tau = symbols("x, tau") | |
>>> expr = (2*tau)**Rational(7, 2) | |
>>> ccode(expr) | |
'8*M_SQRT2*pow(tau, 7.0/2.0)' | |
>>> ccode(expr, math_macros={}) | |
'8*sqrt(2)*pow(tau, 7.0/2.0)' | |
>>> ccode(sin(x), assign_to="s") | |
's = sin(x);' | |
>>> from sympy.codegen.ast import real, float80 | |
>>> ccode(expr, type_aliases={real: float80}) | |
'8*M_SQRT2l*powl(tau, 7.0L/2.0L)' | |
Simple custom printing can be defined for certain types by passing a | |
dictionary of {"type" : "function"} to the ``user_functions`` kwarg. | |
Alternatively, the dictionary value can be a list of tuples i.e. | |
[(argument_test, cfunction_string)]. | |
>>> custom_functions = { | |
... "ceiling": "CEIL", | |
... "Abs": [(lambda x: not x.is_integer, "fabs"), | |
... (lambda x: x.is_integer, "ABS")], | |
... "func": "f" | |
... } | |
>>> func = Function('func') | |
>>> ccode(func(Abs(x) + ceiling(x)), standard='C89', user_functions=custom_functions) | |
'f(fabs(x) + CEIL(x))' | |
or if the C-function takes a subset of the original arguments: | |
>>> ccode(2**x + 3**x, standard='C99', user_functions={'Pow': [ | |
... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e), | |
... (lambda b, e: b != 2, 'pow')]}) | |
'exp2(x) + pow(3, x)' | |
``Piecewise`` expressions are converted into conditionals. If an | |
``assign_to`` variable is provided an if statement is created, otherwise | |
the ternary operator is used. Note that if the ``Piecewise`` lacks a | |
default term, represented by ``(expr, True)`` then an error will be thrown. | |
This is to prevent generating an expression that may not evaluate to | |
anything. | |
>>> from sympy import Piecewise | |
>>> expr = Piecewise((x + 1, x > 0), (x, True)) | |
>>> print(ccode(expr, tau, standard='C89')) | |
if (x > 0) { | |
tau = x + 1; | |
} | |
else { | |
tau = x; | |
} | |
Support for loops is provided through ``Indexed`` types. With | |
``contract=True`` these expressions will be turned into loops, whereas | |
``contract=False`` will just print the assignment expression that should be | |
looped over: | |
>>> from sympy import Eq, IndexedBase, Idx | |
>>> len_y = 5 | |
>>> y = IndexedBase('y', shape=(len_y,)) | |
>>> t = IndexedBase('t', shape=(len_y,)) | |
>>> Dy = IndexedBase('Dy', shape=(len_y-1,)) | |
>>> i = Idx('i', len_y-1) | |
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) | |
>>> ccode(e.rhs, assign_to=e.lhs, contract=False, standard='C89') | |
'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);' | |
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions | |
must be provided to ``assign_to``. Note that any expression that can be | |
generated normally can also exist inside a Matrix: | |
>>> from sympy import Matrix, MatrixSymbol | |
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) | |
>>> A = MatrixSymbol('A', 3, 1) | |
>>> print(ccode(mat, A, standard='C89')) | |
A[0] = pow(x, 2); | |
if (x > 0) { | |
A[1] = x + 1; | |
} | |
else { | |
A[1] = x; | |
} | |
A[2] = sin(x); | |
""" | |
from sympy.printing.c import c_code_printers | |
return c_code_printers[standard.lower()](settings).doprint(expr, assign_to) | |
def print_ccode(expr, **settings): | |
"""Prints C representation of the given expression.""" | |
print(ccode(expr, **settings)) | |
def fcode(expr, assign_to=None, **settings): | |
"""Converts an expr to a string of fortran code | |
Parameters | |
========== | |
expr : Expr | |
A SymPy expression to be converted. | |
assign_to : optional | |
When given, the argument is used as the name of the variable to which | |
the expression is assigned. Can be a string, ``Symbol``, | |
``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of | |
line-wrapping, or for expressions that generate multi-line statements. | |
precision : integer, optional | |
DEPRECATED. Use type_mappings instead. The precision for numbers such | |
as pi [default=17]. | |
user_functions : dict, optional | |
A dictionary where keys are ``FunctionClass`` instances and values are | |
their string representations. Alternatively, the dictionary value can | |
be a list of tuples i.e. [(argument_test, cfunction_string)]. See below | |
for examples. | |
human : bool, optional | |
If True, the result is a single string that may contain some constant | |
declarations for the number symbols. If False, the same information is | |
returned in a tuple of (symbols_to_declare, not_supported_functions, | |
code_text). [default=True]. | |
contract: bool, optional | |
If True, ``Indexed`` instances are assumed to obey tensor contraction | |
rules and the corresponding nested loops over indices are generated. | |
Setting contract=False will not generate loops, instead the user is | |
responsible to provide values for the indices in the code. | |
[default=True]. | |
source_format : optional | |
The source format can be either 'fixed' or 'free'. [default='fixed'] | |
standard : integer, optional | |
The Fortran standard to be followed. This is specified as an integer. | |
Acceptable standards are 66, 77, 90, 95, 2003, and 2008. Default is 77. | |
Note that currently the only distinction internally is between | |
standards before 95, and those 95 and after. This may change later as | |
more features are added. | |
name_mangling : bool, optional | |
If True, then the variables that would become identical in | |
case-insensitive Fortran are mangled by appending different number | |
of ``_`` at the end. If False, SymPy Will not interfere with naming of | |
variables. [default=True] | |
Examples | |
======== | |
>>> from sympy import fcode, symbols, Rational, sin, ceiling, floor | |
>>> x, tau = symbols("x, tau") | |
>>> fcode((2*tau)**Rational(7, 2)) | |
' 8*sqrt(2.0d0)*tau**(7.0d0/2.0d0)' | |
>>> fcode(sin(x), assign_to="s") | |
' s = sin(x)' | |
Custom printing can be defined for certain types by passing a dictionary of | |
"type" : "function" to the ``user_functions`` kwarg. Alternatively, the | |
dictionary value can be a list of tuples i.e. [(argument_test, | |
cfunction_string)]. | |
>>> custom_functions = { | |
... "ceiling": "CEIL", | |
... "floor": [(lambda x: not x.is_integer, "FLOOR1"), | |
... (lambda x: x.is_integer, "FLOOR2")] | |
... } | |
>>> fcode(floor(x) + ceiling(x), user_functions=custom_functions) | |
' CEIL(x) + FLOOR1(x)' | |
``Piecewise`` expressions are converted into conditionals. If an | |
``assign_to`` variable is provided an if statement is created, otherwise | |
the ternary operator is used. Note that if the ``Piecewise`` lacks a | |
default term, represented by ``(expr, True)`` then an error will be thrown. | |
This is to prevent generating an expression that may not evaluate to | |
anything. | |
>>> from sympy import Piecewise | |
>>> expr = Piecewise((x + 1, x > 0), (x, True)) | |
>>> print(fcode(expr, tau)) | |
if (x > 0) then | |
tau = x + 1 | |
else | |
tau = x | |
end if | |
Support for loops is provided through ``Indexed`` types. With | |
``contract=True`` these expressions will be turned into loops, whereas | |
``contract=False`` will just print the assignment expression that should be | |
looped over: | |
>>> from sympy import Eq, IndexedBase, Idx | |
>>> len_y = 5 | |
>>> y = IndexedBase('y', shape=(len_y,)) | |
>>> t = IndexedBase('t', shape=(len_y,)) | |
>>> Dy = IndexedBase('Dy', shape=(len_y-1,)) | |
>>> i = Idx('i', len_y-1) | |
>>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) | |
>>> fcode(e.rhs, assign_to=e.lhs, contract=False) | |
' Dy(i) = (y(i + 1) - y(i))/(t(i + 1) - t(i))' | |
Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions | |
must be provided to ``assign_to``. Note that any expression that can be | |
generated normally can also exist inside a Matrix: | |
>>> from sympy import Matrix, MatrixSymbol | |
>>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)]) | |
>>> A = MatrixSymbol('A', 3, 1) | |
>>> print(fcode(mat, A)) | |
A(1, 1) = x**2 | |
if (x > 0) then | |
A(2, 1) = x + 1 | |
else | |
A(2, 1) = x | |
end if | |
A(3, 1) = sin(x) | |
""" | |
from sympy.printing.fortran import FCodePrinter | |
return FCodePrinter(settings).doprint(expr, assign_to) | |
def print_fcode(expr, **settings): | |
"""Prints the Fortran representation of the given expression. | |
See fcode for the meaning of the optional arguments. | |
""" | |
print(fcode(expr, **settings)) | |
def cxxcode(expr, assign_to=None, standard='c++11', **settings): | |
""" C++ equivalent of :func:`~.ccode`. """ | |
from sympy.printing.cxx import cxx_code_printers | |
return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to) | |