Spaces:
Sleeping
Sleeping
""" | |
Mathematica code printer | |
""" | |
from __future__ import annotations | |
from typing import Any | |
from sympy.core import Basic, Expr, Float | |
from sympy.core.sorting import default_sort_key | |
from sympy.printing.codeprinter import CodePrinter | |
from sympy.printing.precedence import precedence | |
# Used in MCodePrinter._print_Function(self) | |
known_functions = { | |
"exp": [(lambda x: True, "Exp")], | |
"log": [(lambda x: True, "Log")], | |
"sin": [(lambda x: True, "Sin")], | |
"cos": [(lambda x: True, "Cos")], | |
"tan": [(lambda x: True, "Tan")], | |
"cot": [(lambda x: True, "Cot")], | |
"sec": [(lambda x: True, "Sec")], | |
"csc": [(lambda x: True, "Csc")], | |
"asin": [(lambda x: True, "ArcSin")], | |
"acos": [(lambda x: True, "ArcCos")], | |
"atan": [(lambda x: True, "ArcTan")], | |
"acot": [(lambda x: True, "ArcCot")], | |
"asec": [(lambda x: True, "ArcSec")], | |
"acsc": [(lambda x: True, "ArcCsc")], | |
"atan2": [(lambda *x: True, "ArcTan")], | |
"sinh": [(lambda x: True, "Sinh")], | |
"cosh": [(lambda x: True, "Cosh")], | |
"tanh": [(lambda x: True, "Tanh")], | |
"coth": [(lambda x: True, "Coth")], | |
"sech": [(lambda x: True, "Sech")], | |
"csch": [(lambda x: True, "Csch")], | |
"asinh": [(lambda x: True, "ArcSinh")], | |
"acosh": [(lambda x: True, "ArcCosh")], | |
"atanh": [(lambda x: True, "ArcTanh")], | |
"acoth": [(lambda x: True, "ArcCoth")], | |
"asech": [(lambda x: True, "ArcSech")], | |
"acsch": [(lambda x: True, "ArcCsch")], | |
"sinc": [(lambda x: True, "Sinc")], | |
"conjugate": [(lambda x: True, "Conjugate")], | |
"Max": [(lambda *x: True, "Max")], | |
"Min": [(lambda *x: True, "Min")], | |
"erf": [(lambda x: True, "Erf")], | |
"erf2": [(lambda *x: True, "Erf")], | |
"erfc": [(lambda x: True, "Erfc")], | |
"erfi": [(lambda x: True, "Erfi")], | |
"erfinv": [(lambda x: True, "InverseErf")], | |
"erfcinv": [(lambda x: True, "InverseErfc")], | |
"erf2inv": [(lambda *x: True, "InverseErf")], | |
"expint": [(lambda *x: True, "ExpIntegralE")], | |
"Ei": [(lambda x: True, "ExpIntegralEi")], | |
"fresnelc": [(lambda x: True, "FresnelC")], | |
"fresnels": [(lambda x: True, "FresnelS")], | |
"gamma": [(lambda x: True, "Gamma")], | |
"uppergamma": [(lambda *x: True, "Gamma")], | |
"polygamma": [(lambda *x: True, "PolyGamma")], | |
"loggamma": [(lambda x: True, "LogGamma")], | |
"beta": [(lambda *x: True, "Beta")], | |
"Ci": [(lambda x: True, "CosIntegral")], | |
"Si": [(lambda x: True, "SinIntegral")], | |
"Chi": [(lambda x: True, "CoshIntegral")], | |
"Shi": [(lambda x: True, "SinhIntegral")], | |
"li": [(lambda x: True, "LogIntegral")], | |
"factorial": [(lambda x: True, "Factorial")], | |
"factorial2": [(lambda x: True, "Factorial2")], | |
"subfactorial": [(lambda x: True, "Subfactorial")], | |
"catalan": [(lambda x: True, "CatalanNumber")], | |
"harmonic": [(lambda *x: True, "HarmonicNumber")], | |
"lucas": [(lambda x: True, "LucasL")], | |
"RisingFactorial": [(lambda *x: True, "Pochhammer")], | |
"FallingFactorial": [(lambda *x: True, "FactorialPower")], | |
"laguerre": [(lambda *x: True, "LaguerreL")], | |
"assoc_laguerre": [(lambda *x: True, "LaguerreL")], | |
"hermite": [(lambda *x: True, "HermiteH")], | |
"jacobi": [(lambda *x: True, "JacobiP")], | |
"gegenbauer": [(lambda *x: True, "GegenbauerC")], | |
"chebyshevt": [(lambda *x: True, "ChebyshevT")], | |
"chebyshevu": [(lambda *x: True, "ChebyshevU")], | |
"legendre": [(lambda *x: True, "LegendreP")], | |
"assoc_legendre": [(lambda *x: True, "LegendreP")], | |
"mathieuc": [(lambda *x: True, "MathieuC")], | |
"mathieus": [(lambda *x: True, "MathieuS")], | |
"mathieucprime": [(lambda *x: True, "MathieuCPrime")], | |
"mathieusprime": [(lambda *x: True, "MathieuSPrime")], | |
"stieltjes": [(lambda x: True, "StieltjesGamma")], | |
"elliptic_e": [(lambda *x: True, "EllipticE")], | |
"elliptic_f": [(lambda *x: True, "EllipticE")], | |
"elliptic_k": [(lambda x: True, "EllipticK")], | |
"elliptic_pi": [(lambda *x: True, "EllipticPi")], | |
"zeta": [(lambda *x: True, "Zeta")], | |
"dirichlet_eta": [(lambda x: True, "DirichletEta")], | |
"riemann_xi": [(lambda x: True, "RiemannXi")], | |
"besseli": [(lambda *x: True, "BesselI")], | |
"besselj": [(lambda *x: True, "BesselJ")], | |
"besselk": [(lambda *x: True, "BesselK")], | |
"bessely": [(lambda *x: True, "BesselY")], | |
"hankel1": [(lambda *x: True, "HankelH1")], | |
"hankel2": [(lambda *x: True, "HankelH2")], | |
"airyai": [(lambda x: True, "AiryAi")], | |
"airybi": [(lambda x: True, "AiryBi")], | |
"airyaiprime": [(lambda x: True, "AiryAiPrime")], | |
"airybiprime": [(lambda x: True, "AiryBiPrime")], | |
"polylog": [(lambda *x: True, "PolyLog")], | |
"lerchphi": [(lambda *x: True, "LerchPhi")], | |
"gcd": [(lambda *x: True, "GCD")], | |
"lcm": [(lambda *x: True, "LCM")], | |
"jn": [(lambda *x: True, "SphericalBesselJ")], | |
"yn": [(lambda *x: True, "SphericalBesselY")], | |
"hyper": [(lambda *x: True, "HypergeometricPFQ")], | |
"meijerg": [(lambda *x: True, "MeijerG")], | |
"appellf1": [(lambda *x: True, "AppellF1")], | |
"DiracDelta": [(lambda x: True, "DiracDelta")], | |
"Heaviside": [(lambda x: True, "HeavisideTheta")], | |
"KroneckerDelta": [(lambda *x: True, "KroneckerDelta")], | |
"sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites | |
} | |
class MCodePrinter(CodePrinter): | |
"""A printer to convert Python expressions to | |
strings of the Wolfram's Mathematica code | |
""" | |
printmethod = "_mcode" | |
language = "Wolfram Language" | |
_default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ | |
'precision': 15, | |
'user_functions': {}, | |
}) | |
_number_symbols: set[tuple[Expr, Float]] = set() | |
_not_supported: set[Basic] = set() | |
def __init__(self, settings={}): | |
"""Register function mappings supplied by user""" | |
CodePrinter.__init__(self, settings) | |
self.known_functions = dict(known_functions) | |
userfuncs = settings.get('user_functions', {}).copy() | |
for k, v in userfuncs.items(): | |
if not isinstance(v, list): | |
userfuncs[k] = [(lambda *x: True, v)] | |
self.known_functions.update(userfuncs) | |
def _format_code(self, lines): | |
return lines | |
def _print_Pow(self, expr): | |
PREC = precedence(expr) | |
return '%s^%s' % (self.parenthesize(expr.base, PREC), | |
self.parenthesize(expr.exp, PREC)) | |
def _print_Mul(self, expr): | |
PREC = precedence(expr) | |
c, nc = expr.args_cnc() | |
res = super()._print_Mul(expr.func(*c)) | |
if nc: | |
res += '*' | |
res += '**'.join(self.parenthesize(a, PREC) for a in nc) | |
return res | |
def _print_Relational(self, expr): | |
lhs_code = self._print(expr.lhs) | |
rhs_code = self._print(expr.rhs) | |
op = expr.rel_op | |
return "{} {} {}".format(lhs_code, op, rhs_code) | |
# Primitive numbers | |
def _print_Zero(self, expr): | |
return '0' | |
def _print_One(self, expr): | |
return '1' | |
def _print_NegativeOne(self, expr): | |
return '-1' | |
def _print_Half(self, expr): | |
return '1/2' | |
def _print_ImaginaryUnit(self, expr): | |
return 'I' | |
# Infinity and invalid numbers | |
def _print_Infinity(self, expr): | |
return 'Infinity' | |
def _print_NegativeInfinity(self, expr): | |
return '-Infinity' | |
def _print_ComplexInfinity(self, expr): | |
return 'ComplexInfinity' | |
def _print_NaN(self, expr): | |
return 'Indeterminate' | |
# Mathematical constants | |
def _print_Exp1(self, expr): | |
return 'E' | |
def _print_Pi(self, expr): | |
return 'Pi' | |
def _print_GoldenRatio(self, expr): | |
return 'GoldenRatio' | |
def _print_TribonacciConstant(self, expr): | |
expanded = expr.expand(func=True) | |
PREC = precedence(expr) | |
return self.parenthesize(expanded, PREC) | |
def _print_EulerGamma(self, expr): | |
return 'EulerGamma' | |
def _print_Catalan(self, expr): | |
return 'Catalan' | |
def _print_list(self, expr): | |
return '{' + ', '.join(self.doprint(a) for a in expr) + '}' | |
_print_tuple = _print_list | |
_print_Tuple = _print_list | |
def _print_ImmutableDenseMatrix(self, expr): | |
return self.doprint(expr.tolist()) | |
def _print_ImmutableSparseMatrix(self, expr): | |
def print_rule(pos, val): | |
return '{} -> {}'.format( | |
self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val)) | |
def print_data(): | |
items = sorted(expr.todok().items(), key=default_sort_key) | |
return '{' + \ | |
', '.join(print_rule(k, v) for k, v in items) + \ | |
'}' | |
def print_dims(): | |
return self.doprint(expr.shape) | |
return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) | |
def _print_ImmutableDenseNDimArray(self, expr): | |
return self.doprint(expr.tolist()) | |
def _print_ImmutableSparseNDimArray(self, expr): | |
def print_string_list(string_list): | |
return '{' + ', '.join(a for a in string_list) + '}' | |
def to_mathematica_index(*args): | |
"""Helper function to change Python style indexing to | |
Pathematica indexing. | |
Python indexing (0, 1 ... n-1) | |
-> Mathematica indexing (1, 2 ... n) | |
""" | |
return tuple(i + 1 for i in args) | |
def print_rule(pos, val): | |
"""Helper function to print a rule of Mathematica""" | |
return '{} -> {}'.format(self.doprint(pos), self.doprint(val)) | |
def print_data(): | |
"""Helper function to print data part of Mathematica | |
sparse array. | |
It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` | |
from | |
https://reference.wolfram.com/language/ref/SparseArray.html | |
``data`` must be formatted with rule. | |
""" | |
return print_string_list( | |
[print_rule( | |
to_mathematica_index(*(expr._get_tuple_index(key))), | |
value) | |
for key, value in sorted(expr._sparse_array.items())] | |
) | |
def print_dims(): | |
"""Helper function to print dimensions part of Mathematica | |
sparse array. | |
It uses the fourth notation ``SparseArray[data,{d1,d2,...}]`` | |
from | |
https://reference.wolfram.com/language/ref/SparseArray.html | |
""" | |
return self.doprint(expr.shape) | |
return 'SparseArray[{}, {}]'.format(print_data(), print_dims()) | |
def _print_Function(self, expr): | |
if expr.func.__name__ in self.known_functions: | |
cond_mfunc = self.known_functions[expr.func.__name__] | |
for cond, mfunc in cond_mfunc: | |
if cond(*expr.args): | |
return "%s[%s]" % (mfunc, self.stringify(expr.args, ", ")) | |
elif expr.func.__name__ in self._rewriteable_functions: | |
# Simple rewrite to supported function possible | |
target_f, required_fs = self._rewriteable_functions[expr.func.__name__] | |
if self._can_print(target_f) and all(self._can_print(f) for f in required_fs): | |
return self._print(expr.rewrite(target_f)) | |
return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ") | |
_print_MinMaxBase = _print_Function | |
def _print_LambertW(self, expr): | |
if len(expr.args) == 1: | |
return "ProductLog[{}]".format(self._print(expr.args[0])) | |
return "ProductLog[{}, {}]".format( | |
self._print(expr.args[1]), self._print(expr.args[0])) | |
def _print_Integral(self, expr): | |
if len(expr.variables) == 1 and not expr.limits[0][1:]: | |
args = [expr.args[0], expr.variables[0]] | |
else: | |
args = expr.args | |
return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]" | |
def _print_Sum(self, expr): | |
return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]" | |
def _print_Derivative(self, expr): | |
dexpr = expr.expr | |
dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count] | |
return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]" | |
def _get_comment(self, text): | |
return "(* {} *)".format(text) | |
def mathematica_code(expr, **settings): | |
r"""Converts an expr to a string of the Wolfram Mathematica code | |
Examples | |
======== | |
>>> from sympy import mathematica_code as mcode, symbols, sin | |
>>> x = symbols('x') | |
>>> mcode(sin(x).series(x).removeO()) | |
'(1/120)*x^5 - 1/6*x^3 + x' | |
""" | |
return MCodePrinter(settings).doprint(expr) | |