Spaces:
Sleeping
Sleeping
| from collections.abc import Callable | |
| from sympy.core.basic import Basic | |
| from sympy.external import import_module | |
| import sympy.plotting.backends.base_backend as base_backend | |
| from sympy.printing.latex import latex | |
| # N.B. | |
| # When changing the minimum module version for matplotlib, please change | |
| # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` | |
| def _str_or_latex(label): | |
| if isinstance(label, Basic): | |
| return latex(label, mode='inline') | |
| return str(label) | |
| def _matplotlib_list(interval_list): | |
| """ | |
| Returns lists for matplotlib ``fill`` command from a list of bounding | |
| rectangular intervals | |
| """ | |
| xlist = [] | |
| ylist = [] | |
| if len(interval_list): | |
| for intervals in interval_list: | |
| intervalx = intervals[0] | |
| intervaly = intervals[1] | |
| xlist.extend([intervalx.start, intervalx.start, | |
| intervalx.end, intervalx.end, None]) | |
| ylist.extend([intervaly.start, intervaly.end, | |
| intervaly.end, intervaly.start, None]) | |
| else: | |
| #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` | |
| xlist.extend((None, None, None, None)) | |
| ylist.extend((None, None, None, None)) | |
| return xlist, ylist | |
| # Don't have to check for the success of importing matplotlib in each case; | |
| # we will only be using this backend if we can successfully import matploblib | |
| class MatplotlibBackend(base_backend.Plot): | |
| """ This class implements the functionalities to use Matplotlib with SymPy | |
| plotting functions. | |
| """ | |
| def __init__(self, *series, **kwargs): | |
| super().__init__(*series, **kwargs) | |
| self.matplotlib = import_module('matplotlib', | |
| import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, | |
| min_module_version='1.1.0', catch=(RuntimeError,)) | |
| self.plt = self.matplotlib.pyplot | |
| self.cm = self.matplotlib.cm | |
| self.LineCollection = self.matplotlib.collections.LineCollection | |
| self.aspect = kwargs.get('aspect_ratio', 'auto') | |
| if self.aspect != 'auto': | |
| self.aspect = float(self.aspect[1]) / self.aspect[0] | |
| # PlotGrid can provide its figure and axes to be populated with | |
| # the data from the series. | |
| self._plotgrid_fig = kwargs.pop("fig", None) | |
| self._plotgrid_ax = kwargs.pop("ax", None) | |
| def _create_figure(self): | |
| def set_spines(ax): | |
| ax.spines['left'].set_position('zero') | |
| ax.spines['right'].set_color('none') | |
| ax.spines['bottom'].set_position('zero') | |
| ax.spines['top'].set_color('none') | |
| ax.xaxis.set_ticks_position('bottom') | |
| ax.yaxis.set_ticks_position('left') | |
| if self._plotgrid_fig is not None: | |
| self.fig = self._plotgrid_fig | |
| self.ax = self._plotgrid_ax | |
| if not any(s.is_3D for s in self._series): | |
| set_spines(self.ax) | |
| else: | |
| self.fig = self.plt.figure(figsize=self.size) | |
| if any(s.is_3D for s in self._series): | |
| self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") | |
| else: | |
| self.ax = self.fig.add_subplot(1, 1, 1) | |
| set_spines(self.ax) | |
| def get_segments(x, y, z=None): | |
| """ Convert two list of coordinates to a list of segments to be used | |
| with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. | |
| Parameters | |
| ========== | |
| x : list | |
| List of x-coordinates | |
| y : list | |
| List of y-coordinates | |
| z : list | |
| List of z-coordinates for a 3D line. | |
| """ | |
| np = import_module('numpy') | |
| if z is not None: | |
| dim = 3 | |
| points = (x, y, z) | |
| else: | |
| dim = 2 | |
| points = (x, y) | |
| points = np.ma.array(points).T.reshape(-1, 1, dim) | |
| return np.ma.concatenate([points[:-1], points[1:]], axis=1) | |
| def _process_series(self, series, ax): | |
| np = import_module('numpy') | |
| mpl_toolkits = import_module( | |
| 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) | |
| # XXX Workaround for matplotlib issue | |
| # https://github.com/matplotlib/matplotlib/issues/17130 | |
| xlims, ylims, zlims = [], [], [] | |
| for s in series: | |
| # Create the collections | |
| if s.is_2Dline: | |
| if s.is_parametric: | |
| x, y, param = s.get_data() | |
| else: | |
| x, y = s.get_data() | |
| if (isinstance(s.line_color, (int, float)) or | |
| callable(s.line_color)): | |
| segments = self.get_segments(x, y) | |
| collection = self.LineCollection(segments) | |
| collection.set_array(s.get_color_array()) | |
| ax.add_collection(collection) | |
| else: | |
| lbl = _str_or_latex(s.label) | |
| line, = ax.plot(x, y, label=lbl, color=s.line_color) | |
| elif s.is_contour: | |
| ax.contour(*s.get_data()) | |
| elif s.is_3Dline: | |
| x, y, z, param = s.get_data() | |
| if (isinstance(s.line_color, (int, float)) or | |
| callable(s.line_color)): | |
| art3d = mpl_toolkits.mplot3d.art3d | |
| segments = self.get_segments(x, y, z) | |
| collection = art3d.Line3DCollection(segments) | |
| collection.set_array(s.get_color_array()) | |
| ax.add_collection(collection) | |
| else: | |
| lbl = _str_or_latex(s.label) | |
| ax.plot(x, y, z, label=lbl, color=s.line_color) | |
| xlims.append(s._xlim) | |
| ylims.append(s._ylim) | |
| zlims.append(s._zlim) | |
| elif s.is_3Dsurface: | |
| if s.is_parametric: | |
| x, y, z, u, v = s.get_data() | |
| else: | |
| x, y, z = s.get_data() | |
| collection = ax.plot_surface(x, y, z, | |
| cmap=getattr(self.cm, 'viridis', self.cm.jet), | |
| rstride=1, cstride=1, linewidth=0.1) | |
| if isinstance(s.surface_color, (float, int, Callable)): | |
| color_array = s.get_color_array() | |
| color_array = color_array.reshape(color_array.size) | |
| collection.set_array(color_array) | |
| else: | |
| collection.set_color(s.surface_color) | |
| xlims.append(s._xlim) | |
| ylims.append(s._ylim) | |
| zlims.append(s._zlim) | |
| elif s.is_implicit: | |
| points = s.get_data() | |
| if len(points) == 2: | |
| # interval math plotting | |
| x, y = _matplotlib_list(points[0]) | |
| ax.fill(x, y, facecolor=s.line_color, edgecolor='None') | |
| else: | |
| # use contourf or contour depending on whether it is | |
| # an inequality or equality. | |
| # XXX: ``contour`` plots multiple lines. Should be fixed. | |
| ListedColormap = self.matplotlib.colors.ListedColormap | |
| colormap = ListedColormap(["white", s.line_color]) | |
| xarray, yarray, zarray, plot_type = points | |
| if plot_type == 'contour': | |
| ax.contour(xarray, yarray, zarray, cmap=colormap) | |
| else: | |
| ax.contourf(xarray, yarray, zarray, cmap=colormap) | |
| elif s.is_generic: | |
| if s.type == "markers": | |
| # s.rendering_kw["color"] = s.line_color | |
| ax.plot(*s.args, **s.rendering_kw) | |
| elif s.type == "annotations": | |
| ax.annotate(*s.args, **s.rendering_kw) | |
| elif s.type == "fill": | |
| # s.rendering_kw["color"] = s.line_color | |
| ax.fill_between(*s.args, **s.rendering_kw) | |
| elif s.type == "rectangles": | |
| # s.rendering_kw["color"] = s.line_color | |
| ax.add_patch( | |
| self.matplotlib.patches.Rectangle( | |
| *s.args, **s.rendering_kw)) | |
| else: | |
| raise NotImplementedError( | |
| '{} is not supported in the SymPy plotting module ' | |
| 'with matplotlib backend. Please report this issue.' | |
| .format(ax)) | |
| Axes3D = mpl_toolkits.mplot3d.Axes3D | |
| if not isinstance(ax, Axes3D): | |
| ax.autoscale_view( | |
| scalex=ax.get_autoscalex_on(), | |
| scaley=ax.get_autoscaley_on()) | |
| else: | |
| # XXX Workaround for matplotlib issue | |
| # https://github.com/matplotlib/matplotlib/issues/17130 | |
| if xlims: | |
| xlims = np.array(xlims) | |
| xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) | |
| ax.set_xlim(xlim) | |
| else: | |
| ax.set_xlim([0, 1]) | |
| if ylims: | |
| ylims = np.array(ylims) | |
| ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) | |
| ax.set_ylim(ylim) | |
| else: | |
| ax.set_ylim([0, 1]) | |
| if zlims: | |
| zlims = np.array(zlims) | |
| zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) | |
| ax.set_zlim(zlim) | |
| else: | |
| ax.set_zlim([0, 1]) | |
| # Set global options. | |
| # TODO The 3D stuff | |
| # XXX The order of those is important. | |
| if self.xscale and not isinstance(ax, Axes3D): | |
| ax.set_xscale(self.xscale) | |
| if self.yscale and not isinstance(ax, Axes3D): | |
| ax.set_yscale(self.yscale) | |
| if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check | |
| ax.set_autoscale_on(self.autoscale) | |
| if self.axis_center: | |
| val = self.axis_center | |
| if isinstance(ax, Axes3D): | |
| pass | |
| elif val == 'center': | |
| ax.spines['left'].set_position('center') | |
| ax.spines['bottom'].set_position('center') | |
| elif val == 'auto': | |
| xl, xh = ax.get_xlim() | |
| yl, yh = ax.get_ylim() | |
| pos_left = ('data', 0) if xl*xh <= 0 else 'center' | |
| pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' | |
| ax.spines['left'].set_position(pos_left) | |
| ax.spines['bottom'].set_position(pos_bottom) | |
| else: | |
| ax.spines['left'].set_position(('data', val[0])) | |
| ax.spines['bottom'].set_position(('data', val[1])) | |
| if not self.axis: | |
| ax.set_axis_off() | |
| if self.legend: | |
| if ax.legend(): | |
| ax.legend_.set_visible(self.legend) | |
| if self.margin: | |
| ax.set_xmargin(self.margin) | |
| ax.set_ymargin(self.margin) | |
| if self.title: | |
| ax.set_title(self.title) | |
| if self.xlabel: | |
| xlbl = _str_or_latex(self.xlabel) | |
| ax.set_xlabel(xlbl, position=(1, 0)) | |
| if self.ylabel: | |
| ylbl = _str_or_latex(self.ylabel) | |
| ax.set_ylabel(ylbl, position=(0, 1)) | |
| if isinstance(ax, Axes3D) and self.zlabel: | |
| zlbl = _str_or_latex(self.zlabel) | |
| ax.set_zlabel(zlbl, position=(0, 1)) | |
| # xlim and ylim should always be set at last so that plot limits | |
| # doesn't get altered during the process. | |
| if self.xlim: | |
| ax.set_xlim(self.xlim) | |
| if self.ylim: | |
| ax.set_ylim(self.ylim) | |
| self.ax.set_aspect(self.aspect) | |
| def process_series(self): | |
| """ | |
| Iterates over every ``Plot`` object and further calls | |
| _process_series() | |
| """ | |
| self._create_figure() | |
| self._process_series(self._series, self.ax) | |
| def show(self): | |
| self.process_series() | |
| #TODO after fixing https://github.com/ipython/ipython/issues/1255 | |
| # you can uncomment the next line and remove the pyplot.show() call | |
| #self.fig.show() | |
| if base_backend._show: | |
| self.fig.tight_layout() | |
| self.plt.show() | |
| else: | |
| self.close() | |
| def save(self, path): | |
| self.process_series() | |
| self.fig.savefig(path) | |
| def close(self): | |
| self.plt.close(self.fig) | |