Spaces:
Sleeping
Sleeping
from sympy.core.containers import Tuple | |
from sympy.core.basic import Basic | |
from sympy.core.expr import Expr | |
from sympy.core.function import AppliedUndef | |
from sympy.core.relational import Relational | |
from sympy.core.symbol import Dummy | |
from sympy.core.sympify import sympify | |
from sympy.logic.boolalg import BooleanFunction | |
from sympy.sets.fancysets import ImageSet | |
from sympy.sets.sets import FiniteSet | |
from sympy.tensor.indexed import Indexed | |
def _get_free_symbols(exprs): | |
"""Returns the free symbols of a symbolic expression. | |
If the expression contains any of these elements, assume that they are | |
the "free symbols" of the expression: | |
* indexed objects | |
* applied undefined function (useful for sympy.physics.mechanics module) | |
""" | |
if not isinstance(exprs, (list, tuple, set)): | |
exprs = [exprs] | |
if all(callable(e) for e in exprs): | |
return set() | |
free = set().union(*[e.atoms(Indexed) for e in exprs]) | |
free = free.union(*[e.atoms(AppliedUndef) for e in exprs]) | |
return free or set().union(*[e.free_symbols for e in exprs]) | |
def extract_solution(set_sol, n=10): | |
"""Extract numerical solutions from a set solution (computed by solveset, | |
linsolve, nonlinsolve). Often, it is not trivial do get something useful | |
out of them. | |
Parameters | |
========== | |
n : int, optional | |
In order to replace ImageSet with FiniteSet, an iterator is created | |
for each ImageSet contained in `set_sol`, starting from 0 up to `n`. | |
Default value: 10. | |
""" | |
images = set_sol.find(ImageSet) | |
for im in images: | |
it = iter(im) | |
s = FiniteSet(*[next(it) for n in range(0, n)]) | |
set_sol = set_sol.subs(im, s) | |
return set_sol | |
def _plot_sympify(args): | |
"""This function recursively loop over the arguments passed to the plot | |
functions: the sympify function will be applied to all arguments except | |
those of type string/dict. | |
Generally, users can provide the following arguments to a plot function: | |
expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt] | |
`expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't. | |
In particular, whenever a special character like $, {, }, ... is used in | |
the `label`, sympify will raise an error. | |
""" | |
if isinstance(args, Expr): | |
return args | |
args = list(args) | |
for i, a in enumerate(args): | |
if isinstance(a, (list, tuple)): | |
args[i] = Tuple(*_plot_sympify(a), sympify=False) | |
elif not (isinstance(a, (str, dict)) or callable(a) | |
# NOTE: check if it is a vector from sympy.physics.vector module | |
# without importing the module (because it slows down SymPy's | |
# import process and triggers SymPy's optional-dependencies | |
# tests to fail). | |
or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic)) | |
): | |
args[i] = sympify(a) | |
return args | |
def _create_ranges(exprs, ranges, npar, label="", params=None): | |
"""This function does two things: | |
1. Check if the number of free symbols is in agreement with the type of | |
plot chosen. For example, plot() requires 1 free symbol; | |
plot3d() requires 2 free symbols. | |
2. Sometime users create plots without providing ranges for the variables. | |
Here we create the necessary ranges. | |
Parameters | |
========== | |
exprs : iterable | |
The expressions from which to extract the free symbols | |
ranges : iterable | |
The limiting ranges provided by the user | |
npar : int | |
The number of free symbols required by the plot functions. | |
For example, | |
npar=1 for plot, npar=2 for plot3d, ... | |
params : dict | |
A dictionary mapping symbols to parameters for interactive plot. | |
""" | |
get_default_range = lambda symbol: Tuple(symbol, -10, 10) | |
free_symbols = _get_free_symbols(exprs) | |
if params is not None: | |
free_symbols = free_symbols.difference(params.keys()) | |
if len(free_symbols) > npar: | |
raise ValueError( | |
"Too many free symbols.\n" | |
+ "Expected {} free symbols.\n".format(npar) | |
+ "Received {}: {}".format(len(free_symbols), free_symbols) | |
) | |
if len(ranges) > npar: | |
raise ValueError( | |
"Too many ranges. Received %s, expected %s" % (len(ranges), npar)) | |
# free symbols in the ranges provided by the user | |
rfs = set().union([r[0] for r in ranges]) | |
if len(rfs) != len(ranges): | |
raise ValueError("Multiple ranges with the same symbol") | |
if len(ranges) < npar: | |
symbols = free_symbols.difference(rfs) | |
if symbols != set(): | |
# add a range for each missing free symbols | |
for s in symbols: | |
ranges.append(get_default_range(s)) | |
# if there is still room, fill them with dummys | |
for i in range(npar - len(ranges)): | |
ranges.append(get_default_range(Dummy())) | |
if len(free_symbols) == npar: | |
# there could be times when this condition is not met, for example | |
# plotting the function f(x, y) = x (which is a plane); in this case, | |
# free_symbols = {x} whereas rfs = {x, y} (or x and Dummy) | |
rfs = set().union([r[0] for r in ranges]) | |
if len(free_symbols.difference(rfs)) > 0: | |
raise ValueError( | |
"Incompatible free symbols of the expressions with " | |
"the ranges.\n" | |
+ "Free symbols in the expressions: {}\n".format(free_symbols) | |
+ "Free symbols in the ranges: {}".format(rfs) | |
) | |
return ranges | |
def _is_range(r): | |
"""A range is defined as (symbol, start, end). start and end should | |
be numbers. | |
""" | |
# TODO: prange check goes here | |
return ( | |
isinstance(r, Tuple) | |
and (len(r) == 3) | |
and (not isinstance(r.args[1], str)) and r.args[1].is_number | |
and (not isinstance(r.args[2], str)) and r.args[2].is_number | |
) | |
def _unpack_args(*args): | |
"""Given a list/tuple of arguments previously processed by _plot_sympify() | |
and/or _check_arguments(), separates and returns its components: | |
expressions, ranges, label and rendering keywords. | |
Examples | |
======== | |
>>> from sympy import cos, sin, symbols | |
>>> from sympy.plotting.utils import _plot_sympify, _unpack_args | |
>>> x, y = symbols('x, y') | |
>>> args = (sin(x), (x, -10, 10), "f1") | |
>>> args = _plot_sympify(args) | |
>>> _unpack_args(*args) | |
([sin(x)], [(x, -10, 10)], 'f1', None) | |
>>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2") | |
>>> args = _plot_sympify(args) | |
>>> _unpack_args(*args) | |
([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None) | |
>>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3") | |
>>> args = _plot_sympify(args) | |
>>> _unpack_args(*args) | |
([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None) | |
""" | |
ranges = [t for t in args if _is_range(t)] | |
labels = [t for t in args if isinstance(t, str)] | |
label = None if not labels else labels[0] | |
rendering_kw = [t for t in args if isinstance(t, dict)] | |
rendering_kw = None if not rendering_kw else rendering_kw[0] | |
# NOTE: why None? because args might have been preprocessed by | |
# _check_arguments, so None might represent the rendering_kw | |
results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args] | |
exprs = [a for a, b in zip(args, results) if b] | |
return exprs, ranges, label, rendering_kw | |
def _check_arguments(args, nexpr, npar, **kwargs): | |
"""Checks the arguments and converts into tuples of the | |
form (exprs, ranges, label, rendering_kw). | |
Parameters | |
========== | |
args | |
The arguments provided to the plot functions | |
nexpr | |
The number of sub-expression forming an expression to be plotted. | |
For example: | |
nexpr=1 for plot. | |
nexpr=2 for plot_parametric: a curve is represented by a tuple of two | |
elements. | |
nexpr=1 for plot3d. | |
nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple | |
of three elements. | |
npar | |
The number of free symbols required by the plot functions. For example, | |
npar=1 for plot, npar=2 for plot3d, ... | |
**kwargs : | |
keyword arguments passed to the plotting function. It will be used to | |
verify if ``params`` has ben provided. | |
Examples | |
======== | |
.. plot:: | |
:context: reset | |
:format: doctest | |
:include-source: True | |
>>> from sympy import cos, sin, symbols | |
>>> from sympy.plotting.plot import _check_arguments | |
>>> x = symbols('x') | |
>>> _check_arguments([cos(x), sin(x)], 2, 1) | |
[(cos(x), sin(x), (x, -10, 10), None, None)] | |
>>> _check_arguments([cos(x), sin(x), "test"], 2, 1) | |
[(cos(x), sin(x), (x, -10, 10), 'test', None)] | |
>>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1) | |
[(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})] | |
>>> _check_arguments([x, x**2], 1, 1) | |
[(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)] | |
""" | |
if not args: | |
return [] | |
output = [] | |
params = kwargs.get("params", None) | |
if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]): | |
# In this case, with a single plot command, we are plotting either: | |
# 1. one expression | |
# 2. multiple expressions over the same range | |
exprs, ranges, label, rendering_kw = _unpack_args(*args) | |
free_symbols = set().union(*[e.free_symbols for e in exprs]) | |
ranges = _create_ranges(exprs, ranges, npar, label, params) | |
if nexpr > 1: | |
# in case of plot_parametric or plot3d_parametric_line, there will | |
# be 2 or 3 expressions defining a curve. Group them together. | |
if len(exprs) == nexpr: | |
exprs = (tuple(exprs),) | |
for expr in exprs: | |
# need this if-else to deal with both plot/plot3d and | |
# plot_parametric/plot3d_parametric_line | |
is_expr = isinstance(expr, (Expr, Relational, BooleanFunction)) | |
e = (expr,) if is_expr else expr | |
output.append((*e, *ranges, label, rendering_kw)) | |
else: | |
# In this case, we are plotting multiple expressions, each one with its | |
# range. Each "expression" to be plotted has the following form: | |
# (expr, range, label) where label is optional | |
_, ranges, labels, rendering_kw = _unpack_args(*args) | |
labels = [labels] if labels else [] | |
# number of expressions | |
n = (len(ranges) + len(labels) + | |
(len(rendering_kw) if rendering_kw is not None else 0)) | |
new_args = args[:-n] if n > 0 else args | |
# at this point, new_args might just be [expr]. But I need it to be | |
# [[expr]] in order to be able to loop over | |
# [expr, range [opt], label [opt]] | |
if not isinstance(new_args[0], (list, tuple, Tuple)): | |
new_args = [new_args] | |
# Each arg has the form (expr1, expr2, ..., range1 [optional], ..., | |
# label [optional], rendering_kw [optional]) | |
for arg in new_args: | |
# look for "local" range and label. If there is not, use "global". | |
l = [a for a in arg if isinstance(a, str)] | |
if not l: | |
l = labels | |
r = [a for a in arg if _is_range(a)] | |
if not r: | |
r = ranges.copy() | |
rend_kw = [a for a in arg if isinstance(a, dict)] | |
rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0] | |
# NOTE: arg = arg[:nexpr] may raise an exception if lambda | |
# functions are used. Execute the following instead: | |
arg = [arg[i] for i in range(nexpr)] | |
free_symbols = set() | |
if all(not callable(a) for a in arg): | |
free_symbols = free_symbols.union(*[a.free_symbols for a in arg]) | |
if len(r) != npar: | |
r = _create_ranges(arg, r, npar, "", params) | |
label = None if not l else l[0] | |
output.append((*arg, *r, label, rend_kw)) | |
return output | |