Spaces:
Sleeping
Sleeping
from sympy.external.importtools import version_tuple | |
from collections.abc import Iterable | |
from sympy.core.mul import Mul | |
from sympy.core.singleton import S | |
from sympy.codegen.cfunctions import Sqrt | |
from sympy.external import import_module | |
from sympy.printing.precedence import PRECEDENCE | |
from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter | |
import sympy | |
tensorflow = import_module('tensorflow') | |
class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter): | |
""" | |
Tensorflow printer which handles vectorized piecewise functions, | |
logical operators, max/min, and relational operators. | |
""" | |
printmethod = "_tensorflowcode" | |
mapping = { | |
sympy.Abs: "tensorflow.math.abs", | |
sympy.sign: "tensorflow.math.sign", | |
# XXX May raise error for ints. | |
sympy.ceiling: "tensorflow.math.ceil", | |
sympy.floor: "tensorflow.math.floor", | |
sympy.log: "tensorflow.math.log", | |
sympy.exp: "tensorflow.math.exp", | |
Sqrt: "tensorflow.math.sqrt", | |
sympy.cos: "tensorflow.math.cos", | |
sympy.acos: "tensorflow.math.acos", | |
sympy.sin: "tensorflow.math.sin", | |
sympy.asin: "tensorflow.math.asin", | |
sympy.tan: "tensorflow.math.tan", | |
sympy.atan: "tensorflow.math.atan", | |
sympy.atan2: "tensorflow.math.atan2", | |
# XXX Also may give NaN for complex results. | |
sympy.cosh: "tensorflow.math.cosh", | |
sympy.acosh: "tensorflow.math.acosh", | |
sympy.sinh: "tensorflow.math.sinh", | |
sympy.asinh: "tensorflow.math.asinh", | |
sympy.tanh: "tensorflow.math.tanh", | |
sympy.atanh: "tensorflow.math.atanh", | |
sympy.re: "tensorflow.math.real", | |
sympy.im: "tensorflow.math.imag", | |
sympy.arg: "tensorflow.math.angle", | |
# XXX May raise error for ints and complexes | |
sympy.erf: "tensorflow.math.erf", | |
sympy.loggamma: "tensorflow.math.lgamma", | |
sympy.Eq: "tensorflow.math.equal", | |
sympy.Ne: "tensorflow.math.not_equal", | |
sympy.StrictGreaterThan: "tensorflow.math.greater", | |
sympy.StrictLessThan: "tensorflow.math.less", | |
sympy.LessThan: "tensorflow.math.less_equal", | |
sympy.GreaterThan: "tensorflow.math.greater_equal", | |
sympy.And: "tensorflow.math.logical_and", | |
sympy.Or: "tensorflow.math.logical_or", | |
sympy.Not: "tensorflow.math.logical_not", | |
sympy.Max: "tensorflow.math.maximum", | |
sympy.Min: "tensorflow.math.minimum", | |
# Matrices | |
sympy.MatAdd: "tensorflow.math.add", | |
sympy.HadamardProduct: "tensorflow.math.multiply", | |
sympy.Trace: "tensorflow.linalg.trace", | |
# XXX May raise error for integer matrices. | |
sympy.Determinant : "tensorflow.linalg.det", | |
} | |
_default_settings = dict( | |
AbstractPythonCodePrinter._default_settings, | |
tensorflow_version=None | |
) | |
def __init__(self, settings=None): | |
super().__init__(settings) | |
version = self._settings['tensorflow_version'] | |
if version is None and tensorflow: | |
version = tensorflow.__version__ | |
self.tensorflow_version = version | |
def _print_Function(self, expr): | |
op = self.mapping.get(type(expr), None) | |
if op is None: | |
return super()._print_Basic(expr) | |
children = [self._print(arg) for arg in expr.args] | |
if len(children) == 1: | |
return "%s(%s)" % ( | |
self._module_format(op), | |
children[0] | |
) | |
else: | |
return self._expand_fold_binary_op(op, children) | |
_print_Expr = _print_Function | |
_print_Application = _print_Function | |
_print_MatrixExpr = _print_Function | |
# TODO: a better class structure would avoid this mess: | |
_print_Relational = _print_Function | |
_print_Not = _print_Function | |
_print_And = _print_Function | |
_print_Or = _print_Function | |
_print_HadamardProduct = _print_Function | |
_print_Trace = _print_Function | |
_print_Determinant = _print_Function | |
def _print_Inverse(self, expr): | |
op = self._module_format('tensorflow.linalg.inv') | |
return "{}({})".format(op, self._print(expr.arg)) | |
def _print_Transpose(self, expr): | |
version = self.tensorflow_version | |
if version and version_tuple(version) < version_tuple('1.14'): | |
op = self._module_format('tensorflow.matrix_transpose') | |
else: | |
op = self._module_format('tensorflow.linalg.matrix_transpose') | |
return "{}({})".format(op, self._print(expr.arg)) | |
def _print_Derivative(self, expr): | |
variables = expr.variables | |
if any(isinstance(i, Iterable) for i in variables): | |
raise NotImplementedError("derivation by multiple variables is not supported") | |
def unfold(expr, args): | |
if not args: | |
return self._print(expr) | |
return "%s(%s, %s)[0]" % ( | |
self._module_format("tensorflow.gradients"), | |
unfold(expr, args[:-1]), | |
self._print(args[-1]), | |
) | |
return unfold(expr.expr, variables) | |
def _print_Piecewise(self, expr): | |
version = self.tensorflow_version | |
if version and version_tuple(version) < version_tuple('1.0'): | |
tensorflow_piecewise = "tensorflow.select" | |
else: | |
tensorflow_piecewise = "tensorflow.where" | |
from sympy.functions.elementary.piecewise import Piecewise | |
e, cond = expr.args[0].args | |
if len(expr.args) == 1: | |
return '{}({}, {}, {})'.format( | |
self._module_format(tensorflow_piecewise), | |
self._print(cond), | |
self._print(e), | |
0) | |
return '{}({}, {}, {})'.format( | |
self._module_format(tensorflow_piecewise), | |
self._print(cond), | |
self._print(e), | |
self._print(Piecewise(*expr.args[1:]))) | |
def _print_Pow(self, expr): | |
# XXX May raise error for | |
# int**float or int**complex or float**complex | |
base, exp = expr.args | |
if expr.exp == S.Half: | |
return "{}({})".format( | |
self._module_format("tensorflow.math.sqrt"), self._print(base)) | |
return "{}({}, {})".format( | |
self._module_format("tensorflow.math.pow"), | |
self._print(base), self._print(exp)) | |
def _print_MatrixBase(self, expr): | |
tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant" | |
data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]" | |
return "%s(%s)" % ( | |
self._module_format(tensorflow_f), | |
data, | |
) | |
def _print_MatMul(self, expr): | |
from sympy.matrices.expressions import MatrixExpr | |
mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] | |
args = [arg for arg in expr.args if arg not in mat_args] | |
if args: | |
return "%s*%s" % ( | |
self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), | |
self._expand_fold_binary_op( | |
"tensorflow.linalg.matmul", mat_args) | |
) | |
else: | |
return self._expand_fold_binary_op( | |
"tensorflow.linalg.matmul", mat_args) | |
def _print_MatPow(self, expr): | |
return self._expand_fold_binary_op( | |
"tensorflow.linalg.matmul", [expr.base]*expr.exp) | |
def _print_CodeBlock(self, expr): | |
# TODO: is this necessary? | |
ret = [] | |
for subexpr in expr.args: | |
ret.append(self._print(subexpr)) | |
return "\n".join(ret) | |
_module = "tensorflow" | |
_einsum = "linalg.einsum" | |
_add = "math.add" | |
_transpose = "transpose" | |
_ones = "ones" | |
_zeros = "zeros" | |
def tensorflow_code(expr, **settings): | |
printer = TensorflowPrinter(settings) | |
return printer.doprint(expr) | |