File size: 4,226 Bytes
9ff79dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from uuid import uuid4

import matplotlib.pyplot as plt
import torch
from einops import rearrange
from PIL import Image
from tqdm import trange

from colpali_engine.interpretability.plot_utils import plot_patches
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")


@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int


def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    savedir: str | Path | None = None,
    add_special_prompt_to_doc: bool = True,
) -> None:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")

    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]

    # Handle savepath
    if not savedir:
        savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
        print(f"No savepath provided. Results will be saved to: `{savedir}`.")
    elif isinstance(savedir, str):
        savedir = Path(savedir)
    savedir.mkdir(parents=True, exist_ok=True)

    # Resize the image to square
    input_image_square = image.resize((vit_config.resolution, vit_config.resolution))

    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model.forward(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)

    # NOTE: `output_image`` will have shape:
    # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
    # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
    with torch.no_grad():
        output_image = model.forward(**asdict(input_image_processed))

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[
            :, : processor.processor.image_seq_length, :
        ]  # (1, n_patch_x * n_patch_y, hidden_dim)

    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
    )  # (1, n_patch_x, n_patch_y, hidden_dim)

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = normalize_attention_map_per_query_token(
        attention_map
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = attention_map_normalized.float()

    # Get text token information
    n_tokens = input_text_processed.input_ids.size(1)
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    print("Text tokens:")
    pprint.pprint(text_tokens)
    print("\n")

    for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."):  # exclude the <bos> and the "\n" tokens
        fig, axis = plot_patches(
            input_image_square,
            vit_config.patch_size,
            vit_config.resolution,
            patch_opacities=attention_map_normalized[0, token_idx, :, :],
            style="dark_background",
        )

        fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        savepath = savedir / f"token_{token_idx}.png"
        fig.savefig(savepath)
        print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
        plt.close(fig)

    return