Stefan Denner
commited on
Commit
·
208214b
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .gitignore +3 -0
- LeGrad/LICENSE +21 -0
- LeGrad/legrad/__init__.py +2 -0
- LeGrad/legrad/utils.py +722 -0
- LeGrad/legrad/wrapper.py +447 -0
- README.md +12 -0
- app.py +192 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
/venv
|
3 |
+
__pycache__
|
LeGrad/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Walid Bousselham, Angie Boggust, Sofian Chaybouti,Hendrik Strobelt Hilde Kuehne.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
LeGrad/legrad/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .wrapper import LeWrapper, LePreprocess
|
2 |
+
from .utils import *
|
LeGrad/legrad/utils.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Optional, List, Tuple
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2 as cv2
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn import functional as F
|
12 |
+
|
13 |
+
import open_clip
|
14 |
+
from open_clip import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
15 |
+
from open_clip.transformer import _expand_token
|
16 |
+
from timm.layers import resample_abs_pos_embed
|
17 |
+
|
18 |
+
|
19 |
+
################################################################################
|
20 |
+
# Hooks utils #
|
21 |
+
################################################################################
|
22 |
+
|
23 |
+
|
24 |
+
# ------------ Hooked Multi-Head Attention ------------
|
25 |
+
# from https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/transformer.py#L129
|
26 |
+
def hooked_attention_forward(
|
27 |
+
self,
|
28 |
+
x,
|
29 |
+
x_k,
|
30 |
+
x_v,
|
31 |
+
attn_mask: Optional[torch.Tensor] = None,
|
32 |
+
need_weights: bool = False,
|
33 |
+
):
|
34 |
+
L, N, C = x.shape
|
35 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
36 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
37 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
38 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
39 |
+
|
40 |
+
head_dim = q.shape[-1]
|
41 |
+
scale = float(head_dim) ** -0.5
|
42 |
+
q = q * scale
|
43 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
44 |
+
|
45 |
+
if attn_mask is not None:
|
46 |
+
if attn_mask.dtype == torch.bool:
|
47 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
48 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
49 |
+
attn_mask = new_attn_mask
|
50 |
+
attn += attn_mask
|
51 |
+
|
52 |
+
attn = attn.softmax(dim=-1)
|
53 |
+
# Hook for attention maps
|
54 |
+
self.attention_map = attn
|
55 |
+
|
56 |
+
x = torch.bmm(attn, v)
|
57 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
58 |
+
x = self.out_proj(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
def hooked_attention_timm_forward(self, x, attn_mask=None):
|
63 |
+
B, N, C = x.shape
|
64 |
+
qkv = (
|
65 |
+
self.qkv(x)
|
66 |
+
.reshape(B, N, 3, self.num_heads, self.head_dim)
|
67 |
+
.permute(2, 0, 3, 1, 4)
|
68 |
+
)
|
69 |
+
q, k, v = qkv.unbind(0)
|
70 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
71 |
+
|
72 |
+
q = q * self.scale
|
73 |
+
attn = q @ k.transpose(-2, -1)
|
74 |
+
attn = attn.softmax(dim=-1)
|
75 |
+
attn = self.attn_drop(attn)
|
76 |
+
x = attn @ v
|
77 |
+
|
78 |
+
# Hook to save attention map for explainability
|
79 |
+
self.attention_map = attn
|
80 |
+
|
81 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
82 |
+
x = self.proj(x)
|
83 |
+
x = self.proj_drop(x)
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
# ------------ Hooked Residual Transformer Block ------------
|
88 |
+
# from https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/transformer.py#L231
|
89 |
+
def hooked_resblock_forward(self, q_x, k_x=None, v_x=None, attn_mask=None):
|
90 |
+
assert k_x is None and v_x is None, "k_x and v_x must be None"
|
91 |
+
|
92 |
+
# Modify this line to include the necessary arguments for hooked_attention_forward
|
93 |
+
x = q_x + self.ls1(
|
94 |
+
self.attn(
|
95 |
+
self.norm1(q_x),
|
96 |
+
k_x=k_x,
|
97 |
+
v_x=v_x,
|
98 |
+
attn_mask=attn_mask,
|
99 |
+
)
|
100 |
+
)
|
101 |
+
# Hook for intermediate features post Attn
|
102 |
+
self.feat_post_attn = x
|
103 |
+
x = x + self.ls2(self.mlp(self.norm2(x)))
|
104 |
+
|
105 |
+
# Hook for intermediate features post MLP
|
106 |
+
self.feat_post_mlp = x
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
# ------------ Hooked PyTorch's Multi-Head AttentionResidual ------------
|
111 |
+
# modified from PyTorch Library
|
112 |
+
# https://github.com/pytorch/pytorch/blob/8c8e4e31f2ddd8e59de18ac733c0c205c23d14ad/torch/nn/functional.py#L5178
|
113 |
+
def hooked_torch_multi_head_attention_forward(
|
114 |
+
self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None
|
115 |
+
):
|
116 |
+
r"""
|
117 |
+
Args:
|
118 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
119 |
+
See "Attention Is All You Need" for more details.
|
120 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
121 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
122 |
+
the corresponding value on the attention layer will be ignored. When given
|
123 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
124 |
+
layer will be ignored
|
125 |
+
need_weights: output attn_output_weights.
|
126 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
127 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
128 |
+
|
129 |
+
Shape:
|
130 |
+
- Inputs:
|
131 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
132 |
+
the embedding dimension.
|
133 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
134 |
+
the embedding dimension.
|
135 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
136 |
+
the embedding dimension.
|
137 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
138 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
139 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
140 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
141 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
142 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
143 |
+
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
144 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
145 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
146 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
147 |
+
is provided, it will be added to the attention weight.
|
148 |
+
|
149 |
+
- Outputs:
|
150 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
151 |
+
E is the embedding dimension.
|
152 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
153 |
+
L is the target sequence length, S is the source sequence length.
|
154 |
+
"""
|
155 |
+
if not self._qkv_same_embed_dim:
|
156 |
+
out, _attn_maps = hooked_torch_func_multi_head_attention_forward(
|
157 |
+
query,
|
158 |
+
key,
|
159 |
+
value,
|
160 |
+
self.embed_dim,
|
161 |
+
self.num_heads,
|
162 |
+
self.in_proj_weight,
|
163 |
+
self.in_proj_bias,
|
164 |
+
self.bias_k,
|
165 |
+
self.bias_v,
|
166 |
+
self.add_zero_attn,
|
167 |
+
self.dropout,
|
168 |
+
self.out_proj.weight,
|
169 |
+
self.out_proj.bias,
|
170 |
+
training=self.training,
|
171 |
+
key_padding_mask=key_padding_mask,
|
172 |
+
need_weights=True,
|
173 |
+
attn_mask=attn_mask,
|
174 |
+
use_separate_proj_weight=True,
|
175 |
+
q_proj_weight=self.q_proj_weight,
|
176 |
+
k_proj_weight=self.k_proj_weight,
|
177 |
+
v_proj_weight=self.v_proj_weight,
|
178 |
+
)
|
179 |
+
# Hook for attention maps
|
180 |
+
self.attention_maps = _attn_maps
|
181 |
+
return out, _attn_maps
|
182 |
+
else:
|
183 |
+
out, _attn_maps = hooked_torch_func_multi_head_attention_forward(
|
184 |
+
query,
|
185 |
+
key,
|
186 |
+
value,
|
187 |
+
self.embed_dim,
|
188 |
+
self.num_heads,
|
189 |
+
self.in_proj_weight,
|
190 |
+
self.in_proj_bias,
|
191 |
+
self.bias_k,
|
192 |
+
self.bias_v,
|
193 |
+
self.add_zero_attn,
|
194 |
+
self.dropout,
|
195 |
+
self.out_proj.weight,
|
196 |
+
self.out_proj.bias,
|
197 |
+
training=self.training,
|
198 |
+
key_padding_mask=key_padding_mask,
|
199 |
+
need_weights=True,
|
200 |
+
attn_mask=attn_mask,
|
201 |
+
)
|
202 |
+
# Hook for attention maps
|
203 |
+
self.attention_maps = _attn_maps
|
204 |
+
return out, _attn_maps
|
205 |
+
|
206 |
+
|
207 |
+
def hooked_torch_func_multi_head_attention_forward(
|
208 |
+
query: Tensor,
|
209 |
+
key: Tensor,
|
210 |
+
value: Tensor,
|
211 |
+
embed_dim_to_check: int,
|
212 |
+
num_heads: int,
|
213 |
+
in_proj_weight: Tensor,
|
214 |
+
in_proj_bias: Tensor,
|
215 |
+
bias_k: Optional[Tensor],
|
216 |
+
bias_v: Optional[Tensor],
|
217 |
+
add_zero_attn: bool,
|
218 |
+
dropout_p: float,
|
219 |
+
out_proj_weight: Tensor,
|
220 |
+
out_proj_bias: Tensor,
|
221 |
+
training: bool = True,
|
222 |
+
key_padding_mask: Optional[Tensor] = None,
|
223 |
+
need_weights: bool = True,
|
224 |
+
attn_mask: Optional[Tensor] = None,
|
225 |
+
use_separate_proj_weight: bool = False,
|
226 |
+
q_proj_weight: Optional[Tensor] = None,
|
227 |
+
k_proj_weight: Optional[Tensor] = None,
|
228 |
+
v_proj_weight: Optional[Tensor] = None,
|
229 |
+
static_k: Optional[Tensor] = None,
|
230 |
+
static_v: Optional[Tensor] = None,
|
231 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
232 |
+
if not torch.jit.is_scripting():
|
233 |
+
tens_ops = (
|
234 |
+
query,
|
235 |
+
key,
|
236 |
+
value,
|
237 |
+
in_proj_weight,
|
238 |
+
in_proj_bias,
|
239 |
+
bias_k,
|
240 |
+
bias_v,
|
241 |
+
out_proj_weight,
|
242 |
+
out_proj_bias,
|
243 |
+
)
|
244 |
+
if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(
|
245 |
+
tens_ops
|
246 |
+
):
|
247 |
+
return F.handle_torch_function(
|
248 |
+
multi_head_attention_forward,
|
249 |
+
tens_ops,
|
250 |
+
query,
|
251 |
+
key,
|
252 |
+
value,
|
253 |
+
embed_dim_to_check,
|
254 |
+
num_heads,
|
255 |
+
in_proj_weight,
|
256 |
+
in_proj_bias,
|
257 |
+
bias_k,
|
258 |
+
bias_v,
|
259 |
+
add_zero_attn,
|
260 |
+
dropout_p,
|
261 |
+
out_proj_weight,
|
262 |
+
out_proj_bias,
|
263 |
+
training=training,
|
264 |
+
key_padding_mask=key_padding_mask,
|
265 |
+
need_weights=need_weights,
|
266 |
+
attn_mask=attn_mask,
|
267 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
268 |
+
q_proj_weight=q_proj_weight,
|
269 |
+
k_proj_weight=k_proj_weight,
|
270 |
+
v_proj_weight=v_proj_weight,
|
271 |
+
static_k=static_k,
|
272 |
+
static_v=static_v,
|
273 |
+
)
|
274 |
+
tgt_len, bsz, embed_dim = query.size()
|
275 |
+
assert embed_dim == embed_dim_to_check
|
276 |
+
# allow MHA to have different sizes for the feature dimension
|
277 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
278 |
+
|
279 |
+
head_dim = embed_dim // num_heads
|
280 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
281 |
+
scaling = float(head_dim) ** -0.5
|
282 |
+
|
283 |
+
if not use_separate_proj_weight:
|
284 |
+
if torch.equal(query, key) and torch.equal(key, value):
|
285 |
+
# self-attention
|
286 |
+
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
287 |
+
|
288 |
+
elif torch.equal(key, value):
|
289 |
+
# encoder-decoder attention
|
290 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
291 |
+
_b = in_proj_bias
|
292 |
+
_start = 0
|
293 |
+
_end = embed_dim
|
294 |
+
_w = in_proj_weight[_start:_end, :]
|
295 |
+
if _b is not None:
|
296 |
+
_b = _b[_start:_end]
|
297 |
+
q = F.linear(query, _w, _b)
|
298 |
+
|
299 |
+
if key is None:
|
300 |
+
assert value is None
|
301 |
+
k = None
|
302 |
+
v = None
|
303 |
+
else:
|
304 |
+
|
305 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
306 |
+
_b = in_proj_bias
|
307 |
+
_start = embed_dim
|
308 |
+
_end = None
|
309 |
+
_w = in_proj_weight[_start:, :]
|
310 |
+
if _b is not None:
|
311 |
+
_b = _b[_start:]
|
312 |
+
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
313 |
+
|
314 |
+
else:
|
315 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
316 |
+
_b = in_proj_bias
|
317 |
+
_start = 0
|
318 |
+
_end = embed_dim
|
319 |
+
_w = in_proj_weight[_start:_end, :]
|
320 |
+
if _b is not None:
|
321 |
+
_b = _b[_start:_end]
|
322 |
+
q = F.linear(query, _w, _b)
|
323 |
+
|
324 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
325 |
+
_b = in_proj_bias
|
326 |
+
_start = embed_dim
|
327 |
+
_end = embed_dim * 2
|
328 |
+
_w = in_proj_weight[_start:_end, :]
|
329 |
+
if _b is not None:
|
330 |
+
_b = _b[_start:_end]
|
331 |
+
k = F.linear(key, _w, _b)
|
332 |
+
|
333 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
334 |
+
_b = in_proj_bias
|
335 |
+
_start = embed_dim * 2
|
336 |
+
_end = None
|
337 |
+
_w = in_proj_weight[_start:, :]
|
338 |
+
if _b is not None:
|
339 |
+
_b = _b[_start:]
|
340 |
+
v = F.linear(value, _w, _b)
|
341 |
+
else:
|
342 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
343 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
344 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
345 |
+
|
346 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
347 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
348 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
349 |
+
|
350 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
351 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
352 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
353 |
+
|
354 |
+
if in_proj_bias is not None:
|
355 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
356 |
+
k = F.linear(
|
357 |
+
key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]
|
358 |
+
)
|
359 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
|
360 |
+
else:
|
361 |
+
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
362 |
+
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
363 |
+
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
364 |
+
q = q * scaling
|
365 |
+
|
366 |
+
if attn_mask is not None:
|
367 |
+
assert (
|
368 |
+
attn_mask.dtype == torch.float32
|
369 |
+
or attn_mask.dtype == torch.float64
|
370 |
+
or attn_mask.dtype == torch.float16
|
371 |
+
or attn_mask.dtype == torch.uint8
|
372 |
+
or attn_mask.dtype == torch.bool
|
373 |
+
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
374 |
+
attn_mask.dtype
|
375 |
+
)
|
376 |
+
if attn_mask.dtype == torch.uint8:
|
377 |
+
warnings.warn(
|
378 |
+
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
379 |
+
)
|
380 |
+
attn_mask = attn_mask.to(torch.bool)
|
381 |
+
|
382 |
+
if attn_mask.dim() == 2:
|
383 |
+
attn_mask = attn_mask.unsqueeze(0)
|
384 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
385 |
+
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
386 |
+
elif attn_mask.dim() == 3:
|
387 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
388 |
+
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
389 |
+
else:
|
390 |
+
raise RuntimeError(
|
391 |
+
"attn_mask's dimension {} is not supported".format(attn_mask.dim())
|
392 |
+
)
|
393 |
+
# attn_mask's dim is 3 now.
|
394 |
+
|
395 |
+
# convert ByteTensor key_padding_mask to bool
|
396 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
397 |
+
warnings.warn(
|
398 |
+
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
399 |
+
)
|
400 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
401 |
+
|
402 |
+
if bias_k is not None and bias_v is not None:
|
403 |
+
if static_k is None and static_v is None:
|
404 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
405 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
406 |
+
if attn_mask is not None:
|
407 |
+
attn_mask = pad(attn_mask, (0, 1))
|
408 |
+
if key_padding_mask is not None:
|
409 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
410 |
+
else:
|
411 |
+
assert static_k is None, "bias cannot be added to static key."
|
412 |
+
assert static_v is None, "bias cannot be added to static value."
|
413 |
+
else:
|
414 |
+
assert bias_k is None
|
415 |
+
assert bias_v is None
|
416 |
+
|
417 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
418 |
+
if k is not None:
|
419 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
420 |
+
if v is not None:
|
421 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
422 |
+
|
423 |
+
if static_k is not None:
|
424 |
+
assert static_k.size(0) == bsz * num_heads
|
425 |
+
assert static_k.size(2) == head_dim
|
426 |
+
k = static_k
|
427 |
+
|
428 |
+
if static_v is not None:
|
429 |
+
assert static_v.size(0) == bsz * num_heads
|
430 |
+
assert static_v.size(2) == head_dim
|
431 |
+
v = static_v
|
432 |
+
|
433 |
+
src_len = k.size(1)
|
434 |
+
|
435 |
+
if key_padding_mask is not None:
|
436 |
+
assert key_padding_mask.size(0) == bsz
|
437 |
+
assert key_padding_mask.size(1) == src_len
|
438 |
+
|
439 |
+
if add_zero_attn:
|
440 |
+
src_len += 1
|
441 |
+
k = torch.cat(
|
442 |
+
[
|
443 |
+
k,
|
444 |
+
torch.zeros(
|
445 |
+
(k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
|
446 |
+
),
|
447 |
+
],
|
448 |
+
dim=1,
|
449 |
+
)
|
450 |
+
v = torch.cat(
|
451 |
+
[
|
452 |
+
v,
|
453 |
+
torch.zeros(
|
454 |
+
(v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
|
455 |
+
),
|
456 |
+
],
|
457 |
+
dim=1,
|
458 |
+
)
|
459 |
+
if attn_mask is not None:
|
460 |
+
attn_mask = pad(attn_mask, (0, 1))
|
461 |
+
if key_padding_mask is not None:
|
462 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
463 |
+
|
464 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
465 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
466 |
+
|
467 |
+
if attn_mask is not None:
|
468 |
+
if attn_mask.dtype == torch.bool:
|
469 |
+
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
470 |
+
else:
|
471 |
+
attn_output_weights += attn_mask
|
472 |
+
|
473 |
+
if key_padding_mask is not None:
|
474 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
475 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
476 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
477 |
+
float("-inf"),
|
478 |
+
)
|
479 |
+
attn_output_weights = attn_output_weights.view(
|
480 |
+
bsz * num_heads, tgt_len, src_len
|
481 |
+
)
|
482 |
+
|
483 |
+
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
|
484 |
+
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
485 |
+
|
486 |
+
# # use hooks for the attention weights if necessary
|
487 |
+
# self.attention_map = attn_output_weights
|
488 |
+
# # if attention_probs_forward_hook is not None and attention_probs_backwards_hook is not None:
|
489 |
+
# if attention_probs_forward_hook is not None:
|
490 |
+
# attention_probs_forward_hook(attn_output_weights)
|
491 |
+
# # attn_output_weights.register_hook(attention_probs_backwards_hook)
|
492 |
+
|
493 |
+
attn_output = torch.bmm(attn_output_weights, v)
|
494 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
495 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
496 |
+
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
497 |
+
|
498 |
+
if need_weights:
|
499 |
+
# --- Fix: removed the unnecessary average over heads, Why?
|
500 |
+
# average attention weights over heads
|
501 |
+
# attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
502 |
+
# return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
503 |
+
return attn_output, attn_output_weights
|
504 |
+
else:
|
505 |
+
return attn_output, None
|
506 |
+
|
507 |
+
|
508 |
+
# ------------ Hooked TimmModel's Residual Transformer Block ------------
|
509 |
+
def hooked_resblock_timm_forward(self, x: torch.Tensor) -> torch.Tensor:
|
510 |
+
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
|
511 |
+
self.feat_post_attn = x
|
512 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
513 |
+
self.feat_post_mlp = x
|
514 |
+
return x
|
515 |
+
|
516 |
+
|
517 |
+
# ------------ Hooked TimmModel's Attentional Pooler ------------
|
518 |
+
def hooked_attentional_pooler_timm_forward(self, x):
|
519 |
+
B, N, C = x.shape
|
520 |
+
|
521 |
+
if self.pos_embed is not None:
|
522 |
+
# FIXME interpolate
|
523 |
+
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
|
524 |
+
|
525 |
+
q_latent = self.latent.expand(B, -1, -1)
|
526 |
+
q = (
|
527 |
+
self.q(q_latent)
|
528 |
+
.reshape(B, self.latent_len, self.num_heads, self.head_dim)
|
529 |
+
.transpose(1, 2)
|
530 |
+
)
|
531 |
+
|
532 |
+
kv = (
|
533 |
+
self.kv(x)
|
534 |
+
.reshape(B, N, 2, self.num_heads, self.head_dim)
|
535 |
+
.permute(2, 0, 3, 1, 4)
|
536 |
+
)
|
537 |
+
k, v = kv.unbind(0)
|
538 |
+
|
539 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
540 |
+
|
541 |
+
q = q * self.scale
|
542 |
+
attn = q @ k.transpose(-2, -1)
|
543 |
+
attn = attn.softmax(dim=-1)
|
544 |
+
x = attn @ v
|
545 |
+
|
546 |
+
# Hook to save attention map for explainability
|
547 |
+
self.attn_probs = attn
|
548 |
+
|
549 |
+
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
550 |
+
x = self.proj(x)
|
551 |
+
x = self.proj_drop(x)
|
552 |
+
|
553 |
+
x = x + self.mlp(self.norm(x))
|
554 |
+
|
555 |
+
# optional pool if latent seq_len > 1 and pooled output is desired
|
556 |
+
if self.pool == "token":
|
557 |
+
x = x[:, 0]
|
558 |
+
elif self.pool == "avg":
|
559 |
+
x = x.mean(1)
|
560 |
+
return x
|
561 |
+
|
562 |
+
|
563 |
+
# ------------ OpenCLIP ViT forward with dynamic size ------------
|
564 |
+
def vit_dynamic_size_forward(self, x: torch.Tensor):
|
565 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
566 |
+
grid_h, grid_w = x.shape[2:]
|
567 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
568 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
569 |
+
|
570 |
+
# class embeddings and positional embeddings
|
571 |
+
x = torch.cat(
|
572 |
+
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
|
573 |
+
)
|
574 |
+
# shape = [*, grid ** 2 + 1, width]
|
575 |
+
if x.shape[1] != self.positional_embedding.shape[1]:
|
576 |
+
self.positional_embedding.data = resample_abs_pos_embed(
|
577 |
+
self.positional_embedding.unsqueeze(0),
|
578 |
+
new_size=[grid_h, grid_w],
|
579 |
+
# old_size=list(self.grid_size),
|
580 |
+
num_prefix_tokens=1,
|
581 |
+
interpolation="bicubic",
|
582 |
+
antialias=True,
|
583 |
+
)
|
584 |
+
|
585 |
+
x = x + self.positional_embedding.to(x.dtype)
|
586 |
+
|
587 |
+
x = self.patch_dropout(x)
|
588 |
+
x = self.ln_pre(x)
|
589 |
+
|
590 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
591 |
+
x = self.transformer(x)
|
592 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
593 |
+
|
594 |
+
if self.attn_pool is not None:
|
595 |
+
if self.attn_pool_contrastive is not None:
|
596 |
+
# This is untested, WIP pooling that should match paper
|
597 |
+
x = self.ln_post(x) # TBD LN first or separate one after each pool?
|
598 |
+
tokens = self.attn_pool(x)
|
599 |
+
if self.attn_pool_type == "parallel":
|
600 |
+
pooled = self.attn_pool_contrastive(x)
|
601 |
+
else:
|
602 |
+
assert self.attn_pool_type == "cascade"
|
603 |
+
pooled = self.attn_pool_contrastive(tokens)
|
604 |
+
else:
|
605 |
+
# this is the original OpenCLIP CoCa setup, does not match paper
|
606 |
+
x = self.attn_pool(x)
|
607 |
+
x = self.ln_post(x)
|
608 |
+
pooled, tokens = self._global_pool(x)
|
609 |
+
elif self.final_ln_after_pool:
|
610 |
+
pooled, tokens = self._global_pool(x)
|
611 |
+
pooled = self.ln_post(pooled)
|
612 |
+
else:
|
613 |
+
x = self.ln_post(x)
|
614 |
+
pooled, tokens = self._global_pool(x)
|
615 |
+
|
616 |
+
if self.proj is not None:
|
617 |
+
pooled = pooled @ self.proj
|
618 |
+
|
619 |
+
if self.output_tokens:
|
620 |
+
return pooled, tokens
|
621 |
+
|
622 |
+
return pooled
|
623 |
+
|
624 |
+
|
625 |
+
################################################################################
|
626 |
+
# Visualization utils #
|
627 |
+
################################################################################
|
628 |
+
|
629 |
+
|
630 |
+
def min_max(logits):
|
631 |
+
B, num_prompt = logits.shape[:2]
|
632 |
+
logits_min = (
|
633 |
+
logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
634 |
+
)
|
635 |
+
logits_max = (
|
636 |
+
logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
637 |
+
)
|
638 |
+
logits = (logits - logits_min) / (logits_max - logits_min)
|
639 |
+
return logits
|
640 |
+
|
641 |
+
|
642 |
+
def visualize(image, heatmaps, alpha=0.6, save_path: Path = None):
|
643 |
+
# heatmaps of shape (N, 1, W, H)
|
644 |
+
W, H = heatmaps.shape[-2:]
|
645 |
+
if isinstance(image, Image.Image):
|
646 |
+
image = image.resize((W, H))
|
647 |
+
elif isinstance(image, torch.Tensor):
|
648 |
+
if image.ndim > 3:
|
649 |
+
image = image.squeeze(0)
|
650 |
+
# undo the normalization
|
651 |
+
image_unormed = (
|
652 |
+
image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]
|
653 |
+
) + torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None]
|
654 |
+
# convert to PIL
|
655 |
+
image = Image.fromarray(
|
656 |
+
(image_unormed.permute(1, 2, 0).numpy() * 255).astype("uint8")
|
657 |
+
)
|
658 |
+
else:
|
659 |
+
raise f"image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}"
|
660 |
+
|
661 |
+
# plot image
|
662 |
+
plt.imshow(image)
|
663 |
+
plt.axis("off")
|
664 |
+
plt.tight_layout()
|
665 |
+
plt.show()
|
666 |
+
|
667 |
+
if heatmaps.ndim > 3:
|
668 |
+
heatmaps = heatmaps.squeeze(0)
|
669 |
+
heatmaps = heatmaps.detach().cpu().numpy()
|
670 |
+
|
671 |
+
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
672 |
+
heatmaps = (heatmaps * 255).astype("uint8")
|
673 |
+
heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in heatmaps]
|
674 |
+
|
675 |
+
vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
|
676 |
+
for i, viz in enumerate(vizs):
|
677 |
+
viz = cv2.cvtColor(viz.astype("uint8"), cv2.COLOR_BGR2RGB)
|
678 |
+
plt.imshow(viz)
|
679 |
+
plt.axis("off")
|
680 |
+
plt.tight_layout()
|
681 |
+
# remove the margin
|
682 |
+
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
683 |
+
plt.show()
|
684 |
+
if save_path is not None:
|
685 |
+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
686 |
+
print(f"Saved visualization at {save_path}")
|
687 |
+
|
688 |
+
|
689 |
+
def list_pretrained():
|
690 |
+
openclip_list_ = open_clip.list_pretrained()
|
691 |
+
filtered_list = [
|
692 |
+
(model_name, pretrained)
|
693 |
+
for (model_name, pretrained) in openclip_list_
|
694 |
+
if model_name
|
695 |
+
]
|
696 |
+
unsupported_models = [
|
697 |
+
"RN",
|
698 |
+
"convnext",
|
699 |
+
] # legrad doesn't support CNN-based VLMs (for the moment)
|
700 |
+
_str = (
|
701 |
+
": ".join(["model_name" + " " * (25 - len("model_name")), "pretrained"]) + "\n"
|
702 |
+
) # for nice display
|
703 |
+
for model_name, pretrained in openclip_list_:
|
704 |
+
for unsup_model in unsupported_models:
|
705 |
+
if unsup_model in model_name:
|
706 |
+
skip = True
|
707 |
+
break
|
708 |
+
else:
|
709 |
+
skip = False
|
710 |
+
if not skip:
|
711 |
+
filtered_list.append((model_name, pretrained))
|
712 |
+
_str += (
|
713 |
+
": ".join([model_name + " " * (25 - len(model_name)), pretrained])
|
714 |
+
+ "\n"
|
715 |
+
) # for nice display
|
716 |
+
|
717 |
+
print(_str)
|
718 |
+
return filtered_list
|
719 |
+
|
720 |
+
|
721 |
+
if __name__ == "__main__":
|
722 |
+
list_pretrained()
|
LeGrad/legrad/wrapper.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import types
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torchvision.transforms import Compose, Resize, InterpolationMode
|
7 |
+
import open_clip
|
8 |
+
from open_clip.transformer import VisionTransformer
|
9 |
+
from open_clip.timm_model import TimmModel
|
10 |
+
from einops import rearrange
|
11 |
+
|
12 |
+
from .utils import (
|
13 |
+
hooked_attention_timm_forward,
|
14 |
+
hooked_resblock_forward,
|
15 |
+
hooked_attention_forward,
|
16 |
+
hooked_resblock_timm_forward,
|
17 |
+
hooked_attentional_pooler_timm_forward,
|
18 |
+
vit_dynamic_size_forward,
|
19 |
+
min_max,
|
20 |
+
hooked_torch_multi_head_attention_forward,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class LeWrapper(nn.Module):
|
25 |
+
"""
|
26 |
+
Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, model, layer_index=-2):
|
30 |
+
super(LeWrapper, self).__init__()
|
31 |
+
# ------------ copy of model's attributes and methods ------------
|
32 |
+
for attr in dir(model):
|
33 |
+
if not attr.startswith("__"):
|
34 |
+
setattr(self, attr, getattr(model, attr))
|
35 |
+
|
36 |
+
# ------------ activate hooks & gradient ------------
|
37 |
+
self._activate_hooks(layer_index=layer_index)
|
38 |
+
|
39 |
+
def _activate_hooks(self, layer_index):
|
40 |
+
# ------------ identify model's type ------------
|
41 |
+
print("Activating necessary hooks and gradients ....")
|
42 |
+
if isinstance(self.visual, VisionTransformer):
|
43 |
+
# --- Activate dynamic image size ---
|
44 |
+
self.visual.forward = types.MethodType(
|
45 |
+
vit_dynamic_size_forward, self.visual
|
46 |
+
)
|
47 |
+
# Get patch size
|
48 |
+
self.patch_size = self.visual.patch_size[0]
|
49 |
+
# Get starting depth (in case of negative layer_index)
|
50 |
+
self.starting_depth = (
|
51 |
+
layer_index
|
52 |
+
if layer_index >= 0
|
53 |
+
else len(self.visual.transformer.resblocks) + layer_index
|
54 |
+
)
|
55 |
+
|
56 |
+
if self.visual.attn_pool is None:
|
57 |
+
self.model_type = "clip"
|
58 |
+
self._activate_self_attention_hooks()
|
59 |
+
else:
|
60 |
+
self.model_type = "coca"
|
61 |
+
self._activate_att_pool_hooks(layer_index=layer_index)
|
62 |
+
|
63 |
+
elif isinstance(self.visual, TimmModel):
|
64 |
+
# --- Activate dynamic image size ---
|
65 |
+
self.visual.trunk.dynamic_img_size = True
|
66 |
+
self.visual.trunk.patch_embed.dynamic_img_size = True
|
67 |
+
self.visual.trunk.patch_embed.strict_img_size = False
|
68 |
+
self.visual.trunk.patch_embed.flatten = False
|
69 |
+
self.visual.trunk.patch_embed.output_fmt = "NHWC"
|
70 |
+
self.model_type = "timm_siglip"
|
71 |
+
# --- Get patch size ---
|
72 |
+
self.patch_size = self.visual.trunk.patch_embed.patch_size[0]
|
73 |
+
# --- Get starting depth (in case of negative layer_index) ---
|
74 |
+
self.starting_depth = (
|
75 |
+
layer_index
|
76 |
+
if layer_index >= 0
|
77 |
+
else len(self.visual.trunk.blocks) + layer_index
|
78 |
+
)
|
79 |
+
if (
|
80 |
+
hasattr(self.visual.trunk, "attn_pool")
|
81 |
+
and self.visual.trunk.attn_pool is not None
|
82 |
+
):
|
83 |
+
self._activate_timm_attn_pool_hooks(layer_index=layer_index)
|
84 |
+
else:
|
85 |
+
self._activate_timm_self_attention_hooks()
|
86 |
+
else:
|
87 |
+
raise ValueError(
|
88 |
+
"Model currently not supported, see legrad.list_pretrained() for a list of available models"
|
89 |
+
)
|
90 |
+
print("Hooks and gradients activated!")
|
91 |
+
|
92 |
+
def _activate_self_attention_hooks(self):
|
93 |
+
# Adjusting to use the correct structure
|
94 |
+
if isinstance(self.visual, VisionTransformer):
|
95 |
+
blocks = self.visual.transformer.resblocks
|
96 |
+
elif isinstance(self.visual, TimmModel):
|
97 |
+
blocks = self.visual.trunk.blocks
|
98 |
+
else:
|
99 |
+
raise ValueError("Unsupported model type for self-attention hooks")
|
100 |
+
|
101 |
+
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
|
102 |
+
# Necessary steps to get intermediate representations
|
103 |
+
for name, param in self.named_parameters():
|
104 |
+
param.requires_grad = False
|
105 |
+
if name.startswith("visual.trunk.blocks"):
|
106 |
+
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
|
107 |
+
if depth >= self.starting_depth:
|
108 |
+
param.requires_grad = True
|
109 |
+
|
110 |
+
# --- Activate the hooks for the specific layers ---
|
111 |
+
for layer in range(self.starting_depth, len(blocks)):
|
112 |
+
blocks[layer].attn.forward = types.MethodType(
|
113 |
+
hooked_attention_forward, blocks[layer].attn
|
114 |
+
)
|
115 |
+
blocks[layer].forward = types.MethodType(
|
116 |
+
hooked_resblock_forward, blocks[layer]
|
117 |
+
)
|
118 |
+
|
119 |
+
def _activate_timm_self_attention_hooks(self):
|
120 |
+
# Adjusting to use the correct structure
|
121 |
+
blocks = self.visual.trunk.blocks
|
122 |
+
|
123 |
+
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
|
124 |
+
# Necessary steps to get intermediate representations
|
125 |
+
for name, param in self.named_parameters():
|
126 |
+
param.requires_grad = False
|
127 |
+
if name.startswith("visual.trunk.blocks"):
|
128 |
+
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
|
129 |
+
if depth >= self.starting_depth:
|
130 |
+
param.requires_grad = True
|
131 |
+
|
132 |
+
# --- Activate the hooks for the specific layers ---
|
133 |
+
for layer in range(self.starting_depth, len(blocks)):
|
134 |
+
blocks[layer].attn.forward = types.MethodType(
|
135 |
+
hooked_attention_timm_forward, blocks[layer].attn
|
136 |
+
)
|
137 |
+
blocks[layer].forward = types.MethodType(
|
138 |
+
hooked_resblock_timm_forward, blocks[layer]
|
139 |
+
)
|
140 |
+
|
141 |
+
def _activate_att_pool_hooks(self, layer_index):
|
142 |
+
# ---------- Apply Hooks + Activate/Deactivate gradients ----------
|
143 |
+
# Necessary steps to get intermediate representations
|
144 |
+
for name, param in self.named_parameters():
|
145 |
+
param.requires_grad = False
|
146 |
+
if name.startswith("visual.transformer.resblocks"):
|
147 |
+
# get the depth
|
148 |
+
depth = int(
|
149 |
+
name.split("visual.transformer.resblocks.")[-1].split(".")[0]
|
150 |
+
)
|
151 |
+
if depth >= self.starting_depth:
|
152 |
+
param.requires_grad = True
|
153 |
+
|
154 |
+
# --- Activate the hooks for the specific layers ---
|
155 |
+
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
|
156 |
+
self.visual.transformer.resblocks[layer].forward = types.MethodType(
|
157 |
+
hooked_resblock_forward, self.visual.transformer.resblocks[layer]
|
158 |
+
)
|
159 |
+
# --- Apply hook on the attentional pooler ---
|
160 |
+
self.visual.attn_pool.attn.forward = types.MethodType(
|
161 |
+
hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn
|
162 |
+
)
|
163 |
+
|
164 |
+
def _activate_timm_attn_pool_hooks(self, layer_index):
|
165 |
+
# Ensure all components are present before attaching hooks
|
166 |
+
if (
|
167 |
+
not hasattr(self.visual.trunk, "attn_pool")
|
168 |
+
or self.visual.trunk.attn_pool is None
|
169 |
+
):
|
170 |
+
raise ValueError("Attentional pooling not found in TimmModel")
|
171 |
+
|
172 |
+
self.visual.trunk.attn_pool.forward = types.MethodType(
|
173 |
+
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
|
174 |
+
)
|
175 |
+
for block in self.visual.trunk.blocks:
|
176 |
+
if hasattr(block, "attn"):
|
177 |
+
block.attn.forward = types.MethodType(
|
178 |
+
hooked_attention_forward, block.attn
|
179 |
+
)
|
180 |
+
|
181 |
+
# --- Deactivate gradient for module that don't need it ---
|
182 |
+
for name, param in self.named_parameters():
|
183 |
+
param.requires_grad = False
|
184 |
+
if name.startswith("visual.trunk.attn_pool"):
|
185 |
+
param.requires_grad = True
|
186 |
+
if name.startswith("visual.trunk.blocks"):
|
187 |
+
# get the depth
|
188 |
+
depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
|
189 |
+
if depth >= self.starting_depth:
|
190 |
+
param.requires_grad = True
|
191 |
+
|
192 |
+
# --- Activate the hooks for the specific layers by modifying the block's forward ---
|
193 |
+
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
|
194 |
+
self.visual.trunk.blocks[layer].forward = types.MethodType(
|
195 |
+
hooked_resblock_timm_forward, self.visual.trunk.blocks[layer]
|
196 |
+
)
|
197 |
+
|
198 |
+
self.visual.trunk.attn_pool.forward = types.MethodType(
|
199 |
+
hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
|
200 |
+
)
|
201 |
+
|
202 |
+
def compute_legrad(self, text_embedding, image=None, apply_correction=True):
|
203 |
+
if "clip" in self.model_type:
|
204 |
+
return self.compute_legrad_clip(text_embedding, image)
|
205 |
+
elif "siglip" in self.model_type:
|
206 |
+
return self.compute_legrad_siglip(
|
207 |
+
text_embedding, image, apply_correction=apply_correction
|
208 |
+
)
|
209 |
+
elif "coca" in self.model_type:
|
210 |
+
return self.compute_legrad_coca(text_embedding, image)
|
211 |
+
|
212 |
+
def compute_legrad_clip(self, text_embedding, image=None):
|
213 |
+
num_prompts = text_embedding.shape[0]
|
214 |
+
if image is not None:
|
215 |
+
# Ensure the image is passed through the model to get the intermediate features
|
216 |
+
_ = self.encode_image(image)
|
217 |
+
|
218 |
+
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
|
219 |
+
|
220 |
+
image_features_list = []
|
221 |
+
|
222 |
+
for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
|
223 |
+
# [num_patch, batch, dim]
|
224 |
+
intermediate_feat = blocks_list[layer].feat_post_mlp
|
225 |
+
# Mean over the patch tokens
|
226 |
+
intermediate_feat = intermediate_feat.mean(dim=1)
|
227 |
+
intermediate_feat = self.visual.head(
|
228 |
+
self.visual.trunk.norm(intermediate_feat)
|
229 |
+
)
|
230 |
+
intermediate_feat = F.normalize(intermediate_feat, dim=-1)
|
231 |
+
image_features_list.append(intermediate_feat)
|
232 |
+
|
233 |
+
num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1
|
234 |
+
w = h = int(math.sqrt(num_tokens))
|
235 |
+
|
236 |
+
# ----- Get explainability map
|
237 |
+
accum_expl_map = 0
|
238 |
+
for layer, (blk, img_feat) in enumerate(
|
239 |
+
zip(blocks_list[self.starting_depth :], image_features_list)
|
240 |
+
):
|
241 |
+
self.visual.zero_grad()
|
242 |
+
sim = text_embedding @ img_feat.transpose(-1, -2) # [1, 1]
|
243 |
+
one_hot = (
|
244 |
+
F.one_hot(torch.arange(0, num_prompts))
|
245 |
+
.float()
|
246 |
+
.requires_grad_(True)
|
247 |
+
.to(text_embedding.device)
|
248 |
+
)
|
249 |
+
one_hot = torch.sum(one_hot * sim)
|
250 |
+
|
251 |
+
# [b, num_heads, N, N]
|
252 |
+
attn_map = blocks_list[self.starting_depth + layer].attn.attention_map
|
253 |
+
|
254 |
+
# -------- Get explainability map --------
|
255 |
+
# [batch_size * num_heads, N, N]
|
256 |
+
grad = torch.autograd.grad(
|
257 |
+
one_hot, [attn_map], retain_graph=True, create_graph=True
|
258 |
+
)[0]
|
259 |
+
# grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts) # separate batch and attn heads
|
260 |
+
grad = torch.clamp(grad, min=0.0)
|
261 |
+
|
262 |
+
# average attn over [CLS] + patch tokens
|
263 |
+
image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:]
|
264 |
+
expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h)
|
265 |
+
# [B, 1, H, W]
|
266 |
+
expl_map = F.interpolate(
|
267 |
+
expl_map, scale_factor=self.patch_size, mode="bilinear"
|
268 |
+
)
|
269 |
+
accum_expl_map += expl_map
|
270 |
+
|
271 |
+
# Min-Max Norm
|
272 |
+
accum_expl_map = min_max(accum_expl_map)
|
273 |
+
return accum_expl_map
|
274 |
+
|
275 |
+
def compute_legrad_coca(self, text_embedding, image=None):
|
276 |
+
if image is not None:
|
277 |
+
_ = self.encode_image(image)
|
278 |
+
|
279 |
+
blocks_list = list(
|
280 |
+
dict(self.visual.transformer.resblocks.named_children()).values()
|
281 |
+
)
|
282 |
+
|
283 |
+
image_features_list = []
|
284 |
+
|
285 |
+
for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
|
286 |
+
intermediate_feat = self.visual.transformer.resblocks[
|
287 |
+
layer
|
288 |
+
].feat_post_mlp # [num_patch, batch, dim]
|
289 |
+
intermediate_feat = intermediate_feat.permute(
|
290 |
+
1, 0, 2
|
291 |
+
) # [batch, num_patch, dim]
|
292 |
+
image_features_list.append(intermediate_feat)
|
293 |
+
|
294 |
+
num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1
|
295 |
+
w = h = int(math.sqrt(num_tokens))
|
296 |
+
|
297 |
+
# ----- Get explainability map
|
298 |
+
accum_expl_map = 0
|
299 |
+
for layer, (blk, img_feat) in enumerate(
|
300 |
+
zip(blocks_list[self.starting_depth :], image_features_list)
|
301 |
+
):
|
302 |
+
self.visual.zero_grad()
|
303 |
+
# --- Apply attn_pool ---
|
304 |
+
image_embedding = self.visual.attn_pool(img_feat)[
|
305 |
+
:, 0
|
306 |
+
] # we keep only the first pooled token as it is only this one trained with the contrastive loss
|
307 |
+
image_embedding = image_embedding @ self.visual.proj
|
308 |
+
|
309 |
+
sim = text_embedding @ image_embedding.transpose(-1, -2) # [1, 1]
|
310 |
+
one_hot = torch.sum(sim)
|
311 |
+
|
312 |
+
attn_map = (
|
313 |
+
self.visual.attn_pool.attn.attention_maps
|
314 |
+
) # [num_heads, num_latent, num_patch]
|
315 |
+
|
316 |
+
# -------- Get explainability map --------
|
317 |
+
grad = torch.autograd.grad(
|
318 |
+
one_hot, [attn_map], retain_graph=True, create_graph=True
|
319 |
+
)[
|
320 |
+
0
|
321 |
+
] # [num_heads, num_latent, num_patch]
|
322 |
+
grad = torch.clamp(grad, min=0.0)
|
323 |
+
|
324 |
+
image_relevance = grad.mean(dim=0)[
|
325 |
+
0, 1:
|
326 |
+
] # average attn over heads + select first latent
|
327 |
+
expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h)
|
328 |
+
expl_map = F.interpolate(
|
329 |
+
expl_map, scale_factor=self.patch_size, mode="bilinear"
|
330 |
+
) # [B, 1, H, W]
|
331 |
+
accum_expl_map += expl_map
|
332 |
+
|
333 |
+
# Min-Max Norm
|
334 |
+
accum_expl_map = (accum_expl_map - accum_expl_map.min()) / (
|
335 |
+
accum_expl_map.max() - accum_expl_map.min()
|
336 |
+
)
|
337 |
+
return accum_expl_map
|
338 |
+
|
339 |
+
def _init_empty_embedding(self):
|
340 |
+
if not hasattr(self, "empty_embedding"):
|
341 |
+
# For the moment only SigLIP is supported & they all have the same tokenizer
|
342 |
+
_tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP")
|
343 |
+
empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device)
|
344 |
+
empty_embedding = self.encode_text(empty_text)
|
345 |
+
empty_embedding = F.normalize(empty_embedding, dim=-1)
|
346 |
+
self.empty_embedding = empty_embedding.t()
|
347 |
+
|
348 |
+
def compute_legrad_siglip(
|
349 |
+
self,
|
350 |
+
text_embedding,
|
351 |
+
image=None,
|
352 |
+
apply_correction=True,
|
353 |
+
correction_threshold=0.8,
|
354 |
+
):
|
355 |
+
# --- Forward CLIP ---
|
356 |
+
blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
|
357 |
+
if image is not None:
|
358 |
+
_ = self.encode_image(image) # [bs, num_patch, dim] bs=num_masks
|
359 |
+
|
360 |
+
image_features_list = []
|
361 |
+
for blk in blocks_list[self.starting_depth :]:
|
362 |
+
intermediate_feat = blk.feat_post_mlp
|
363 |
+
image_features_list.append(intermediate_feat)
|
364 |
+
|
365 |
+
num_tokens = blocks_list[-1].feat_post_mlp.shape[1]
|
366 |
+
w = h = int(math.sqrt(num_tokens))
|
367 |
+
|
368 |
+
if apply_correction:
|
369 |
+
self._init_empty_embedding()
|
370 |
+
accum_expl_map_empty = 0
|
371 |
+
|
372 |
+
accum_expl_map = 0
|
373 |
+
for layer, (blk, img_feat) in enumerate(
|
374 |
+
zip(blocks_list[self.starting_depth :], image_features_list)
|
375 |
+
):
|
376 |
+
self.zero_grad()
|
377 |
+
pooled_feat = self.visual.trunk.attn_pool(img_feat)
|
378 |
+
pooled_feat = F.normalize(pooled_feat, dim=-1)
|
379 |
+
# -------- Get explainability map --------
|
380 |
+
sim = text_embedding @ pooled_feat.transpose(-1, -2) # [num_mask, num_mask]
|
381 |
+
one_hot = torch.sum(sim)
|
382 |
+
grad = torch.autograd.grad(
|
383 |
+
one_hot,
|
384 |
+
[self.visual.trunk.attn_pool.attn_probs],
|
385 |
+
retain_graph=True,
|
386 |
+
create_graph=True,
|
387 |
+
)[0]
|
388 |
+
grad = torch.clamp(grad, min=0.0)
|
389 |
+
|
390 |
+
image_relevance = grad.mean(dim=1)[
|
391 |
+
:, 0
|
392 |
+
] # average attn over [CLS] + patch tokens
|
393 |
+
expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h)
|
394 |
+
accum_expl_map += expl_map
|
395 |
+
|
396 |
+
if apply_correction:
|
397 |
+
# -------- Get empty explainability map --------
|
398 |
+
sim_empty = pooled_feat @ self.empty_embedding
|
399 |
+
one_hot_empty = torch.sum(sim_empty)
|
400 |
+
grad_empty = torch.autograd.grad(
|
401 |
+
one_hot_empty,
|
402 |
+
[self.visual.trunk.attn_pool.attn_probs],
|
403 |
+
retain_graph=True,
|
404 |
+
create_graph=True,
|
405 |
+
)[0]
|
406 |
+
grad_empty = torch.clamp(grad_empty, min=0.0)
|
407 |
+
|
408 |
+
image_relevance_empty = grad_empty.mean(dim=1)[
|
409 |
+
:, 0
|
410 |
+
] # average attn over heads + select query's row
|
411 |
+
expl_map_empty = rearrange(
|
412 |
+
image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h
|
413 |
+
)
|
414 |
+
accum_expl_map_empty += expl_map_empty
|
415 |
+
|
416 |
+
if apply_correction:
|
417 |
+
heatmap_empty = min_max(accum_expl_map_empty)
|
418 |
+
accum_expl_map[heatmap_empty > correction_threshold] = 0
|
419 |
+
|
420 |
+
Res = min_max(accum_expl_map)
|
421 |
+
Res = F.interpolate(
|
422 |
+
Res, scale_factor=self.patch_size, mode="bilinear"
|
423 |
+
) # [B, 1, H, W]
|
424 |
+
|
425 |
+
return Res
|
426 |
+
|
427 |
+
|
428 |
+
class LePreprocess(nn.Module):
|
429 |
+
"""
|
430 |
+
Modify OpenCLIP preprocessing to accept arbitrary image size.
|
431 |
+
"""
|
432 |
+
|
433 |
+
def __init__(self, preprocess, image_size):
|
434 |
+
super(LePreprocess, self).__init__()
|
435 |
+
self.transform = Compose(
|
436 |
+
[
|
437 |
+
Resize(
|
438 |
+
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
|
439 |
+
),
|
440 |
+
preprocess.transforms[-3],
|
441 |
+
preprocess.transforms[-2],
|
442 |
+
preprocess.transforms[-1],
|
443 |
+
]
|
444 |
+
)
|
445 |
+
|
446 |
+
def forward(self, image):
|
447 |
+
return self.transform(image)
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MedicalVisualPromptEngineering
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: pink
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.38.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import open_clip
|
5 |
+
import numpy as np
|
6 |
+
from LeGrad.legrad import LeWrapper, LePreprocess
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
# Load BiomedCLIP model
|
13 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
model_name = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
|
15 |
+
model, preprocess = open_clip.create_model_from_pretrained(
|
16 |
+
model_name=model_name, device=device
|
17 |
+
)
|
18 |
+
tokenizer = open_clip.get_tokenizer(model_name=model_name)
|
19 |
+
model = LeWrapper(model) # Equip the model with LeGrad
|
20 |
+
preprocess = LePreprocess(
|
21 |
+
preprocess=preprocess, image_size=448
|
22 |
+
) # Optional higher-res preprocessing
|
23 |
+
|
24 |
+
|
25 |
+
def classify_image_with_biomedclip(editor_value, prompts):
|
26 |
+
# editor_value is a dict with keys: 'background', 'layers', 'composite'
|
27 |
+
# The 'composite' key contains the final annotated image
|
28 |
+
|
29 |
+
if editor_value is None:
|
30 |
+
return None, None
|
31 |
+
|
32 |
+
# Get the composite image (background + annotations)
|
33 |
+
image = editor_value["composite"]
|
34 |
+
|
35 |
+
# Ensure image is in PIL format
|
36 |
+
if not isinstance(image, Image.Image):
|
37 |
+
image = Image.fromarray(image)
|
38 |
+
|
39 |
+
# Preprocess and encode the image
|
40 |
+
image_input = preprocess(image).unsqueeze(0).to(device)
|
41 |
+
text_inputs = tokenizer(prompts).to(device)
|
42 |
+
|
43 |
+
# Encode text and image
|
44 |
+
|
45 |
+
text_embeddings = model.encode_text(text_inputs, normalize=True)
|
46 |
+
image_embeddings = model.encode_image(image_input, normalize=True)
|
47 |
+
|
48 |
+
# Generate probabilities (optional - not required for LeGrad explanations but included for completeness)
|
49 |
+
similarity = (
|
50 |
+
model.logit_scale.exp() * image_embeddings @ text_embeddings.T
|
51 |
+
).softmax(dim=-1)
|
52 |
+
probabilities = similarity[0].detach().cpu().numpy()
|
53 |
+
explanation_maps = model.compute_legrad_clip(
|
54 |
+
image=image_input, text_embedding=text_embeddings[probabilities.argmax()]
|
55 |
+
)
|
56 |
+
|
57 |
+
# Convert explanation maps to heatmap
|
58 |
+
explanation_maps = explanation_maps.squeeze(0).detach().cpu().numpy()
|
59 |
+
explanation_map = (explanation_maps * 255).astype(np.uint8) # Rescale to [0, 255]
|
60 |
+
|
61 |
+
return probabilities, explanation_map
|
62 |
+
|
63 |
+
def update_output(editor_value, prompts_input):
|
64 |
+
prompts_list = [p.strip() for p in prompts_input.split(",") if p.strip()]
|
65 |
+
if not prompts_list:
|
66 |
+
return None, "Please enter at least one prompt."
|
67 |
+
|
68 |
+
probabilities, explanation_map = classify_image_with_biomedclip(
|
69 |
+
editor_value, prompts_list
|
70 |
+
)
|
71 |
+
|
72 |
+
if probabilities is None:
|
73 |
+
return None, "Please upload and annotate an image."
|
74 |
+
|
75 |
+
# Create probability display
|
76 |
+
prob_text = "\n".join(
|
77 |
+
[
|
78 |
+
f"{prompt}: {prob*100:.2f}%"
|
79 |
+
for prompt, prob in zip(prompts_list, probabilities)
|
80 |
+
]
|
81 |
+
)
|
82 |
+
|
83 |
+
# Prepare the explanation map overlay
|
84 |
+
image = editor_value["composite"]
|
85 |
+
if not isinstance(image, Image.Image):
|
86 |
+
image = Image.fromarray(image)
|
87 |
+
|
88 |
+
explanation_image = explanation_map[0]
|
89 |
+
if isinstance(explanation_image, torch.Tensor):
|
90 |
+
explanation_image = explanation_image.cpu().numpy()
|
91 |
+
|
92 |
+
# Resize the explanation map to match the size of the original image
|
93 |
+
explanation_image_resized = cv2.resize(
|
94 |
+
explanation_image, (image.width, image.height)
|
95 |
+
)
|
96 |
+
|
97 |
+
# Normalize the explanation map for proper colormap application
|
98 |
+
explanation_image_resized = cv2.normalize(
|
99 |
+
explanation_image_resized, None, 0, 255, cv2.NORM_MINMAX
|
100 |
+
)
|
101 |
+
|
102 |
+
# Apply the colormap (e.g., COLORMAP_JET)
|
103 |
+
explanation_colormap = cv2.applyColorMap(
|
104 |
+
explanation_image_resized.astype(np.uint8), cv2.COLORMAP_JET
|
105 |
+
)
|
106 |
+
|
107 |
+
# Convert the original image to a format that OpenCV understands (RGB to BGR)
|
108 |
+
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
109 |
+
|
110 |
+
# Blend the original image and the colormap
|
111 |
+
alpha = 0.5 # Transparency factor
|
112 |
+
blended_image = cv2.addWeighted(image_cv, 1 - alpha, explanation_colormap, alpha, 0)
|
113 |
+
|
114 |
+
# Convert back to RGB for displaying with PIL or matplotlib
|
115 |
+
blended_image_rgb = cv2.cvtColor(blended_image, cv2.COLOR_BGR2RGB)
|
116 |
+
output_image = Image.fromarray(blended_image_rgb)
|
117 |
+
|
118 |
+
return output_image, prob_text
|
119 |
+
|
120 |
+
|
121 |
+
def clear_inputs():
|
122 |
+
return None, ""
|
123 |
+
|
124 |
+
|
125 |
+
with gr.Blocks() as demo:
|
126 |
+
gr.Markdown(
|
127 |
+
"# ✨ Visual Prompt Engineering for Medical Vision Language Models in Radiology ✨",
|
128 |
+
elem_id="main-header",
|
129 |
+
)
|
130 |
+
|
131 |
+
gr.Markdown(
|
132 |
+
"This tool applies **visual prompt engineering to improve the classification of medical images using the BiomedCLIP**[3], the current state of the art in zero-shot biomedical image classification. By uploading biomedical images (e.g., chest X-rays), you can manually annotate areas of interest directly on the image. These annotations serve as visual prompts, which guide the model's attention on the region of interest. This technique improves the model's ability to focus on subtle yet important details.\n\n"
|
133 |
+
"After annotating and inputting text prompts (e.g., 'A chest X-ray with a benign/malignant lung nodule indicated by a red circle'), the tool returns classification results. These results are accompanied by **explainability maps** generated by **LeGrad** [3], which show where the model focused its attention, conditioned on the highest scoring text prompt. This helps to better interpret the model's decision-making process.\n\n"
|
134 |
+
"In our paper **[Visual Prompt Engineering for Medical Vision Language Models in Radiology](https://arxiv.org/pdf/2408.15802)**, we show, that visual prompts such as arrows, circles, and contours improve the zero-shot classification of biomedical vision language models in radiology."
|
135 |
+
)
|
136 |
+
|
137 |
+
gr.Markdown("---") # Horizontal rule for separation
|
138 |
+
|
139 |
+
gr.Markdown(
|
140 |
+
"## 📝 **How It Works**:\n"
|
141 |
+
"1. **Upload** a biomedical image.\n"
|
142 |
+
"2. **Annotate** the image using the built-in editor to highlight regions of interest.\n"
|
143 |
+
"3. **Enter text prompts** separated by comma (e.g., 'A chest X-ray with a (benign/malignant) lung nodule indicated by a red circle').\n"
|
144 |
+
"4. **Submit** to get class probabilities and an explainability map conditioned on the highest scoring text prompt."
|
145 |
+
)
|
146 |
+
|
147 |
+
gr.Markdown("---") # Horizontal rule for separation
|
148 |
+
|
149 |
+
with gr.Row():
|
150 |
+
with gr.Column():
|
151 |
+
image_editor = gr.ImageEditor(
|
152 |
+
label="Upload and Annotate Image",
|
153 |
+
type="pil",
|
154 |
+
interactive=True,
|
155 |
+
mirror_webcam=False,
|
156 |
+
layers=False,
|
157 |
+
# placeholder="Upload an image",
|
158 |
+
scale=2,
|
159 |
+
)
|
160 |
+
prompts_input = gr.Textbox(
|
161 |
+
placeholder="Enter prompts, comma-separated", label="Text Prompts"
|
162 |
+
)
|
163 |
+
submit_button = gr.Button("Submit", variant="primary")
|
164 |
+
with gr.Column():
|
165 |
+
output_image = gr.Image(
|
166 |
+
type="pil",
|
167 |
+
label="Output Image with Explanation Map",
|
168 |
+
)
|
169 |
+
prob_text = gr.Textbox(
|
170 |
+
label="Class Probabilities", interactive=False, lines=10
|
171 |
+
)
|
172 |
+
|
173 |
+
# Manually trigger the computation with the submit button
|
174 |
+
inputs = [image_editor, prompts_input]
|
175 |
+
outputs = [output_image, prob_text]
|
176 |
+
submit_button.click(fn=update_output, inputs=inputs, outputs=outputs)
|
177 |
+
|
178 |
+
gr.Markdown("---") # Horizontal rule for separation
|
179 |
+
|
180 |
+
gr.Markdown("### 📝 **References**:\n")
|
181 |
+
gr.Markdown(
|
182 |
+
"[1] Denner, S., Bujotzek, M., Bounias, D., Zimmerer, D., Stock, R., Jäger, P.F. and Maier-Hein, K., 2024. **Visual Prompt Engineering for Medical Vision Language Models in Radiology**. arXiv preprint arXiv:2408.15802."
|
183 |
+
)
|
184 |
+
gr.Markdown(
|
185 |
+
"[2] Zhang, S., Xu, Y., Usuyama, N., Bagga, J., Tinn, R., Preston, S., Rao, R., Wei, M., Valluri, N., Wong, C. and Lungren, M.P., 2023. **Large-scale domain-specific pretraining for biomedical vision-language processing**. arXiv preprint arXiv:2303.00915, 2(3), p.6.\n"
|
186 |
+
)
|
187 |
+
gr.Markdown(
|
188 |
+
"[3] Bousselham, W., Boggust, A., Chaybouti, S., Strobelt, H. and Kuehne, H., 2024. **LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity**. arXiv preprint arXiv:2404.03214."
|
189 |
+
)
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.42.0
|
2 |
+
torch==2.4.1
|
3 |
+
open_clip_torch==2.26.1
|
4 |
+
legrad_torch
|
5 |
+
transformers==4.44.2
|