Kano001's picture
Upload 3077 files
6a86ad5 verified
raw
history blame
12.5 kB
from collections.abc import Callable
from sympy.core.basic import Basic
from sympy.external import import_module
import sympy.plotting.backends.base_backend as base_backend
from sympy.printing.latex import latex
# N.B.
# When changing the minimum module version for matplotlib, please change
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
def _str_or_latex(label):
if isinstance(label, Basic):
return latex(label, mode='inline')
return str(label)
def _matplotlib_list(interval_list):
"""
Returns lists for matplotlib ``fill`` command from a list of bounding
rectangular intervals
"""
xlist = []
ylist = []
if len(interval_list):
for intervals in interval_list:
intervalx = intervals[0]
intervaly = intervals[1]
xlist.extend([intervalx.start, intervalx.start,
intervalx.end, intervalx.end, None])
ylist.extend([intervaly.start, intervaly.end,
intervaly.end, intervaly.start, None])
else:
#XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
xlist.extend((None, None, None, None))
ylist.extend((None, None, None, None))
return xlist, ylist
# Don't have to check for the success of importing matplotlib in each case;
# we will only be using this backend if we can successfully import matploblib
class MatplotlibBackend(base_backend.Plot):
""" This class implements the functionalities to use Matplotlib with SymPy
plotting functions.
"""
def __init__(self, *series, **kwargs):
super().__init__(*series, **kwargs)
self.matplotlib = import_module('matplotlib',
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
min_module_version='1.1.0', catch=(RuntimeError,))
self.plt = self.matplotlib.pyplot
self.cm = self.matplotlib.cm
self.LineCollection = self.matplotlib.collections.LineCollection
self.aspect = kwargs.get('aspect_ratio', 'auto')
if self.aspect != 'auto':
self.aspect = float(self.aspect[1]) / self.aspect[0]
# PlotGrid can provide its figure and axes to be populated with
# the data from the series.
self._plotgrid_fig = kwargs.pop("fig", None)
self._plotgrid_ax = kwargs.pop("ax", None)
def _create_figure(self):
def set_spines(ax):
ax.spines['left'].set_position('zero')
ax.spines['right'].set_color('none')
ax.spines['bottom'].set_position('zero')
ax.spines['top'].set_color('none')
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
if self._plotgrid_fig is not None:
self.fig = self._plotgrid_fig
self.ax = self._plotgrid_ax
if not any(s.is_3D for s in self._series):
set_spines(self.ax)
else:
self.fig = self.plt.figure(figsize=self.size)
if any(s.is_3D for s in self._series):
self.ax = self.fig.add_subplot(1, 1, 1, projection="3d")
else:
self.ax = self.fig.add_subplot(1, 1, 1)
set_spines(self.ax)
@staticmethod
def get_segments(x, y, z=None):
""" Convert two list of coordinates to a list of segments to be used
with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`.
Parameters
==========
x : list
List of x-coordinates
y : list
List of y-coordinates
z : list
List of z-coordinates for a 3D line.
"""
np = import_module('numpy')
if z is not None:
dim = 3
points = (x, y, z)
else:
dim = 2
points = (x, y)
points = np.ma.array(points).T.reshape(-1, 1, dim)
return np.ma.concatenate([points[:-1], points[1:]], axis=1)
def _process_series(self, series, ax):
np = import_module('numpy')
mpl_toolkits = import_module(
'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
# XXX Workaround for matplotlib issue
# https://github.com/matplotlib/matplotlib/issues/17130
xlims, ylims, zlims = [], [], []
for s in series:
# Create the collections
if s.is_2Dline:
if s.is_parametric:
x, y, param = s.get_data()
else:
x, y = s.get_data()
if (isinstance(s.line_color, (int, float)) or
callable(s.line_color)):
segments = self.get_segments(x, y)
collection = self.LineCollection(segments)
collection.set_array(s.get_color_array())
ax.add_collection(collection)
else:
lbl = _str_or_latex(s.label)
line, = ax.plot(x, y, label=lbl, color=s.line_color)
elif s.is_contour:
ax.contour(*s.get_data())
elif s.is_3Dline:
x, y, z, param = s.get_data()
if (isinstance(s.line_color, (int, float)) or
callable(s.line_color)):
art3d = mpl_toolkits.mplot3d.art3d
segments = self.get_segments(x, y, z)
collection = art3d.Line3DCollection(segments)
collection.set_array(s.get_color_array())
ax.add_collection(collection)
else:
lbl = _str_or_latex(s.label)
ax.plot(x, y, z, label=lbl, color=s.line_color)
xlims.append(s._xlim)
ylims.append(s._ylim)
zlims.append(s._zlim)
elif s.is_3Dsurface:
if s.is_parametric:
x, y, z, u, v = s.get_data()
else:
x, y, z = s.get_data()
collection = ax.plot_surface(x, y, z,
cmap=getattr(self.cm, 'viridis', self.cm.jet),
rstride=1, cstride=1, linewidth=0.1)
if isinstance(s.surface_color, (float, int, Callable)):
color_array = s.get_color_array()
color_array = color_array.reshape(color_array.size)
collection.set_array(color_array)
else:
collection.set_color(s.surface_color)
xlims.append(s._xlim)
ylims.append(s._ylim)
zlims.append(s._zlim)
elif s.is_implicit:
points = s.get_data()
if len(points) == 2:
# interval math plotting
x, y = _matplotlib_list(points[0])
ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
else:
# use contourf or contour depending on whether it is
# an inequality or equality.
# XXX: ``contour`` plots multiple lines. Should be fixed.
ListedColormap = self.matplotlib.colors.ListedColormap
colormap = ListedColormap(["white", s.line_color])
xarray, yarray, zarray, plot_type = points
if plot_type == 'contour':
ax.contour(xarray, yarray, zarray, cmap=colormap)
else:
ax.contourf(xarray, yarray, zarray, cmap=colormap)
elif s.is_generic:
if s.type == "markers":
# s.rendering_kw["color"] = s.line_color
ax.plot(*s.args, **s.rendering_kw)
elif s.type == "annotations":
ax.annotate(*s.args, **s.rendering_kw)
elif s.type == "fill":
# s.rendering_kw["color"] = s.line_color
ax.fill_between(*s.args, **s.rendering_kw)
elif s.type == "rectangles":
# s.rendering_kw["color"] = s.line_color
ax.add_patch(
self.matplotlib.patches.Rectangle(
*s.args, **s.rendering_kw))
else:
raise NotImplementedError(
'{} is not supported in the SymPy plotting module '
'with matplotlib backend. Please report this issue.'
.format(ax))
Axes3D = mpl_toolkits.mplot3d.Axes3D
if not isinstance(ax, Axes3D):
ax.autoscale_view(
scalex=ax.get_autoscalex_on(),
scaley=ax.get_autoscaley_on())
else:
# XXX Workaround for matplotlib issue
# https://github.com/matplotlib/matplotlib/issues/17130
if xlims:
xlims = np.array(xlims)
xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
ax.set_xlim(xlim)
else:
ax.set_xlim([0, 1])
if ylims:
ylims = np.array(ylims)
ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
ax.set_ylim(ylim)
else:
ax.set_ylim([0, 1])
if zlims:
zlims = np.array(zlims)
zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
ax.set_zlim(zlim)
else:
ax.set_zlim([0, 1])
# Set global options.
# TODO The 3D stuff
# XXX The order of those is important.
if self.xscale and not isinstance(ax, Axes3D):
ax.set_xscale(self.xscale)
if self.yscale and not isinstance(ax, Axes3D):
ax.set_yscale(self.yscale)
if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
ax.set_autoscale_on(self.autoscale)
if self.axis_center:
val = self.axis_center
if isinstance(ax, Axes3D):
pass
elif val == 'center':
ax.spines['left'].set_position('center')
ax.spines['bottom'].set_position('center')
elif val == 'auto':
xl, xh = ax.get_xlim()
yl, yh = ax.get_ylim()
pos_left = ('data', 0) if xl*xh <= 0 else 'center'
pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
ax.spines['left'].set_position(pos_left)
ax.spines['bottom'].set_position(pos_bottom)
else:
ax.spines['left'].set_position(('data', val[0]))
ax.spines['bottom'].set_position(('data', val[1]))
if not self.axis:
ax.set_axis_off()
if self.legend:
if ax.legend():
ax.legend_.set_visible(self.legend)
if self.margin:
ax.set_xmargin(self.margin)
ax.set_ymargin(self.margin)
if self.title:
ax.set_title(self.title)
if self.xlabel:
xlbl = _str_or_latex(self.xlabel)
ax.set_xlabel(xlbl, position=(1, 0))
if self.ylabel:
ylbl = _str_or_latex(self.ylabel)
ax.set_ylabel(ylbl, position=(0, 1))
if isinstance(ax, Axes3D) and self.zlabel:
zlbl = _str_or_latex(self.zlabel)
ax.set_zlabel(zlbl, position=(0, 1))
# xlim and ylim should always be set at last so that plot limits
# doesn't get altered during the process.
if self.xlim:
ax.set_xlim(self.xlim)
if self.ylim:
ax.set_ylim(self.ylim)
self.ax.set_aspect(self.aspect)
def process_series(self):
"""
Iterates over every ``Plot`` object and further calls
_process_series()
"""
self._create_figure()
self._process_series(self._series, self.ax)
def show(self):
self.process_series()
#TODO after fixing https://github.com/ipython/ipython/issues/1255
# you can uncomment the next line and remove the pyplot.show() call
#self.fig.show()
if base_backend._show:
self.fig.tight_layout()
self.plt.show()
else:
self.close()
def save(self, path):
self.process_series()
self.fig.savefig(path)
def close(self):
self.plt.close(self.fig)