Spaces:
Running
Running
from typing import Any, Dict, Optional, Tuple, cast | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import numpy.typing as npt | |
import seaborn as sns | |
import torch | |
from PIL import Image | |
MAX_OPACITY = 255 | |
def plot_patches( | |
img: Image.Image, | |
patch_size: int, | |
image_resolution: int, | |
patch_opacities: Optional[npt.NDArray | torch.Tensor] = None, | |
figsize: Tuple[int, int] = (8, 8), | |
style: Dict[str, Any] | str | None = None, | |
) -> Tuple[plt.Figure, plt.Axes]: | |
""" | |
Plot patches of a square image. | |
Set `style` to "dark_background" if your image has a light background. | |
""" | |
# Get the number of patches | |
if image_resolution % patch_size != 0: | |
raise ValueError("The image resolution must be divisible by the patch size.") | |
num_patches = image_resolution // patch_size | |
# Default style | |
if style is None: | |
style = {} | |
# Sanity checks | |
if patch_opacities is not None: | |
if isinstance(patch_opacities, torch.Tensor): | |
patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy()) | |
if patch_opacities.shape != (num_patches, num_patches): | |
raise ValueError("The shape of the patch_opacities tensor is not correct.") | |
if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)): | |
raise ValueError("The patch_opacities tensor must have values between 0 and 1.") | |
# If the image is not square, raise an error | |
if img.size[0] != img.size[1]: | |
raise ValueError("The image must be square.") | |
# Get the image as a numpy array | |
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel | |
# Create a figure | |
with plt.style.context(style): | |
fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize) | |
# Plot the patches | |
for i in range(num_patches): | |
for j in range(num_patches): | |
patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :] | |
# Set the opacity of the patch | |
if patch_opacities is not None: | |
patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY) | |
axis[i, j].imshow(patch) | |
axis[i, j].axis("off") | |
fig.subplots_adjust(wspace=0.1, hspace=0.1) | |
fig.tight_layout() | |
return fig, axis | |
def plot_attention_heatmap( | |
img: Image.Image, | |
patch_size: int, | |
image_resolution: int, | |
attention_map: npt.NDArray | torch.Tensor, | |
figsize: Tuple[int, int] = (8, 8), | |
style: Dict[str, Any] | str | None = None, | |
show_colorbar: bool = False, | |
show_axes: bool = False, | |
) -> Tuple[plt.Figure, plt.Axes]: | |
""" | |
Plot a heatmap of the attention map over the image. | |
The image must be square and `attention_map` must be normalized between 0 and 1. | |
""" | |
# Get the number of patches | |
if image_resolution % patch_size != 0: | |
raise ValueError("The image resolution must be divisible by the patch size.") | |
num_patches = image_resolution // patch_size | |
# Default style | |
if style is None: | |
style = {} | |
# Sanity checks | |
if isinstance(attention_map, torch.Tensor): | |
attention_map = cast(npt.NDArray, attention_map.cpu().numpy()) | |
if attention_map.shape != (num_patches, num_patches): | |
raise ValueError("The shape of the patch_opacities tensor is not correct.") | |
if not np.all((0 <= attention_map) & (attention_map <= 1)): | |
raise ValueError("The patch_opacities tensor must have values between 0 and 1.") | |
# If the image is not square, raise an error | |
if img.size[0] != img.size[1]: | |
raise ValueError("The image must be square.") | |
# Get the image as a numpy array | |
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel | |
# Get the attention map as a numpy array | |
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize( | |
img.size, Image.Resampling.BICUBIC | |
) | |
# Create a figure | |
with plt.style.context(style): | |
fig, ax = plt.subplots(figsize=figsize) | |
ax.imshow(img_array) | |
im = ax.imshow( | |
attention_map_image, | |
cmap=sns.color_palette("mako", as_cmap=True), | |
alpha=0.5, | |
) | |
if show_colorbar: | |
fig.colorbar(im) | |
if not show_axes: | |
ax.set_axis_off() | |
fig.tight_layout() | |
return fig, ax | |