Spaces:
Running
Running
File size: 4,479 Bytes
9ff79dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
from typing import Any, Dict, Optional, Tuple, cast
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import seaborn as sns
import torch
from PIL import Image
MAX_OPACITY = 255
def plot_patches(
img: Image.Image,
patch_size: int,
image_resolution: int,
patch_opacities: Optional[npt.NDArray | torch.Tensor] = None,
figsize: Tuple[int, int] = (8, 8),
style: Dict[str, Any] | str | None = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot patches of a square image.
Set `style` to "dark_background" if your image has a light background.
"""
# Get the number of patches
if image_resolution % patch_size != 0:
raise ValueError("The image resolution must be divisible by the patch size.")
num_patches = image_resolution // patch_size
# Default style
if style is None:
style = {}
# Sanity checks
if patch_opacities is not None:
if isinstance(patch_opacities, torch.Tensor):
patch_opacities = cast(npt.NDArray, patch_opacities.cpu().numpy())
if patch_opacities.shape != (num_patches, num_patches):
raise ValueError("The shape of the patch_opacities tensor is not correct.")
if not np.all((0 <= patch_opacities) & (patch_opacities <= 1)):
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
# If the image is not square, raise an error
if img.size[0] != img.size[1]:
raise ValueError("The image must be square.")
# Get the image as a numpy array
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
# Create a figure
with plt.style.context(style):
fig, axis = plt.subplots(num_patches, num_patches, figsize=figsize)
# Plot the patches
for i in range(num_patches):
for j in range(num_patches):
patch = img_array[i * patch_size : (i + 1) * patch_size, j * patch_size : (j + 1) * patch_size, :]
# Set the opacity of the patch
if patch_opacities is not None:
patch[:, :, -1] = round(patch_opacities[i, j] * MAX_OPACITY)
axis[i, j].imshow(patch)
axis[i, j].axis("off")
fig.subplots_adjust(wspace=0.1, hspace=0.1)
fig.tight_layout()
return fig, axis
def plot_attention_heatmap(
img: Image.Image,
patch_size: int,
image_resolution: int,
attention_map: npt.NDArray | torch.Tensor,
figsize: Tuple[int, int] = (8, 8),
style: Dict[str, Any] | str | None = None,
show_colorbar: bool = False,
show_axes: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot a heatmap of the attention map over the image.
The image must be square and `attention_map` must be normalized between 0 and 1.
"""
# Get the number of patches
if image_resolution % patch_size != 0:
raise ValueError("The image resolution must be divisible by the patch size.")
num_patches = image_resolution // patch_size
# Default style
if style is None:
style = {}
# Sanity checks
if isinstance(attention_map, torch.Tensor):
attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
if attention_map.shape != (num_patches, num_patches):
raise ValueError("The shape of the patch_opacities tensor is not correct.")
if not np.all((0 <= attention_map) & (attention_map <= 1)):
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
# If the image is not square, raise an error
if img.size[0] != img.size[1]:
raise ValueError("The image must be square.")
# Get the image as a numpy array
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
# Get the attention map as a numpy array
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
img.size, Image.Resampling.BICUBIC
)
# Create a figure
with plt.style.context(style):
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(img_array)
im = ax.imshow(
attention_map_image,
cmap=sns.color_palette("mako", as_cmap=True),
alpha=0.5,
)
if show_colorbar:
fig.colorbar(im)
if not show_axes:
ax.set_axis_off()
fig.tight_layout()
return fig, ax
|