Spaces:
Sleeping
Sleeping
| from .pycode import ( | |
| PythonCodePrinter, | |
| MpmathPrinter, | |
| ) | |
| from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility | |
| from sympy.core.sorting import default_sort_key | |
| __all__ = [ | |
| 'PythonCodePrinter', | |
| 'MpmathPrinter', # MpmathPrinter is published for backward compatibility | |
| 'NumPyPrinter', | |
| 'LambdaPrinter', | |
| 'NumPyPrinter', | |
| 'IntervalPrinter', | |
| 'lambdarepr', | |
| ] | |
| class LambdaPrinter(PythonCodePrinter): | |
| """ | |
| This printer converts expressions into strings that can be used by | |
| lambdify. | |
| """ | |
| printmethod = "_lambdacode" | |
| def _print_And(self, expr): | |
| result = ['('] | |
| for arg in sorted(expr.args, key=default_sort_key): | |
| result.extend(['(', self._print(arg), ')']) | |
| result.append(' and ') | |
| result = result[:-1] | |
| result.append(')') | |
| return ''.join(result) | |
| def _print_Or(self, expr): | |
| result = ['('] | |
| for arg in sorted(expr.args, key=default_sort_key): | |
| result.extend(['(', self._print(arg), ')']) | |
| result.append(' or ') | |
| result = result[:-1] | |
| result.append(')') | |
| return ''.join(result) | |
| def _print_Not(self, expr): | |
| result = ['(', 'not (', self._print(expr.args[0]), '))'] | |
| return ''.join(result) | |
| def _print_BooleanTrue(self, expr): | |
| return "True" | |
| def _print_BooleanFalse(self, expr): | |
| return "False" | |
| def _print_ITE(self, expr): | |
| result = [ | |
| '((', self._print(expr.args[1]), | |
| ') if (', self._print(expr.args[0]), | |
| ') else (', self._print(expr.args[2]), '))' | |
| ] | |
| return ''.join(result) | |
| def _print_NumberSymbol(self, expr): | |
| return str(expr) | |
| def _print_Pow(self, expr, **kwargs): | |
| # XXX Temporary workaround. Should Python math printer be | |
| # isolated from PythonCodePrinter? | |
| return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs) | |
| # numexpr works by altering the string passed to numexpr.evaluate | |
| # rather than by populating a namespace. Thus a special printer... | |
| class NumExprPrinter(LambdaPrinter): | |
| # key, value pairs correspond to SymPy name and numexpr name | |
| # functions not appearing in this dict will raise a TypeError | |
| printmethod = "_numexprcode" | |
| _numexpr_functions = { | |
| 'sin' : 'sin', | |
| 'cos' : 'cos', | |
| 'tan' : 'tan', | |
| 'asin': 'arcsin', | |
| 'acos': 'arccos', | |
| 'atan': 'arctan', | |
| 'atan2' : 'arctan2', | |
| 'sinh' : 'sinh', | |
| 'cosh' : 'cosh', | |
| 'tanh' : 'tanh', | |
| 'asinh': 'arcsinh', | |
| 'acosh': 'arccosh', | |
| 'atanh': 'arctanh', | |
| 'ln' : 'log', | |
| 'log': 'log', | |
| 'exp': 'exp', | |
| 'sqrt' : 'sqrt', | |
| 'Abs' : 'abs', | |
| 'conjugate' : 'conj', | |
| 'im' : 'imag', | |
| 're' : 'real', | |
| 'where' : 'where', | |
| 'complex' : 'complex', | |
| 'contains' : 'contains', | |
| } | |
| module = 'numexpr' | |
| def _print_ImaginaryUnit(self, expr): | |
| return '1j' | |
| def _print_seq(self, seq, delimiter=', '): | |
| # simplified _print_seq taken from pretty.py | |
| s = [self._print(item) for item in seq] | |
| if s: | |
| return delimiter.join(s) | |
| else: | |
| return "" | |
| def _print_Function(self, e): | |
| func_name = e.func.__name__ | |
| nstr = self._numexpr_functions.get(func_name, None) | |
| if nstr is None: | |
| # check for implemented_function | |
| if hasattr(e, '_imp_'): | |
| return "(%s)" % self._print(e._imp_(*e.args)) | |
| else: | |
| raise TypeError("numexpr does not support function '%s'" % | |
| func_name) | |
| return "%s(%s)" % (nstr, self._print_seq(e.args)) | |
| def _print_Piecewise(self, expr): | |
| "Piecewise function printer" | |
| exprs = [self._print(arg.expr) for arg in expr.args] | |
| conds = [self._print(arg.cond) for arg in expr.args] | |
| # If [default_value, True] is a (expr, cond) sequence in a Piecewise object | |
| # it will behave the same as passing the 'default' kwarg to select() | |
| # *as long as* it is the last element in expr.args. | |
| # If this is not the case, it may be triggered prematurely. | |
| ans = [] | |
| parenthesis_count = 0 | |
| is_last_cond_True = False | |
| for cond, expr in zip(conds, exprs): | |
| if cond == 'True': | |
| ans.append(expr) | |
| is_last_cond_True = True | |
| break | |
| else: | |
| ans.append('where(%s, %s, ' % (cond, expr)) | |
| parenthesis_count += 1 | |
| if not is_last_cond_True: | |
| # See https://github.com/pydata/numexpr/issues/298 | |
| # | |
| # simplest way to put a nan but raises | |
| # 'RuntimeWarning: invalid value encountered in log' | |
| # | |
| # There are other ways to do this such as | |
| # | |
| # >>> import numexpr as ne | |
| # >>> nan = float('nan') | |
| # >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan}) | |
| # array([-1., nan, nan]) | |
| # | |
| # That needs to be handled in the lambdified function though rather | |
| # than here in the printer. | |
| ans.append('log(-1)') | |
| return ''.join(ans) + ')' * parenthesis_count | |
| def _print_ITE(self, expr): | |
| from sympy.functions.elementary.piecewise import Piecewise | |
| return self._print(expr.rewrite(Piecewise)) | |
| def blacklisted(self, expr): | |
| raise TypeError("numexpr cannot be used with %s" % | |
| expr.__class__.__name__) | |
| # blacklist all Matrix printing | |
| _print_SparseRepMatrix = \ | |
| _print_MutableSparseMatrix = \ | |
| _print_ImmutableSparseMatrix = \ | |
| _print_Matrix = \ | |
| _print_DenseMatrix = \ | |
| _print_MutableDenseMatrix = \ | |
| _print_ImmutableMatrix = \ | |
| _print_ImmutableDenseMatrix = \ | |
| blacklisted | |
| # blacklist some Python expressions | |
| _print_list = \ | |
| _print_tuple = \ | |
| _print_Tuple = \ | |
| _print_dict = \ | |
| _print_Dict = \ | |
| blacklisted | |
| def _print_NumExprEvaluate(self, expr): | |
| evaluate = self._module_format(self.module +".evaluate") | |
| return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr)) | |
| def doprint(self, expr): | |
| from sympy.codegen.ast import CodegenAST | |
| from sympy.codegen.pynodes import NumExprEvaluate | |
| if not isinstance(expr, CodegenAST): | |
| expr = NumExprEvaluate(expr) | |
| return super().doprint(expr) | |
| def _print_Return(self, expr): | |
| from sympy.codegen.pynodes import NumExprEvaluate | |
| r, = expr.args | |
| if not isinstance(r, NumExprEvaluate): | |
| expr = expr.func(NumExprEvaluate(r)) | |
| return super()._print_Return(expr) | |
| def _print_Assignment(self, expr): | |
| from sympy.codegen.pynodes import NumExprEvaluate | |
| lhs, rhs, *args = expr.args | |
| if not isinstance(rhs, NumExprEvaluate): | |
| expr = expr.func(lhs, NumExprEvaluate(rhs), *args) | |
| return super()._print_Assignment(expr) | |
| def _print_CodeBlock(self, expr): | |
| from sympy.codegen.ast import CodegenAST | |
| from sympy.codegen.pynodes import NumExprEvaluate | |
| args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ] | |
| return super()._print_CodeBlock(self, expr.func(*args)) | |
| class IntervalPrinter(MpmathPrinter, LambdaPrinter): | |
| """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """ | |
| def _print_Integer(self, expr): | |
| return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr) | |
| def _print_Rational(self, expr): | |
| return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) | |
| def _print_Half(self, expr): | |
| return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) | |
| def _print_Pow(self, expr): | |
| return super(MpmathPrinter, self)._print_Pow(expr, rational=True) | |
| for k in NumExprPrinter._numexpr_functions: | |
| setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function) | |
| def lambdarepr(expr, **settings): | |
| """ | |
| Returns a string usable for lambdifying. | |
| """ | |
| return LambdaPrinter(settings).doprint(expr) | |