HUANG-Stephanie's picture
Upload 88 files
9ff79dc verified
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from uuid import uuid4
import matplotlib.pyplot as plt
import torch
from einops import rearrange
from PIL import Image
from tqdm import trange
from colpali_engine.interpretability.plot_utils import plot_patches
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali
OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
@dataclass
class InterpretabilityInput:
query: str
image: Image.Image
start_idx_token: int
end_idx_token: int
def generate_interpretability_plots(
model: ColPali,
processor: ColPaliProcessor,
query: str,
image: Image.Image,
savedir: str | Path | None = None,
add_special_prompt_to_doc: bool = True,
) -> None:
# Sanity checks
if len(model.active_adapters()) != 1:
raise ValueError("The model must have exactly one active adapter.")
if model.config.name_or_path not in VIT_CONFIG:
raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
vit_config = VIT_CONFIG[model.config.name_or_path]
# Handle savepath
if not savedir:
savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
print(f"No savepath provided. Results will be saved to: `{savedir}`.")
elif isinstance(savedir, str):
savedir = Path(savedir)
savedir.mkdir(parents=True, exist_ok=True)
# Resize the image to square
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
# Preprocess the inputs
input_text_processed = processor.process_text(query).to(model.device)
input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
model.device
)
# Forward pass
with torch.no_grad():
output_text = model.forward(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
# NOTE: `output_image`` will have shape:
# (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
# (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
with torch.no_grad():
output_image = model.forward(**asdict(input_image_processed))
if add_special_prompt_to_doc: # remove the special tokens
output_image = output_image[
:, : processor.processor.image_seq_length, :
] # (1, n_patch_x * n_patch_y, hidden_dim)
output_image = rearrange(
output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
) # (1, n_patch_x, n_patch_y, hidden_dim)
# Get the unnormalized attention map
attention_map = torch.einsum(
"bnk,bijk->bnij", output_text, output_image
) # (1, n_text_tokens, n_patch_x, n_patch_y)
attention_map_normalized = normalize_attention_map_per_query_token(
attention_map
) # (1, n_text_tokens, n_patch_x, n_patch_y)
attention_map_normalized = attention_map_normalized.float()
# Get text token information
n_tokens = input_text_processed.input_ids.size(1)
text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
print("Text tokens:")
pprint.pprint(text_tokens)
print("\n")
for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."): # exclude the <bos> and the "\n" tokens
fig, axis = plot_patches(
input_image_square,
vit_config.patch_size,
vit_config.resolution,
patch_opacities=attention_map_normalized[0, token_idx, :, :],
style="dark_background",
)
fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
savepath = savedir / f"token_{token_idx}.png"
fig.savefig(savepath)
print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
plt.close(fig)
return