|
""" |
|
Axislines includes modified implementation of the Axes class. The |
|
biggest difference is that the artists responsible for drawing the axis spine, |
|
ticks, ticklabels and axis labels are separated out from Matplotlib's Axis |
|
class. Originally, this change was motivated to support curvilinear |
|
grid. Here are a few reasons that I came up with a new axes class: |
|
|
|
* "top" and "bottom" x-axis (or "left" and "right" y-axis) can have |
|
different ticks (tick locations and labels). This is not possible |
|
with the current Matplotlib, although some twin axes trick can help. |
|
|
|
* Curvilinear grid. |
|
|
|
* angled ticks. |
|
|
|
In the new axes class, xaxis and yaxis is set to not visible by |
|
default, and new set of artist (AxisArtist) are defined to draw axis |
|
line, ticks, ticklabels and axis label. Axes.axis attribute serves as |
|
a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist |
|
instance responsible to draw left y-axis. The default Axes.axis contains |
|
"bottom", "left", "top" and "right". |
|
|
|
AxisArtist can be considered as a container artist and has the following |
|
children artists which will draw ticks, labels, etc. |
|
|
|
* line |
|
* major_ticks, major_ticklabels |
|
* minor_ticks, minor_ticklabels |
|
* offsetText |
|
* label |
|
|
|
Note that these are separate artists from `matplotlib.axis.Axis`, thus most |
|
tick-related functions in Matplotlib won't work. For example, color and |
|
markerwidth of the ``ax.axis["bottom"].major_ticks`` will follow those of |
|
Axes.xaxis unless explicitly specified. |
|
|
|
In addition to AxisArtist, the Axes will have *gridlines* attribute, |
|
which obviously draws grid lines. The gridlines needs to be separated |
|
from the axis as some gridlines can never pass any axis. |
|
""" |
|
|
|
import numpy as np |
|
|
|
import matplotlib as mpl |
|
from matplotlib import _api |
|
import matplotlib.axes as maxes |
|
from matplotlib.path import Path |
|
from mpl_toolkits.axes_grid1 import mpl_axes |
|
from .axisline_style import AxislineStyle |
|
from .axis_artist import AxisArtist, GridlinesCollection |
|
|
|
|
|
class _AxisArtistHelperBase: |
|
""" |
|
Base class for axis helper. |
|
|
|
Subclasses should define the methods listed below. The *axes* |
|
argument will be the ``.axes`` attribute of the caller artist. :: |
|
|
|
# Construct the spine. |
|
|
|
def get_line_transform(self, axes): |
|
return transform |
|
|
|
def get_line(self, axes): |
|
return path |
|
|
|
# Construct the label. |
|
|
|
def get_axislabel_transform(self, axes): |
|
return transform |
|
|
|
def get_axislabel_pos_angle(self, axes): |
|
return (x, y), angle |
|
|
|
# Construct the ticks. |
|
|
|
def get_tick_transform(self, axes): |
|
return transform |
|
|
|
def get_tick_iterators(self, axes): |
|
# A pair of iterables (one for major ticks, one for minor ticks) |
|
# that yield (tick_position, tick_angle, tick_label). |
|
return iter_major, iter_minor |
|
""" |
|
|
|
def update_lim(self, axes): |
|
pass |
|
|
|
def _to_xy(self, values, const): |
|
""" |
|
Create a (*values.shape, 2)-shape array representing (x, y) pairs. |
|
|
|
The other coordinate is filled with the constant *const*. |
|
|
|
Example:: |
|
|
|
>>> self.nth_coord = 0 |
|
>>> self._to_xy([1, 2, 3], const=0) |
|
array([[1, 0], |
|
[2, 0], |
|
[3, 0]]) |
|
""" |
|
if self.nth_coord == 0: |
|
return np.stack(np.broadcast_arrays(values, const), axis=-1) |
|
elif self.nth_coord == 1: |
|
return np.stack(np.broadcast_arrays(const, values), axis=-1) |
|
else: |
|
raise ValueError("Unexpected nth_coord") |
|
|
|
|
|
class _FixedAxisArtistHelperBase(_AxisArtistHelperBase): |
|
"""Helper class for a fixed (in the axes coordinate) axis.""" |
|
|
|
passthru_pt = _api.deprecated("3.7")(property( |
|
lambda self: {"left": (0, 0), "right": (1, 0), |
|
"bottom": (0, 0), "top": (0, 1)}[self._loc])) |
|
|
|
def __init__(self, loc, nth_coord=None): |
|
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis.""" |
|
self.nth_coord = ( |
|
nth_coord if nth_coord is not None else |
|
_api.check_getitem( |
|
{"bottom": 0, "top": 0, "left": 1, "right": 1}, loc=loc)) |
|
if (nth_coord == 0 and loc not in ["left", "right"] |
|
or nth_coord == 1 and loc not in ["bottom", "top"]): |
|
_api.warn_deprecated( |
|
"3.7", message=f"{loc=!r} is incompatible with " |
|
"{nth_coord=}; support is deprecated since %(since)s") |
|
self._loc = loc |
|
self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc] |
|
super().__init__() |
|
|
|
self._path = Path(self._to_xy((0, 1), const=self._pos)) |
|
|
|
def get_nth_coord(self): |
|
return self.nth_coord |
|
|
|
|
|
|
|
def get_line(self, axes): |
|
return self._path |
|
|
|
def get_line_transform(self, axes): |
|
return axes.transAxes |
|
|
|
|
|
|
|
def get_axislabel_transform(self, axes): |
|
return axes.transAxes |
|
|
|
def get_axislabel_pos_angle(self, axes): |
|
""" |
|
Return the label reference position in transAxes. |
|
|
|
get_label_transform() returns a transform of (transAxes+offset) |
|
""" |
|
return dict(left=((0., 0.5), 90), |
|
right=((1., 0.5), 90), |
|
bottom=((0.5, 0.), 0), |
|
top=((0.5, 1.), 0))[self._loc] |
|
|
|
|
|
|
|
def get_tick_transform(self, axes): |
|
return [axes.get_xaxis_transform(), |
|
axes.get_yaxis_transform()][self.nth_coord] |
|
|
|
|
|
class _FloatingAxisArtistHelperBase(_AxisArtistHelperBase): |
|
|
|
def __init__(self, nth_coord, value): |
|
self.nth_coord = nth_coord |
|
self._value = value |
|
super().__init__() |
|
|
|
def get_nth_coord(self): |
|
return self.nth_coord |
|
|
|
def get_line(self, axes): |
|
raise RuntimeError( |
|
"get_line method should be defined by the derived class") |
|
|
|
|
|
class FixedAxisArtistHelperRectilinear(_FixedAxisArtistHelperBase): |
|
|
|
def __init__(self, axes, loc, nth_coord=None): |
|
""" |
|
nth_coord = along which coordinate value varies |
|
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis |
|
""" |
|
super().__init__(loc, nth_coord) |
|
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] |
|
|
|
|
|
|
|
def get_tick_iterators(self, axes): |
|
"""tick_loc, tick_angle, tick_label""" |
|
if self._loc in ["bottom", "top"]: |
|
angle_normal, angle_tangent = 90, 0 |
|
else: |
|
angle_normal, angle_tangent = 0, 90 |
|
|
|
major = self.axis.major |
|
major_locs = major.locator() |
|
major_labels = major.formatter.format_ticks(major_locs) |
|
|
|
minor = self.axis.minor |
|
minor_locs = minor.locator() |
|
minor_labels = minor.formatter.format_ticks(minor_locs) |
|
|
|
tick_to_axes = self.get_tick_transform(axes) - axes.transAxes |
|
|
|
def _f(locs, labels): |
|
for loc, label in zip(locs, labels): |
|
c = self._to_xy(loc, const=self._pos) |
|
|
|
c2 = tick_to_axes.transform(c) |
|
if mpl.transforms._interval_contains_close( |
|
(0, 1), c2[self.nth_coord]): |
|
yield c, angle_normal, angle_tangent, label |
|
|
|
return _f(major_locs, major_labels), _f(minor_locs, minor_labels) |
|
|
|
|
|
class FloatingAxisArtistHelperRectilinear(_FloatingAxisArtistHelperBase): |
|
|
|
def __init__(self, axes, nth_coord, |
|
passingthrough_point, axis_direction="bottom"): |
|
super().__init__(nth_coord, passingthrough_point) |
|
self._axis_direction = axis_direction |
|
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] |
|
|
|
def get_line(self, axes): |
|
fixed_coord = 1 - self.nth_coord |
|
data_to_axes = axes.transData - axes.transAxes |
|
p = data_to_axes.transform([self._value, self._value]) |
|
return Path(self._to_xy((0, 1), const=p[fixed_coord])) |
|
|
|
def get_line_transform(self, axes): |
|
return axes.transAxes |
|
|
|
def get_axislabel_transform(self, axes): |
|
return axes.transAxes |
|
|
|
def get_axislabel_pos_angle(self, axes): |
|
""" |
|
Return the label reference position in transAxes. |
|
|
|
get_label_transform() returns a transform of (transAxes+offset) |
|
""" |
|
angle = [0, 90][self.nth_coord] |
|
fixed_coord = 1 - self.nth_coord |
|
data_to_axes = axes.transData - axes.transAxes |
|
p = data_to_axes.transform([self._value, self._value]) |
|
verts = self._to_xy(0.5, const=p[fixed_coord]) |
|
if 0 <= verts[fixed_coord] <= 1: |
|
return verts, angle |
|
else: |
|
return None, None |
|
|
|
def get_tick_transform(self, axes): |
|
return axes.transData |
|
|
|
def get_tick_iterators(self, axes): |
|
"""tick_loc, tick_angle, tick_label""" |
|
if self.nth_coord == 0: |
|
angle_normal, angle_tangent = 90, 0 |
|
else: |
|
angle_normal, angle_tangent = 0, 90 |
|
|
|
major = self.axis.major |
|
major_locs = major.locator() |
|
major_labels = major.formatter.format_ticks(major_locs) |
|
|
|
minor = self.axis.minor |
|
minor_locs = minor.locator() |
|
minor_labels = minor.formatter.format_ticks(minor_locs) |
|
|
|
data_to_axes = axes.transData - axes.transAxes |
|
|
|
def _f(locs, labels): |
|
for loc, label in zip(locs, labels): |
|
c = self._to_xy(loc, const=self._value) |
|
c1, c2 = data_to_axes.transform(c) |
|
if 0 <= c1 <= 1 and 0 <= c2 <= 1: |
|
yield c, angle_normal, angle_tangent, label |
|
|
|
return _f(major_locs, major_labels), _f(minor_locs, minor_labels) |
|
|
|
|
|
class AxisArtistHelper: |
|
Fixed = _FixedAxisArtistHelperBase |
|
Floating = _FloatingAxisArtistHelperBase |
|
|
|
|
|
class AxisArtistHelperRectlinear: |
|
Fixed = FixedAxisArtistHelperRectilinear |
|
Floating = FloatingAxisArtistHelperRectilinear |
|
|
|
|
|
class GridHelperBase: |
|
|
|
def __init__(self): |
|
self._old_limits = None |
|
super().__init__() |
|
|
|
def update_lim(self, axes): |
|
x1, x2 = axes.get_xlim() |
|
y1, y2 = axes.get_ylim() |
|
if self._old_limits != (x1, x2, y1, y2): |
|
self._update_grid(x1, y1, x2, y2) |
|
self._old_limits = (x1, x2, y1, y2) |
|
|
|
def _update_grid(self, x1, y1, x2, y2): |
|
"""Cache relevant computations when the axes limits have changed.""" |
|
|
|
def get_gridlines(self, which, axis): |
|
""" |
|
Return list of grid lines as a list of paths (list of points). |
|
|
|
Parameters |
|
---------- |
|
which : {"both", "major", "minor"} |
|
axis : {"both", "x", "y"} |
|
""" |
|
return [] |
|
|
|
|
|
class GridHelperRectlinear(GridHelperBase): |
|
|
|
def __init__(self, axes): |
|
super().__init__() |
|
self.axes = axes |
|
|
|
def new_fixed_axis(self, loc, |
|
nth_coord=None, |
|
axis_direction=None, |
|
offset=None, |
|
axes=None, |
|
): |
|
if axes is None: |
|
_api.warn_external( |
|
"'new_fixed_axis' explicitly requires the axes keyword.") |
|
axes = self.axes |
|
if axis_direction is None: |
|
axis_direction = loc |
|
|
|
helper = FixedAxisArtistHelperRectilinear(axes, loc, nth_coord) |
|
axisline = AxisArtist(axes, helper, offset=offset, |
|
axis_direction=axis_direction) |
|
return axisline |
|
|
|
def new_floating_axis(self, nth_coord, value, |
|
axis_direction="bottom", |
|
axes=None, |
|
): |
|
if axes is None: |
|
_api.warn_external( |
|
"'new_floating_axis' explicitly requires the axes keyword.") |
|
axes = self.axes |
|
|
|
helper = FloatingAxisArtistHelperRectilinear( |
|
axes, nth_coord, value, axis_direction) |
|
axisline = AxisArtist(axes, helper, axis_direction=axis_direction) |
|
axisline.line.set_clip_on(True) |
|
axisline.line.set_clip_box(axisline.axes.bbox) |
|
return axisline |
|
|
|
def get_gridlines(self, which="major", axis="both"): |
|
""" |
|
Return list of gridline coordinates in data coordinates. |
|
|
|
Parameters |
|
---------- |
|
which : {"both", "major", "minor"} |
|
axis : {"both", "x", "y"} |
|
""" |
|
_api.check_in_list(["both", "major", "minor"], which=which) |
|
_api.check_in_list(["both", "x", "y"], axis=axis) |
|
gridlines = [] |
|
|
|
if axis in ("both", "x"): |
|
locs = [] |
|
y1, y2 = self.axes.get_ylim() |
|
if which in ("both", "major"): |
|
locs.extend(self.axes.xaxis.major.locator()) |
|
if which in ("both", "minor"): |
|
locs.extend(self.axes.xaxis.minor.locator()) |
|
|
|
for x in locs: |
|
gridlines.append([[x, x], [y1, y2]]) |
|
|
|
if axis in ("both", "y"): |
|
x1, x2 = self.axes.get_xlim() |
|
locs = [] |
|
if self.axes.yaxis._major_tick_kw["gridOn"]: |
|
locs.extend(self.axes.yaxis.major.locator()) |
|
if self.axes.yaxis._minor_tick_kw["gridOn"]: |
|
locs.extend(self.axes.yaxis.minor.locator()) |
|
|
|
for y in locs: |
|
gridlines.append([[x1, x2], [y, y]]) |
|
|
|
return gridlines |
|
|
|
|
|
class Axes(maxes.Axes): |
|
|
|
@_api.deprecated("3.8", alternative="ax.axis") |
|
def __call__(self, *args, **kwargs): |
|
return maxes.Axes.axis(self.axes, *args, **kwargs) |
|
|
|
def __init__(self, *args, grid_helper=None, **kwargs): |
|
self._axisline_on = True |
|
self._grid_helper = (grid_helper if grid_helper |
|
else GridHelperRectlinear(self)) |
|
super().__init__(*args, **kwargs) |
|
self.toggle_axisline(True) |
|
|
|
def toggle_axisline(self, b=None): |
|
if b is None: |
|
b = not self._axisline_on |
|
if b: |
|
self._axisline_on = True |
|
self.spines[:].set_visible(False) |
|
self.xaxis.set_visible(False) |
|
self.yaxis.set_visible(False) |
|
else: |
|
self._axisline_on = False |
|
self.spines[:].set_visible(True) |
|
self.xaxis.set_visible(True) |
|
self.yaxis.set_visible(True) |
|
|
|
@property |
|
def axis(self): |
|
return self._axislines |
|
|
|
def clear(self): |
|
|
|
|
|
|
|
self.gridlines = gridlines = GridlinesCollection( |
|
[], |
|
colors=mpl.rcParams['grid.color'], |
|
linestyles=mpl.rcParams['grid.linestyle'], |
|
linewidths=mpl.rcParams['grid.linewidth']) |
|
self._set_artist_props(gridlines) |
|
gridlines.set_grid_helper(self.get_grid_helper()) |
|
|
|
super().clear() |
|
|
|
|
|
gridlines.set_clip_path(self.axes.patch) |
|
|
|
|
|
self._axislines = mpl_axes.Axes.AxisDict(self) |
|
new_fixed_axis = self.get_grid_helper().new_fixed_axis |
|
self._axislines.update({ |
|
loc: new_fixed_axis(loc=loc, axes=self, axis_direction=loc) |
|
for loc in ["bottom", "top", "left", "right"]}) |
|
for axisline in [self._axislines["top"], self._axislines["right"]]: |
|
axisline.label.set_visible(False) |
|
axisline.major_ticklabels.set_visible(False) |
|
axisline.minor_ticklabels.set_visible(False) |
|
|
|
def get_grid_helper(self): |
|
return self._grid_helper |
|
|
|
def grid(self, visible=None, which='major', axis="both", **kwargs): |
|
""" |
|
Toggle the gridlines, and optionally set the properties of the lines. |
|
""" |
|
|
|
|
|
|
|
super().grid(visible, which=which, axis=axis, **kwargs) |
|
if not self._axisline_on: |
|
return |
|
if visible is None: |
|
visible = (self.axes.xaxis._minor_tick_kw["gridOn"] |
|
or self.axes.xaxis._major_tick_kw["gridOn"] |
|
or self.axes.yaxis._minor_tick_kw["gridOn"] |
|
or self.axes.yaxis._major_tick_kw["gridOn"]) |
|
self.gridlines.set(which=which, axis=axis, visible=visible) |
|
self.gridlines.set(**kwargs) |
|
|
|
def get_children(self): |
|
if self._axisline_on: |
|
children = [*self._axislines.values(), self.gridlines] |
|
else: |
|
children = [] |
|
children.extend(super().get_children()) |
|
return children |
|
|
|
def new_fixed_axis(self, loc, offset=None): |
|
gh = self.get_grid_helper() |
|
axis = gh.new_fixed_axis(loc, |
|
nth_coord=None, |
|
axis_direction=None, |
|
offset=offset, |
|
axes=self, |
|
) |
|
return axis |
|
|
|
def new_floating_axis(self, nth_coord, value, axis_direction="bottom"): |
|
gh = self.get_grid_helper() |
|
axis = gh.new_floating_axis(nth_coord, value, |
|
axis_direction=axis_direction, |
|
axes=self) |
|
return axis |
|
|
|
|
|
class AxesZero(Axes): |
|
|
|
def clear(self): |
|
super().clear() |
|
new_floating_axis = self.get_grid_helper().new_floating_axis |
|
self._axislines.update( |
|
xzero=new_floating_axis( |
|
nth_coord=0, value=0., axis_direction="bottom", axes=self), |
|
yzero=new_floating_axis( |
|
nth_coord=1, value=0., axis_direction="left", axes=self), |
|
) |
|
for k in ["xzero", "yzero"]: |
|
self._axislines[k].line.set_clip_path(self.patch) |
|
self._axislines[k].set_visible(False) |
|
|
|
|
|
Subplot = Axes |
|
SubplotZero = AxesZero |
|
|