Spaces:
Sleeping
Sleeping
File size: 8,305 Bytes
6a86ad5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
from .pycode import (
PythonCodePrinter,
MpmathPrinter,
)
from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility
from sympy.core.sorting import default_sort_key
__all__ = [
'PythonCodePrinter',
'MpmathPrinter', # MpmathPrinter is published for backward compatibility
'NumPyPrinter',
'LambdaPrinter',
'NumPyPrinter',
'IntervalPrinter',
'lambdarepr',
]
class LambdaPrinter(PythonCodePrinter):
"""
This printer converts expressions into strings that can be used by
lambdify.
"""
printmethod = "_lambdacode"
def _print_And(self, expr):
result = ['(']
for arg in sorted(expr.args, key=default_sort_key):
result.extend(['(', self._print(arg), ')'])
result.append(' and ')
result = result[:-1]
result.append(')')
return ''.join(result)
def _print_Or(self, expr):
result = ['(']
for arg in sorted(expr.args, key=default_sort_key):
result.extend(['(', self._print(arg), ')'])
result.append(' or ')
result = result[:-1]
result.append(')')
return ''.join(result)
def _print_Not(self, expr):
result = ['(', 'not (', self._print(expr.args[0]), '))']
return ''.join(result)
def _print_BooleanTrue(self, expr):
return "True"
def _print_BooleanFalse(self, expr):
return "False"
def _print_ITE(self, expr):
result = [
'((', self._print(expr.args[1]),
') if (', self._print(expr.args[0]),
') else (', self._print(expr.args[2]), '))'
]
return ''.join(result)
def _print_NumberSymbol(self, expr):
return str(expr)
def _print_Pow(self, expr, **kwargs):
# XXX Temporary workaround. Should Python math printer be
# isolated from PythonCodePrinter?
return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs)
# numexpr works by altering the string passed to numexpr.evaluate
# rather than by populating a namespace. Thus a special printer...
class NumExprPrinter(LambdaPrinter):
# key, value pairs correspond to SymPy name and numexpr name
# functions not appearing in this dict will raise a TypeError
printmethod = "_numexprcode"
_numexpr_functions = {
'sin' : 'sin',
'cos' : 'cos',
'tan' : 'tan',
'asin': 'arcsin',
'acos': 'arccos',
'atan': 'arctan',
'atan2' : 'arctan2',
'sinh' : 'sinh',
'cosh' : 'cosh',
'tanh' : 'tanh',
'asinh': 'arcsinh',
'acosh': 'arccosh',
'atanh': 'arctanh',
'ln' : 'log',
'log': 'log',
'exp': 'exp',
'sqrt' : 'sqrt',
'Abs' : 'abs',
'conjugate' : 'conj',
'im' : 'imag',
're' : 'real',
'where' : 'where',
'complex' : 'complex',
'contains' : 'contains',
}
module = 'numexpr'
def _print_ImaginaryUnit(self, expr):
return '1j'
def _print_seq(self, seq, delimiter=', '):
# simplified _print_seq taken from pretty.py
s = [self._print(item) for item in seq]
if s:
return delimiter.join(s)
else:
return ""
def _print_Function(self, e):
func_name = e.func.__name__
nstr = self._numexpr_functions.get(func_name, None)
if nstr is None:
# check for implemented_function
if hasattr(e, '_imp_'):
return "(%s)" % self._print(e._imp_(*e.args))
else:
raise TypeError("numexpr does not support function '%s'" %
func_name)
return "%s(%s)" % (nstr, self._print_seq(e.args))
def _print_Piecewise(self, expr):
"Piecewise function printer"
exprs = [self._print(arg.expr) for arg in expr.args]
conds = [self._print(arg.cond) for arg in expr.args]
# If [default_value, True] is a (expr, cond) sequence in a Piecewise object
# it will behave the same as passing the 'default' kwarg to select()
# *as long as* it is the last element in expr.args.
# If this is not the case, it may be triggered prematurely.
ans = []
parenthesis_count = 0
is_last_cond_True = False
for cond, expr in zip(conds, exprs):
if cond == 'True':
ans.append(expr)
is_last_cond_True = True
break
else:
ans.append('where(%s, %s, ' % (cond, expr))
parenthesis_count += 1
if not is_last_cond_True:
# See https://github.com/pydata/numexpr/issues/298
#
# simplest way to put a nan but raises
# 'RuntimeWarning: invalid value encountered in log'
#
# There are other ways to do this such as
#
# >>> import numexpr as ne
# >>> nan = float('nan')
# >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan})
# array([-1., nan, nan])
#
# That needs to be handled in the lambdified function though rather
# than here in the printer.
ans.append('log(-1)')
return ''.join(ans) + ')' * parenthesis_count
def _print_ITE(self, expr):
from sympy.functions.elementary.piecewise import Piecewise
return self._print(expr.rewrite(Piecewise))
def blacklisted(self, expr):
raise TypeError("numexpr cannot be used with %s" %
expr.__class__.__name__)
# blacklist all Matrix printing
_print_SparseRepMatrix = \
_print_MutableSparseMatrix = \
_print_ImmutableSparseMatrix = \
_print_Matrix = \
_print_DenseMatrix = \
_print_MutableDenseMatrix = \
_print_ImmutableMatrix = \
_print_ImmutableDenseMatrix = \
blacklisted
# blacklist some Python expressions
_print_list = \
_print_tuple = \
_print_Tuple = \
_print_dict = \
_print_Dict = \
blacklisted
def _print_NumExprEvaluate(self, expr):
evaluate = self._module_format(self.module +".evaluate")
return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr))
def doprint(self, expr):
from sympy.codegen.ast import CodegenAST
from sympy.codegen.pynodes import NumExprEvaluate
if not isinstance(expr, CodegenAST):
expr = NumExprEvaluate(expr)
return super().doprint(expr)
def _print_Return(self, expr):
from sympy.codegen.pynodes import NumExprEvaluate
r, = expr.args
if not isinstance(r, NumExprEvaluate):
expr = expr.func(NumExprEvaluate(r))
return super()._print_Return(expr)
def _print_Assignment(self, expr):
from sympy.codegen.pynodes import NumExprEvaluate
lhs, rhs, *args = expr.args
if not isinstance(rhs, NumExprEvaluate):
expr = expr.func(lhs, NumExprEvaluate(rhs), *args)
return super()._print_Assignment(expr)
def _print_CodeBlock(self, expr):
from sympy.codegen.ast import CodegenAST
from sympy.codegen.pynodes import NumExprEvaluate
args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ]
return super()._print_CodeBlock(self, expr.func(*args))
class IntervalPrinter(MpmathPrinter, LambdaPrinter):
"""Use ``lambda`` printer but print numbers as ``mpi`` intervals. """
def _print_Integer(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr)
def _print_Rational(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
def _print_Half(self, expr):
return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
def _print_Pow(self, expr):
return super(MpmathPrinter, self)._print_Pow(expr, rational=True)
for k in NumExprPrinter._numexpr_functions:
setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)
def lambdarepr(expr, **settings):
"""
Returns a string usable for lambdifying.
"""
return LambdaPrinter(settings).doprint(expr)
|