Spaces:
Sleeping
Sleeping
from collections.abc import Callable | |
from sympy.core.basic import Basic | |
from sympy.external import import_module | |
import sympy.plotting.backends.base_backend as base_backend | |
from sympy.printing.latex import latex | |
# N.B. | |
# When changing the minimum module version for matplotlib, please change | |
# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` | |
def _str_or_latex(label): | |
if isinstance(label, Basic): | |
return latex(label, mode='inline') | |
return str(label) | |
def _matplotlib_list(interval_list): | |
""" | |
Returns lists for matplotlib ``fill`` command from a list of bounding | |
rectangular intervals | |
""" | |
xlist = [] | |
ylist = [] | |
if len(interval_list): | |
for intervals in interval_list: | |
intervalx = intervals[0] | |
intervaly = intervals[1] | |
xlist.extend([intervalx.start, intervalx.start, | |
intervalx.end, intervalx.end, None]) | |
ylist.extend([intervaly.start, intervaly.end, | |
intervaly.end, intervaly.start, None]) | |
else: | |
#XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` | |
xlist.extend((None, None, None, None)) | |
ylist.extend((None, None, None, None)) | |
return xlist, ylist | |
# Don't have to check for the success of importing matplotlib in each case; | |
# we will only be using this backend if we can successfully import matploblib | |
class MatplotlibBackend(base_backend.Plot): | |
""" This class implements the functionalities to use Matplotlib with SymPy | |
plotting functions. | |
""" | |
def __init__(self, *series, **kwargs): | |
super().__init__(*series, **kwargs) | |
self.matplotlib = import_module('matplotlib', | |
import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, | |
min_module_version='1.1.0', catch=(RuntimeError,)) | |
self.plt = self.matplotlib.pyplot | |
self.cm = self.matplotlib.cm | |
self.LineCollection = self.matplotlib.collections.LineCollection | |
self.aspect = kwargs.get('aspect_ratio', 'auto') | |
if self.aspect != 'auto': | |
self.aspect = float(self.aspect[1]) / self.aspect[0] | |
# PlotGrid can provide its figure and axes to be populated with | |
# the data from the series. | |
self._plotgrid_fig = kwargs.pop("fig", None) | |
self._plotgrid_ax = kwargs.pop("ax", None) | |
def _create_figure(self): | |
def set_spines(ax): | |
ax.spines['left'].set_position('zero') | |
ax.spines['right'].set_color('none') | |
ax.spines['bottom'].set_position('zero') | |
ax.spines['top'].set_color('none') | |
ax.xaxis.set_ticks_position('bottom') | |
ax.yaxis.set_ticks_position('left') | |
if self._plotgrid_fig is not None: | |
self.fig = self._plotgrid_fig | |
self.ax = self._plotgrid_ax | |
if not any(s.is_3D for s in self._series): | |
set_spines(self.ax) | |
else: | |
self.fig = self.plt.figure(figsize=self.size) | |
if any(s.is_3D for s in self._series): | |
self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") | |
else: | |
self.ax = self.fig.add_subplot(1, 1, 1) | |
set_spines(self.ax) | |
def get_segments(x, y, z=None): | |
""" Convert two list of coordinates to a list of segments to be used | |
with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. | |
Parameters | |
========== | |
x : list | |
List of x-coordinates | |
y : list | |
List of y-coordinates | |
z : list | |
List of z-coordinates for a 3D line. | |
""" | |
np = import_module('numpy') | |
if z is not None: | |
dim = 3 | |
points = (x, y, z) | |
else: | |
dim = 2 | |
points = (x, y) | |
points = np.ma.array(points).T.reshape(-1, 1, dim) | |
return np.ma.concatenate([points[:-1], points[1:]], axis=1) | |
def _process_series(self, series, ax): | |
np = import_module('numpy') | |
mpl_toolkits = import_module( | |
'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) | |
# XXX Workaround for matplotlib issue | |
# https://github.com/matplotlib/matplotlib/issues/17130 | |
xlims, ylims, zlims = [], [], [] | |
for s in series: | |
# Create the collections | |
if s.is_2Dline: | |
if s.is_parametric: | |
x, y, param = s.get_data() | |
else: | |
x, y = s.get_data() | |
if (isinstance(s.line_color, (int, float)) or | |
callable(s.line_color)): | |
segments = self.get_segments(x, y) | |
collection = self.LineCollection(segments) | |
collection.set_array(s.get_color_array()) | |
ax.add_collection(collection) | |
else: | |
lbl = _str_or_latex(s.label) | |
line, = ax.plot(x, y, label=lbl, color=s.line_color) | |
elif s.is_contour: | |
ax.contour(*s.get_data()) | |
elif s.is_3Dline: | |
x, y, z, param = s.get_data() | |
if (isinstance(s.line_color, (int, float)) or | |
callable(s.line_color)): | |
art3d = mpl_toolkits.mplot3d.art3d | |
segments = self.get_segments(x, y, z) | |
collection = art3d.Line3DCollection(segments) | |
collection.set_array(s.get_color_array()) | |
ax.add_collection(collection) | |
else: | |
lbl = _str_or_latex(s.label) | |
ax.plot(x, y, z, label=lbl, color=s.line_color) | |
xlims.append(s._xlim) | |
ylims.append(s._ylim) | |
zlims.append(s._zlim) | |
elif s.is_3Dsurface: | |
if s.is_parametric: | |
x, y, z, u, v = s.get_data() | |
else: | |
x, y, z = s.get_data() | |
collection = ax.plot_surface(x, y, z, | |
cmap=getattr(self.cm, 'viridis', self.cm.jet), | |
rstride=1, cstride=1, linewidth=0.1) | |
if isinstance(s.surface_color, (float, int, Callable)): | |
color_array = s.get_color_array() | |
color_array = color_array.reshape(color_array.size) | |
collection.set_array(color_array) | |
else: | |
collection.set_color(s.surface_color) | |
xlims.append(s._xlim) | |
ylims.append(s._ylim) | |
zlims.append(s._zlim) | |
elif s.is_implicit: | |
points = s.get_data() | |
if len(points) == 2: | |
# interval math plotting | |
x, y = _matplotlib_list(points[0]) | |
ax.fill(x, y, facecolor=s.line_color, edgecolor='None') | |
else: | |
# use contourf or contour depending on whether it is | |
# an inequality or equality. | |
# XXX: ``contour`` plots multiple lines. Should be fixed. | |
ListedColormap = self.matplotlib.colors.ListedColormap | |
colormap = ListedColormap(["white", s.line_color]) | |
xarray, yarray, zarray, plot_type = points | |
if plot_type == 'contour': | |
ax.contour(xarray, yarray, zarray, cmap=colormap) | |
else: | |
ax.contourf(xarray, yarray, zarray, cmap=colormap) | |
elif s.is_generic: | |
if s.type == "markers": | |
# s.rendering_kw["color"] = s.line_color | |
ax.plot(*s.args, **s.rendering_kw) | |
elif s.type == "annotations": | |
ax.annotate(*s.args, **s.rendering_kw) | |
elif s.type == "fill": | |
# s.rendering_kw["color"] = s.line_color | |
ax.fill_between(*s.args, **s.rendering_kw) | |
elif s.type == "rectangles": | |
# s.rendering_kw["color"] = s.line_color | |
ax.add_patch( | |
self.matplotlib.patches.Rectangle( | |
*s.args, **s.rendering_kw)) | |
else: | |
raise NotImplementedError( | |
'{} is not supported in the SymPy plotting module ' | |
'with matplotlib backend. Please report this issue.' | |
.format(ax)) | |
Axes3D = mpl_toolkits.mplot3d.Axes3D | |
if not isinstance(ax, Axes3D): | |
ax.autoscale_view( | |
scalex=ax.get_autoscalex_on(), | |
scaley=ax.get_autoscaley_on()) | |
else: | |
# XXX Workaround for matplotlib issue | |
# https://github.com/matplotlib/matplotlib/issues/17130 | |
if xlims: | |
xlims = np.array(xlims) | |
xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) | |
ax.set_xlim(xlim) | |
else: | |
ax.set_xlim([0, 1]) | |
if ylims: | |
ylims = np.array(ylims) | |
ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) | |
ax.set_ylim(ylim) | |
else: | |
ax.set_ylim([0, 1]) | |
if zlims: | |
zlims = np.array(zlims) | |
zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) | |
ax.set_zlim(zlim) | |
else: | |
ax.set_zlim([0, 1]) | |
# Set global options. | |
# TODO The 3D stuff | |
# XXX The order of those is important. | |
if self.xscale and not isinstance(ax, Axes3D): | |
ax.set_xscale(self.xscale) | |
if self.yscale and not isinstance(ax, Axes3D): | |
ax.set_yscale(self.yscale) | |
if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check | |
ax.set_autoscale_on(self.autoscale) | |
if self.axis_center: | |
val = self.axis_center | |
if isinstance(ax, Axes3D): | |
pass | |
elif val == 'center': | |
ax.spines['left'].set_position('center') | |
ax.spines['bottom'].set_position('center') | |
elif val == 'auto': | |
xl, xh = ax.get_xlim() | |
yl, yh = ax.get_ylim() | |
pos_left = ('data', 0) if xl*xh <= 0 else 'center' | |
pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' | |
ax.spines['left'].set_position(pos_left) | |
ax.spines['bottom'].set_position(pos_bottom) | |
else: | |
ax.spines['left'].set_position(('data', val[0])) | |
ax.spines['bottom'].set_position(('data', val[1])) | |
if not self.axis: | |
ax.set_axis_off() | |
if self.legend: | |
if ax.legend(): | |
ax.legend_.set_visible(self.legend) | |
if self.margin: | |
ax.set_xmargin(self.margin) | |
ax.set_ymargin(self.margin) | |
if self.title: | |
ax.set_title(self.title) | |
if self.xlabel: | |
xlbl = _str_or_latex(self.xlabel) | |
ax.set_xlabel(xlbl, position=(1, 0)) | |
if self.ylabel: | |
ylbl = _str_or_latex(self.ylabel) | |
ax.set_ylabel(ylbl, position=(0, 1)) | |
if isinstance(ax, Axes3D) and self.zlabel: | |
zlbl = _str_or_latex(self.zlabel) | |
ax.set_zlabel(zlbl, position=(0, 1)) | |
# xlim and ylim should always be set at last so that plot limits | |
# doesn't get altered during the process. | |
if self.xlim: | |
ax.set_xlim(self.xlim) | |
if self.ylim: | |
ax.set_ylim(self.ylim) | |
self.ax.set_aspect(self.aspect) | |
def process_series(self): | |
""" | |
Iterates over every ``Plot`` object and further calls | |
_process_series() | |
""" | |
self._create_figure() | |
self._process_series(self._series, self.ax) | |
def show(self): | |
self.process_series() | |
#TODO after fixing https://github.com/ipython/ipython/issues/1255 | |
# you can uncomment the next line and remove the pyplot.show() call | |
#self.fig.show() | |
if base_backend._show: | |
self.fig.tight_layout() | |
self.plt.show() | |
else: | |
self.close() | |
def save(self, path): | |
self.process_series() | |
self.fig.savefig(path) | |
def close(self): | |
self.plt.close(self.fig) | |