Spaces:
Sleeping
Sleeping
""" | |
.. deprecated:: 1.8 | |
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to | |
Aesara. Use ``sympy.printing.aesaracode`` instead. See | |
:ref:`theanocode-deprecated` for more information. | |
""" | |
from __future__ import annotations | |
from typing import Any | |
from sympy.external import import_module | |
from sympy.printing.printer import Printer | |
from sympy.utilities.iterables import is_sequence | |
import sympy | |
from functools import partial | |
from sympy.utilities.decorator import doctest_depends_on | |
from sympy.utilities.exceptions import sympy_deprecation_warning | |
__doctest_requires__ = {('theano_function',): ['theano']} | |
theano = import_module('theano') | |
if theano: | |
ts = theano.scalar | |
tt = theano.tensor | |
from theano.sandbox import linalg as tlinalg | |
mapping = { | |
sympy.Add: tt.add, | |
sympy.Mul: tt.mul, | |
sympy.Abs: tt.abs_, | |
sympy.sign: tt.sgn, | |
sympy.ceiling: tt.ceil, | |
sympy.floor: tt.floor, | |
sympy.log: tt.log, | |
sympy.exp: tt.exp, | |
sympy.sqrt: tt.sqrt, | |
sympy.cos: tt.cos, | |
sympy.acos: tt.arccos, | |
sympy.sin: tt.sin, | |
sympy.asin: tt.arcsin, | |
sympy.tan: tt.tan, | |
sympy.atan: tt.arctan, | |
sympy.atan2: tt.arctan2, | |
sympy.cosh: tt.cosh, | |
sympy.acosh: tt.arccosh, | |
sympy.sinh: tt.sinh, | |
sympy.asinh: tt.arcsinh, | |
sympy.tanh: tt.tanh, | |
sympy.atanh: tt.arctanh, | |
sympy.re: tt.real, | |
sympy.im: tt.imag, | |
sympy.arg: tt.angle, | |
sympy.erf: tt.erf, | |
sympy.gamma: tt.gamma, | |
sympy.loggamma: tt.gammaln, | |
sympy.Pow: tt.pow, | |
sympy.Eq: tt.eq, | |
sympy.StrictGreaterThan: tt.gt, | |
sympy.StrictLessThan: tt.lt, | |
sympy.LessThan: tt.le, | |
sympy.GreaterThan: tt.ge, | |
sympy.And: tt.and_, | |
sympy.Or: tt.or_, | |
sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2 | |
sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2 | |
sympy.conjugate: tt.conj, | |
sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1), | |
# Matrices | |
sympy.MatAdd: tt.Elemwise(ts.add), | |
sympy.HadamardProduct: tt.Elemwise(ts.mul), | |
sympy.Trace: tlinalg.trace, | |
sympy.Determinant : tlinalg.det, | |
sympy.Inverse: tlinalg.matrix_inverse, | |
sympy.Transpose: tt.DimShuffle((False, False), [1, 0]), | |
} | |
class TheanoPrinter(Printer): | |
""" Code printer which creates Theano symbolic expression graphs. | |
Parameters | |
========== | |
cache : dict | |
Cache dictionary to use. If None (default) will use | |
the global cache. To create a printer which does not depend on or alter | |
global state pass an empty dictionary. Note: the dictionary is not | |
copied on initialization of the printer and will be updated in-place, | |
so using the same dict object when creating multiple printers or making | |
multiple calls to :func:`.theano_code` or :func:`.theano_function` means | |
the cache is shared between all these applications. | |
Attributes | |
========== | |
cache : dict | |
A cache of Theano variables which have been created for SymPy | |
symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or | |
:class:`sympy.matrices.expressions.MatrixSymbol`). This is used to | |
ensure that all references to a given symbol in an expression (or | |
multiple expressions) are printed as the same Theano variable, which is | |
created only once. Symbols are differentiated only by name and type. The | |
format of the cache's contents should be considered opaque to the user. | |
""" | |
printmethod = "_theano" | |
def __init__(self, *args, **kwargs): | |
self.cache = kwargs.pop('cache', {}) | |
super().__init__(*args, **kwargs) | |
def _get_key(self, s, name=None, dtype=None, broadcastable=None): | |
""" Get the cache key for a SymPy object. | |
Parameters | |
========== | |
s : sympy.core.basic.Basic | |
SymPy object to get key for. | |
name : str | |
Name of object, if it does not have a ``name`` attribute. | |
""" | |
if name is None: | |
name = s.name | |
return (name, type(s), s.args, dtype, broadcastable) | |
def _get_or_create(self, s, name=None, dtype=None, broadcastable=None): | |
""" | |
Get the Theano variable for a SymPy symbol from the cache, or create it | |
if it does not exist. | |
""" | |
# Defaults | |
if name is None: | |
name = s.name | |
if dtype is None: | |
dtype = 'floatX' | |
if broadcastable is None: | |
broadcastable = () | |
key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable) | |
if key in self.cache: | |
return self.cache[key] | |
value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable) | |
self.cache[key] = value | |
return value | |
def _print_Symbol(self, s, **kwargs): | |
dtype = kwargs.get('dtypes', {}).get(s) | |
bc = kwargs.get('broadcastables', {}).get(s) | |
return self._get_or_create(s, dtype=dtype, broadcastable=bc) | |
def _print_AppliedUndef(self, s, **kwargs): | |
name = str(type(s)) + '_' + str(s.args[0]) | |
dtype = kwargs.get('dtypes', {}).get(s) | |
bc = kwargs.get('broadcastables', {}).get(s) | |
return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc) | |
def _print_Basic(self, expr, **kwargs): | |
op = mapping[type(expr)] | |
children = [self._print(arg, **kwargs) for arg in expr.args] | |
return op(*children) | |
def _print_Number(self, n, **kwargs): | |
# Integers already taken care of below, interpret as float | |
return float(n.evalf()) | |
def _print_MatrixSymbol(self, X, **kwargs): | |
dtype = kwargs.get('dtypes', {}).get(X) | |
return self._get_or_create(X, dtype=dtype, broadcastable=(None, None)) | |
def _print_DenseMatrix(self, X, **kwargs): | |
if not hasattr(tt, 'stacklists'): | |
raise NotImplementedError( | |
"Matrix translation not yet supported in this version of Theano") | |
return tt.stacklists([ | |
[self._print(arg, **kwargs) for arg in L] | |
for L in X.tolist() | |
]) | |
_print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix | |
def _print_MatMul(self, expr, **kwargs): | |
children = [self._print(arg, **kwargs) for arg in expr.args] | |
result = children[0] | |
for child in children[1:]: | |
result = tt.dot(result, child) | |
return result | |
def _print_MatPow(self, expr, **kwargs): | |
children = [self._print(arg, **kwargs) for arg in expr.args] | |
result = 1 | |
if isinstance(children[1], int) and children[1] > 0: | |
for i in range(children[1]): | |
result = tt.dot(result, children[0]) | |
else: | |
raise NotImplementedError('''Only non-negative integer | |
powers of matrices can be handled by Theano at the moment''') | |
return result | |
def _print_MatrixSlice(self, expr, **kwargs): | |
parent = self._print(expr.parent, **kwargs) | |
rowslice = self._print(slice(*expr.rowslice), **kwargs) | |
colslice = self._print(slice(*expr.colslice), **kwargs) | |
return parent[rowslice, colslice] | |
def _print_BlockMatrix(self, expr, **kwargs): | |
nrows, ncols = expr.blocks.shape | |
blocks = [[self._print(expr.blocks[r, c], **kwargs) | |
for c in range(ncols)] | |
for r in range(nrows)] | |
return tt.join(0, *[tt.join(1, *row) for row in blocks]) | |
def _print_slice(self, expr, **kwargs): | |
return slice(*[self._print(i, **kwargs) | |
if isinstance(i, sympy.Basic) else i | |
for i in (expr.start, expr.stop, expr.step)]) | |
def _print_Pi(self, expr, **kwargs): | |
return 3.141592653589793 | |
def _print_Exp1(self, expr, **kwargs): | |
return ts.exp(1) | |
def _print_Piecewise(self, expr, **kwargs): | |
import numpy as np | |
e, cond = expr.args[0].args # First condition and corresponding value | |
# Print conditional expression and value for first condition | |
p_cond = self._print(cond, **kwargs) | |
p_e = self._print(e, **kwargs) | |
# One condition only | |
if len(expr.args) == 1: | |
# Return value if condition else NaN | |
return tt.switch(p_cond, p_e, np.nan) | |
# Return value_1 if condition_1 else evaluate remaining conditions | |
p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs) | |
return tt.switch(p_cond, p_e, p_remaining) | |
def _print_Rational(self, expr, **kwargs): | |
return tt.true_div(self._print(expr.p, **kwargs), | |
self._print(expr.q, **kwargs)) | |
def _print_Integer(self, expr, **kwargs): | |
return expr.p | |
def _print_factorial(self, expr, **kwargs): | |
return self._print(sympy.gamma(expr.args[0] + 1), **kwargs) | |
def _print_Derivative(self, deriv, **kwargs): | |
rv = self._print(deriv.expr, **kwargs) | |
for var in deriv.variables: | |
var = self._print(var, **kwargs) | |
rv = tt.Rop(rv, var, tt.ones_like(var)) | |
return rv | |
def emptyPrinter(self, expr): | |
return expr | |
def doprint(self, expr, dtypes=None, broadcastables=None): | |
""" Convert a SymPy expression to a Theano graph variable. | |
The ``dtypes`` and ``broadcastables`` arguments are used to specify the | |
data type, dimension, and broadcasting behavior of the Theano variables | |
corresponding to the free symbols in ``expr``. Each is a mapping from | |
SymPy symbols to the value of the corresponding argument to | |
``theano.tensor.Tensor``. | |
See the corresponding `documentation page`__ for more information on | |
broadcasting in Theano. | |
.. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html | |
Parameters | |
========== | |
expr : sympy.core.expr.Expr | |
SymPy expression to print. | |
dtypes : dict | |
Mapping from SymPy symbols to Theano datatypes to use when creating | |
new Theano variables for those symbols. Corresponds to the ``dtype`` | |
argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'`` | |
for symbols not included in the mapping. | |
broadcastables : dict | |
Mapping from SymPy symbols to the value of the ``broadcastable`` | |
argument to ``theano.tensor.Tensor`` to use when creating Theano | |
variables for those symbols. Defaults to the empty tuple for symbols | |
not included in the mapping (resulting in a scalar). | |
Returns | |
======= | |
theano.gof.graph.Variable | |
A variable corresponding to the expression's value in a Theano | |
symbolic expression graph. | |
""" | |
if dtypes is None: | |
dtypes = {} | |
if broadcastables is None: | |
broadcastables = {} | |
return self._print(expr, dtypes=dtypes, broadcastables=broadcastables) | |
global_cache: dict[Any, Any] = {} | |
def theano_code(expr, cache=None, **kwargs): | |
""" | |
Convert a SymPy expression into a Theano graph variable. | |
.. deprecated:: 1.8 | |
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to | |
Aesara. Use ``sympy.printing.aesaracode`` instead. See | |
:ref:`theanocode-deprecated` for more information. | |
Parameters | |
========== | |
expr : sympy.core.expr.Expr | |
SymPy expression object to convert. | |
cache : dict | |
Cached Theano variables (see :class:`TheanoPrinter.cache | |
<TheanoPrinter>`). Defaults to the module-level global cache. | |
dtypes : dict | |
Passed to :meth:`.TheanoPrinter.doprint`. | |
broadcastables : dict | |
Passed to :meth:`.TheanoPrinter.doprint`. | |
Returns | |
======= | |
theano.gof.graph.Variable | |
A variable corresponding to the expression's value in a Theano symbolic | |
expression graph. | |
""" | |
sympy_deprecation_warning( | |
""" | |
sympy.printing.theanocode is deprecated. Theano has been renamed to | |
Aesara. Use sympy.printing.aesaracode instead.""", | |
deprecated_since_version="1.8", | |
active_deprecations_target='theanocode-deprecated') | |
if not theano: | |
raise ImportError("theano is required for theano_code") | |
if cache is None: | |
cache = global_cache | |
return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs) | |
def dim_handling(inputs, dim=None, dims=None, broadcastables=None): | |
r""" | |
Get value of ``broadcastables`` argument to :func:`.theano_code` from | |
keyword arguments to :func:`.theano_function`. | |
Included for backwards compatibility. | |
Parameters | |
========== | |
inputs | |
Sequence of input symbols. | |
dim : int | |
Common number of dimensions for all inputs. Overrides other arguments | |
if given. | |
dims : dict | |
Mapping from input symbols to number of dimensions. Overrides | |
``broadcastables`` argument if given. | |
broadcastables : dict | |
Explicit value of ``broadcastables`` argument to | |
:meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged. | |
Returns | |
======= | |
dict | |
Dictionary mapping elements of ``inputs`` to their "broadcastable" | |
values (tuple of ``bool``\ s). | |
""" | |
if dim is not None: | |
return dict.fromkeys(inputs, (False,) * dim) | |
if dims is not None: | |
maxdim = max(dims.values()) | |
return { | |
s: (False,) * d + (True,) * (maxdim - d) | |
for s, d in dims.items() | |
} | |
if broadcastables is not None: | |
return broadcastables | |
return {} | |
def theano_function(inputs, outputs, scalar=False, *, | |
dim=None, dims=None, broadcastables=None, **kwargs): | |
""" | |
Create a Theano function from SymPy expressions. | |
.. deprecated:: 1.8 | |
``sympy.printing.theanocode`` is deprecated. Theano has been renamed to | |
Aesara. Use ``sympy.printing.aesaracode`` instead. See | |
:ref:`theanocode-deprecated` for more information. | |
The inputs and outputs are converted to Theano variables using | |
:func:`.theano_code` and then passed to ``theano.function``. | |
Parameters | |
========== | |
inputs | |
Sequence of symbols which constitute the inputs of the function. | |
outputs | |
Sequence of expressions which constitute the outputs(s) of the | |
function. The free symbols of each expression must be a subset of | |
``inputs``. | |
scalar : bool | |
Convert 0-dimensional arrays in output to scalars. This will return a | |
Python wrapper function around the Theano function object. | |
cache : dict | |
Cached Theano variables (see :class:`TheanoPrinter.cache | |
<TheanoPrinter>`). Defaults to the module-level global cache. | |
dtypes : dict | |
Passed to :meth:`.TheanoPrinter.doprint`. | |
broadcastables : dict | |
Passed to :meth:`.TheanoPrinter.doprint`. | |
dims : dict | |
Alternative to ``broadcastables`` argument. Mapping from elements of | |
``inputs`` to integers indicating the dimension of their associated | |
arrays/tensors. Overrides ``broadcastables`` argument if given. | |
dim : int | |
Another alternative to the ``broadcastables`` argument. Common number of | |
dimensions to use for all arrays/tensors. | |
``theano_function([x, y], [...], dim=2)`` is equivalent to using | |
``broadcastables={x: (False, False), y: (False, False)}``. | |
Returns | |
======= | |
callable | |
A callable object which takes values of ``inputs`` as positional | |
arguments and returns an output array for each of the expressions | |
in ``outputs``. If ``outputs`` is a single expression the function will | |
return a Numpy array, if it is a list of multiple expressions the | |
function will return a list of arrays. See description of the ``squeeze`` | |
argument above for the behavior when a single output is passed in a list. | |
The returned object will either be an instance of | |
``theano.compile.function_module.Function`` or a Python wrapper | |
function around one. In both cases, the returned value will have a | |
``theano_function`` attribute which points to the return value of | |
``theano.function``. | |
Examples | |
======== | |
>>> from sympy.abc import x, y, z | |
>>> from sympy.printing.theanocode import theano_function | |
A simple function with one input and one output: | |
>>> f1 = theano_function([x], [x**2 - 1], scalar=True) | |
>>> f1(3) | |
8.0 | |
A function with multiple inputs and one output: | |
>>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True) | |
>>> f2(3, 4, 2) | |
5.0 | |
A function with multiple inputs and multiple outputs: | |
>>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True) | |
>>> f3(2, 3) | |
[13.0, -5.0] | |
See also | |
======== | |
dim_handling | |
""" | |
sympy_deprecation_warning( | |
""" | |
sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""", | |
deprecated_since_version="1.8", | |
active_deprecations_target='theanocode-deprecated') | |
if not theano: | |
raise ImportError("theano is required for theano_function") | |
# Pop off non-theano keyword args | |
cache = kwargs.pop('cache', {}) | |
dtypes = kwargs.pop('dtypes', {}) | |
broadcastables = dim_handling( | |
inputs, dim=dim, dims=dims, broadcastables=broadcastables, | |
) | |
# Print inputs/outputs | |
code = partial(theano_code, cache=cache, dtypes=dtypes, | |
broadcastables=broadcastables) | |
tinputs = list(map(code, inputs)) | |
toutputs = list(map(code, outputs)) | |
#fix constant expressions as variables | |
toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs] | |
if len(toutputs) == 1: | |
toutputs = toutputs[0] | |
# Compile theano func | |
func = theano.function(tinputs, toutputs, **kwargs) | |
is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs] | |
# No wrapper required | |
if not scalar or not any(is_0d): | |
func.theano_function = func | |
return func | |
# Create wrapper to convert 0-dimensional outputs to scalars | |
def wrapper(*args): | |
out = func(*args) | |
# out can be array(1.0) or [array(1.0), array(2.0)] | |
if is_sequence(out): | |
return [o[()] if is_0d[i] else o for i, o in enumerate(out)] | |
else: | |
return out[()] | |
wrapper.__wrapped__ = func | |
wrapper.__doc__ = func.__doc__ | |
wrapper.theano_function = func | |
return wrapper | |