Spaces:
Sleeping
Sleeping
import keyword as kw | |
import sympy | |
from .repr import ReprPrinter | |
from .str import StrPrinter | |
# A list of classes that should be printed using StrPrinter | |
STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow") | |
class PythonPrinter(ReprPrinter, StrPrinter): | |
"""A printer which converts an expression into its Python interpretation.""" | |
def __init__(self, settings=None): | |
super().__init__(settings) | |
self.symbols = [] | |
self.functions = [] | |
# Create print methods for classes that should use StrPrinter instead | |
# of ReprPrinter. | |
for name in STRPRINT: | |
f_name = "_print_%s" % name | |
f = getattr(StrPrinter, f_name) | |
setattr(PythonPrinter, f_name, f) | |
def _print_Function(self, expr): | |
func = expr.func.__name__ | |
if not hasattr(sympy, func) and func not in self.functions: | |
self.functions.append(func) | |
return StrPrinter._print_Function(self, expr) | |
# procedure (!) for defining symbols which have be defined in print_python() | |
def _print_Symbol(self, expr): | |
symbol = self._str(expr) | |
if symbol not in self.symbols: | |
self.symbols.append(symbol) | |
return StrPrinter._print_Symbol(self, expr) | |
def _print_module(self, expr): | |
raise ValueError('Modules in the expression are unacceptable') | |
def python(expr, **settings): | |
"""Return Python interpretation of passed expression | |
(can be passed to the exec() function without any modifications)""" | |
printer = PythonPrinter(settings) | |
exprp = printer.doprint(expr) | |
result = '' | |
# Returning found symbols and functions | |
renamings = {} | |
for symbolname in printer.symbols: | |
# Remove curly braces from subscripted variables | |
if '{' in symbolname: | |
newsymbolname = symbolname.replace('{', '').replace('}', '') | |
renamings[sympy.Symbol(symbolname)] = newsymbolname | |
else: | |
newsymbolname = symbolname | |
# Escape symbol names that are reserved Python keywords | |
if kw.iskeyword(newsymbolname): | |
while True: | |
newsymbolname += "_" | |
if (newsymbolname not in printer.symbols and | |
newsymbolname not in printer.functions): | |
renamings[sympy.Symbol( | |
symbolname)] = sympy.Symbol(newsymbolname) | |
break | |
result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n' | |
for functionname in printer.functions: | |
newfunctionname = functionname | |
# Escape function names that are reserved Python keywords | |
if kw.iskeyword(newfunctionname): | |
while True: | |
newfunctionname += "_" | |
if (newfunctionname not in printer.symbols and | |
newfunctionname not in printer.functions): | |
renamings[sympy.Function( | |
functionname)] = sympy.Function(newfunctionname) | |
break | |
result += newfunctionname + ' = Function(\'' + functionname + '\')\n' | |
if renamings: | |
exprp = expr.subs(renamings) | |
result += 'e = ' + printer._str(exprp) | |
return result | |
def print_python(expr, **settings): | |
"""Print output of python() function""" | |
print(python(expr, **settings)) | |