Spaces:
Running
Running
| """ | |
| 2D visualization primitives based on Matplotlib. | |
| 1) Plot images with `plot_images`. | |
| 2) Call `plot_keypoints` or `plot_matches` any number of times. | |
| 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. | |
| """ | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patheffects as path_effects | |
| import numpy as np | |
| def cm_RdGn(x): | |
| """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" | |
| x = np.clip(x, 0, 1)[..., None] * 2 | |
| c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) | |
| return np.clip(c, 0, 1) | |
| def plot_images( | |
| imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True, figsize=4.5 | |
| ): | |
| """Plot a set of images horizontally. | |
| Args: | |
| imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). | |
| titles: a list of strings, as titles for each image. | |
| cmaps: colormaps for monochrome images. | |
| adaptive: whether the figure size should fit the image aspect ratios. | |
| """ | |
| n = len(imgs) | |
| if not isinstance(cmaps, (list, tuple)): | |
| cmaps = [cmaps] * n | |
| if adaptive: | |
| ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H | |
| else: | |
| ratios = [4 / 3] * n | |
| figsize = [sum(ratios) * figsize, figsize] | |
| fig, axs = plt.subplots( | |
| 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} | |
| ) | |
| if n == 1: | |
| axs = [axs] | |
| for i, (img, ax) in enumerate(zip(imgs, axs)): | |
| ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) | |
| ax.set_axis_off() | |
| if titles: | |
| ax.set_title(titles[i]) | |
| fig.tight_layout(pad=pad) | |
| return fig | |
| def plot_keypoints(kpts, colors="lime", ps=4): | |
| """Plot keypoints for existing images. | |
| Args: | |
| kpts: list of ndarrays of size (N, 2). | |
| colors: string, or list of list of tuples (one for each keypoints). | |
| ps: size of the keypoints as float. | |
| """ | |
| if not isinstance(colors, list): | |
| colors = [colors] * len(kpts) | |
| axes = plt.gcf().axes | |
| try: | |
| for a, k, c in zip(axes, kpts, colors): | |
| a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) | |
| except IndexError: | |
| pass | |
| def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0): | |
| """Plot matches for a pair of existing images. | |
| Args: | |
| kpts0, kpts1: corresponding keypoints of size (N, 2). | |
| color: color of each match, string or RGB tuple. Random if not given. | |
| lw: width of the lines. | |
| ps: size of the end points (no endpoint if ps=0) | |
| indices: indices of the images to draw the matches on. | |
| a: alpha opacity of the match lines. | |
| """ | |
| fig = plt.gcf() | |
| ax = fig.axes | |
| assert len(ax) > max(indices) | |
| ax0, ax1 = ax[indices[0]], ax[indices[1]] | |
| fig.canvas.draw() | |
| assert len(kpts0) == len(kpts1) | |
| if color is None: | |
| color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() | |
| elif len(color) > 0 and not isinstance(color[0], (tuple, list)): | |
| color = [color] * len(kpts0) | |
| if lw > 0: | |
| # transform the points into the figure coordinate system | |
| for i in range(len(kpts0)): | |
| fig.add_artist( | |
| matplotlib.patches.ConnectionPatch( | |
| xyA=(kpts0[i, 0], kpts0[i, 1]), | |
| coordsA=ax0.transData, | |
| xyB=(kpts1[i, 0], kpts1[i, 1]), | |
| coordsB=ax1.transData, | |
| zorder=1, | |
| color=color[i], | |
| linewidth=lw, | |
| alpha=a, | |
| ) | |
| ) | |
| # freeze the axes to prevent the transform to change | |
| ax0.autoscale(enable=False) | |
| ax1.autoscale(enable=False) | |
| if ps > 0: | |
| ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) | |
| ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) | |
| def add_text( | |
| idx, | |
| text, | |
| pos=(0.01, 0.99), | |
| fs=15, | |
| color="w", | |
| lcolor="k", | |
| lwidth=2, | |
| ha="left", | |
| va="top", | |
| ): | |
| ax = plt.gcf().axes[idx] | |
| t = ax.text( | |
| *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes | |
| ) | |
| if lcolor is not None: | |
| t.set_path_effects( | |
| [ | |
| path_effects.Stroke(linewidth=lwidth, foreground=lcolor), | |
| path_effects.Normal(), | |
| ] | |
| ) | |
| def save_plot(path, **kw): | |
| """Save the current figure without any white margin.""" | |
| plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) | |