Stefan Denner commited on
Commit
208214b
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ /venv
3
+ __pycache__
LeGrad/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Walid Bousselham, Angie Boggust, Sofian Chaybouti,Hendrik Strobelt Hilde Kuehne.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LeGrad/legrad/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .wrapper import LeWrapper, LePreprocess
2
+ from .utils import *
LeGrad/legrad/utils.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional, List, Tuple
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2 as cv2
7
+ import warnings
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ from torch.nn import functional as F
12
+
13
+ import open_clip
14
+ from open_clip import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
15
+ from open_clip.transformer import _expand_token
16
+ from timm.layers import resample_abs_pos_embed
17
+
18
+
19
+ ################################################################################
20
+ # Hooks utils #
21
+ ################################################################################
22
+
23
+
24
+ # ------------ Hooked Multi-Head Attention ------------
25
+ # from https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/transformer.py#L129
26
+ def hooked_attention_forward(
27
+ self,
28
+ x,
29
+ x_k,
30
+ x_v,
31
+ attn_mask: Optional[torch.Tensor] = None,
32
+ need_weights: bool = False,
33
+ ):
34
+ L, N, C = x.shape
35
+ q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
36
+ q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
37
+ k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
38
+ v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
39
+
40
+ head_dim = q.shape[-1]
41
+ scale = float(head_dim) ** -0.5
42
+ q = q * scale
43
+ attn = torch.bmm(q, k.transpose(-1, -2))
44
+
45
+ if attn_mask is not None:
46
+ if attn_mask.dtype == torch.bool:
47
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
48
+ new_attn_mask.masked_fill_(attn_mask, float("-inf"))
49
+ attn_mask = new_attn_mask
50
+ attn += attn_mask
51
+
52
+ attn = attn.softmax(dim=-1)
53
+ # Hook for attention maps
54
+ self.attention_map = attn
55
+
56
+ x = torch.bmm(attn, v)
57
+ x = x.transpose(0, 1).reshape(L, N, C)
58
+ x = self.out_proj(x)
59
+ return x
60
+
61
+
62
+ def hooked_attention_timm_forward(self, x, attn_mask=None):
63
+ B, N, C = x.shape
64
+ qkv = (
65
+ self.qkv(x)
66
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
67
+ .permute(2, 0, 3, 1, 4)
68
+ )
69
+ q, k, v = qkv.unbind(0)
70
+ q, k = self.q_norm(q), self.k_norm(k)
71
+
72
+ q = q * self.scale
73
+ attn = q @ k.transpose(-2, -1)
74
+ attn = attn.softmax(dim=-1)
75
+ attn = self.attn_drop(attn)
76
+ x = attn @ v
77
+
78
+ # Hook to save attention map for explainability
79
+ self.attention_map = attn
80
+
81
+ x = x.transpose(1, 2).reshape(B, N, C)
82
+ x = self.proj(x)
83
+ x = self.proj_drop(x)
84
+ return x
85
+
86
+
87
+ # ------------ Hooked Residual Transformer Block ------------
88
+ # from https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/transformer.py#L231
89
+ def hooked_resblock_forward(self, q_x, k_x=None, v_x=None, attn_mask=None):
90
+ assert k_x is None and v_x is None, "k_x and v_x must be None"
91
+
92
+ # Modify this line to include the necessary arguments for hooked_attention_forward
93
+ x = q_x + self.ls1(
94
+ self.attn(
95
+ self.norm1(q_x),
96
+ k_x=k_x,
97
+ v_x=v_x,
98
+ attn_mask=attn_mask,
99
+ )
100
+ )
101
+ # Hook for intermediate features post Attn
102
+ self.feat_post_attn = x
103
+ x = x + self.ls2(self.mlp(self.norm2(x)))
104
+
105
+ # Hook for intermediate features post MLP
106
+ self.feat_post_mlp = x
107
+ return x
108
+
109
+
110
+ # ------------ Hooked PyTorch's Multi-Head AttentionResidual ------------
111
+ # modified from PyTorch Library
112
+ # https://github.com/pytorch/pytorch/blob/8c8e4e31f2ddd8e59de18ac733c0c205c23d14ad/torch/nn/functional.py#L5178
113
+ def hooked_torch_multi_head_attention_forward(
114
+ self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None
115
+ ):
116
+ r"""
117
+ Args:
118
+ query, key, value: map a query and a set of key-value pairs to an output.
119
+ See "Attention Is All You Need" for more details.
120
+ key_padding_mask: if provided, specified padding elements in the key will
121
+ be ignored by the attention. When given a binary mask and a value is True,
122
+ the corresponding value on the attention layer will be ignored. When given
123
+ a byte mask and a value is non-zero, the corresponding value on the attention
124
+ layer will be ignored
125
+ need_weights: output attn_output_weights.
126
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
127
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
128
+
129
+ Shape:
130
+ - Inputs:
131
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
132
+ the embedding dimension.
133
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
134
+ the embedding dimension.
135
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
136
+ the embedding dimension.
137
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
138
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
139
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
140
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
141
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
142
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
143
+ S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
144
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
145
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
146
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
147
+ is provided, it will be added to the attention weight.
148
+
149
+ - Outputs:
150
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
151
+ E is the embedding dimension.
152
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
153
+ L is the target sequence length, S is the source sequence length.
154
+ """
155
+ if not self._qkv_same_embed_dim:
156
+ out, _attn_maps = hooked_torch_func_multi_head_attention_forward(
157
+ query,
158
+ key,
159
+ value,
160
+ self.embed_dim,
161
+ self.num_heads,
162
+ self.in_proj_weight,
163
+ self.in_proj_bias,
164
+ self.bias_k,
165
+ self.bias_v,
166
+ self.add_zero_attn,
167
+ self.dropout,
168
+ self.out_proj.weight,
169
+ self.out_proj.bias,
170
+ training=self.training,
171
+ key_padding_mask=key_padding_mask,
172
+ need_weights=True,
173
+ attn_mask=attn_mask,
174
+ use_separate_proj_weight=True,
175
+ q_proj_weight=self.q_proj_weight,
176
+ k_proj_weight=self.k_proj_weight,
177
+ v_proj_weight=self.v_proj_weight,
178
+ )
179
+ # Hook for attention maps
180
+ self.attention_maps = _attn_maps
181
+ return out, _attn_maps
182
+ else:
183
+ out, _attn_maps = hooked_torch_func_multi_head_attention_forward(
184
+ query,
185
+ key,
186
+ value,
187
+ self.embed_dim,
188
+ self.num_heads,
189
+ self.in_proj_weight,
190
+ self.in_proj_bias,
191
+ self.bias_k,
192
+ self.bias_v,
193
+ self.add_zero_attn,
194
+ self.dropout,
195
+ self.out_proj.weight,
196
+ self.out_proj.bias,
197
+ training=self.training,
198
+ key_padding_mask=key_padding_mask,
199
+ need_weights=True,
200
+ attn_mask=attn_mask,
201
+ )
202
+ # Hook for attention maps
203
+ self.attention_maps = _attn_maps
204
+ return out, _attn_maps
205
+
206
+
207
+ def hooked_torch_func_multi_head_attention_forward(
208
+ query: Tensor,
209
+ key: Tensor,
210
+ value: Tensor,
211
+ embed_dim_to_check: int,
212
+ num_heads: int,
213
+ in_proj_weight: Tensor,
214
+ in_proj_bias: Tensor,
215
+ bias_k: Optional[Tensor],
216
+ bias_v: Optional[Tensor],
217
+ add_zero_attn: bool,
218
+ dropout_p: float,
219
+ out_proj_weight: Tensor,
220
+ out_proj_bias: Tensor,
221
+ training: bool = True,
222
+ key_padding_mask: Optional[Tensor] = None,
223
+ need_weights: bool = True,
224
+ attn_mask: Optional[Tensor] = None,
225
+ use_separate_proj_weight: bool = False,
226
+ q_proj_weight: Optional[Tensor] = None,
227
+ k_proj_weight: Optional[Tensor] = None,
228
+ v_proj_weight: Optional[Tensor] = None,
229
+ static_k: Optional[Tensor] = None,
230
+ static_v: Optional[Tensor] = None,
231
+ ) -> Tuple[Tensor, Optional[Tensor]]:
232
+ if not torch.jit.is_scripting():
233
+ tens_ops = (
234
+ query,
235
+ key,
236
+ value,
237
+ in_proj_weight,
238
+ in_proj_bias,
239
+ bias_k,
240
+ bias_v,
241
+ out_proj_weight,
242
+ out_proj_bias,
243
+ )
244
+ if any([type(t) is not Tensor for t in tens_ops]) and F.has_torch_function(
245
+ tens_ops
246
+ ):
247
+ return F.handle_torch_function(
248
+ multi_head_attention_forward,
249
+ tens_ops,
250
+ query,
251
+ key,
252
+ value,
253
+ embed_dim_to_check,
254
+ num_heads,
255
+ in_proj_weight,
256
+ in_proj_bias,
257
+ bias_k,
258
+ bias_v,
259
+ add_zero_attn,
260
+ dropout_p,
261
+ out_proj_weight,
262
+ out_proj_bias,
263
+ training=training,
264
+ key_padding_mask=key_padding_mask,
265
+ need_weights=need_weights,
266
+ attn_mask=attn_mask,
267
+ use_separate_proj_weight=use_separate_proj_weight,
268
+ q_proj_weight=q_proj_weight,
269
+ k_proj_weight=k_proj_weight,
270
+ v_proj_weight=v_proj_weight,
271
+ static_k=static_k,
272
+ static_v=static_v,
273
+ )
274
+ tgt_len, bsz, embed_dim = query.size()
275
+ assert embed_dim == embed_dim_to_check
276
+ # allow MHA to have different sizes for the feature dimension
277
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
278
+
279
+ head_dim = embed_dim // num_heads
280
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
281
+ scaling = float(head_dim) ** -0.5
282
+
283
+ if not use_separate_proj_weight:
284
+ if torch.equal(query, key) and torch.equal(key, value):
285
+ # self-attention
286
+ q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
287
+
288
+ elif torch.equal(key, value):
289
+ # encoder-decoder attention
290
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
291
+ _b = in_proj_bias
292
+ _start = 0
293
+ _end = embed_dim
294
+ _w = in_proj_weight[_start:_end, :]
295
+ if _b is not None:
296
+ _b = _b[_start:_end]
297
+ q = F.linear(query, _w, _b)
298
+
299
+ if key is None:
300
+ assert value is None
301
+ k = None
302
+ v = None
303
+ else:
304
+
305
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
306
+ _b = in_proj_bias
307
+ _start = embed_dim
308
+ _end = None
309
+ _w = in_proj_weight[_start:, :]
310
+ if _b is not None:
311
+ _b = _b[_start:]
312
+ k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
313
+
314
+ else:
315
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
316
+ _b = in_proj_bias
317
+ _start = 0
318
+ _end = embed_dim
319
+ _w = in_proj_weight[_start:_end, :]
320
+ if _b is not None:
321
+ _b = _b[_start:_end]
322
+ q = F.linear(query, _w, _b)
323
+
324
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
325
+ _b = in_proj_bias
326
+ _start = embed_dim
327
+ _end = embed_dim * 2
328
+ _w = in_proj_weight[_start:_end, :]
329
+ if _b is not None:
330
+ _b = _b[_start:_end]
331
+ k = F.linear(key, _w, _b)
332
+
333
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
334
+ _b = in_proj_bias
335
+ _start = embed_dim * 2
336
+ _end = None
337
+ _w = in_proj_weight[_start:, :]
338
+ if _b is not None:
339
+ _b = _b[_start:]
340
+ v = F.linear(value, _w, _b)
341
+ else:
342
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
343
+ len1, len2 = q_proj_weight_non_opt.size()
344
+ assert len1 == embed_dim and len2 == query.size(-1)
345
+
346
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
347
+ len1, len2 = k_proj_weight_non_opt.size()
348
+ assert len1 == embed_dim and len2 == key.size(-1)
349
+
350
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
351
+ len1, len2 = v_proj_weight_non_opt.size()
352
+ assert len1 == embed_dim and len2 == value.size(-1)
353
+
354
+ if in_proj_bias is not None:
355
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
356
+ k = F.linear(
357
+ key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)]
358
+ )
359
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
360
+ else:
361
+ q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
362
+ k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
363
+ v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
364
+ q = q * scaling
365
+
366
+ if attn_mask is not None:
367
+ assert (
368
+ attn_mask.dtype == torch.float32
369
+ or attn_mask.dtype == torch.float64
370
+ or attn_mask.dtype == torch.float16
371
+ or attn_mask.dtype == torch.uint8
372
+ or attn_mask.dtype == torch.bool
373
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
374
+ attn_mask.dtype
375
+ )
376
+ if attn_mask.dtype == torch.uint8:
377
+ warnings.warn(
378
+ "Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
379
+ )
380
+ attn_mask = attn_mask.to(torch.bool)
381
+
382
+ if attn_mask.dim() == 2:
383
+ attn_mask = attn_mask.unsqueeze(0)
384
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
385
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
386
+ elif attn_mask.dim() == 3:
387
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
388
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
389
+ else:
390
+ raise RuntimeError(
391
+ "attn_mask's dimension {} is not supported".format(attn_mask.dim())
392
+ )
393
+ # attn_mask's dim is 3 now.
394
+
395
+ # convert ByteTensor key_padding_mask to bool
396
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
397
+ warnings.warn(
398
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
399
+ )
400
+ key_padding_mask = key_padding_mask.to(torch.bool)
401
+
402
+ if bias_k is not None and bias_v is not None:
403
+ if static_k is None and static_v is None:
404
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
405
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
406
+ if attn_mask is not None:
407
+ attn_mask = pad(attn_mask, (0, 1))
408
+ if key_padding_mask is not None:
409
+ key_padding_mask = pad(key_padding_mask, (0, 1))
410
+ else:
411
+ assert static_k is None, "bias cannot be added to static key."
412
+ assert static_v is None, "bias cannot be added to static value."
413
+ else:
414
+ assert bias_k is None
415
+ assert bias_v is None
416
+
417
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
418
+ if k is not None:
419
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
420
+ if v is not None:
421
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
422
+
423
+ if static_k is not None:
424
+ assert static_k.size(0) == bsz * num_heads
425
+ assert static_k.size(2) == head_dim
426
+ k = static_k
427
+
428
+ if static_v is not None:
429
+ assert static_v.size(0) == bsz * num_heads
430
+ assert static_v.size(2) == head_dim
431
+ v = static_v
432
+
433
+ src_len = k.size(1)
434
+
435
+ if key_padding_mask is not None:
436
+ assert key_padding_mask.size(0) == bsz
437
+ assert key_padding_mask.size(1) == src_len
438
+
439
+ if add_zero_attn:
440
+ src_len += 1
441
+ k = torch.cat(
442
+ [
443
+ k,
444
+ torch.zeros(
445
+ (k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device
446
+ ),
447
+ ],
448
+ dim=1,
449
+ )
450
+ v = torch.cat(
451
+ [
452
+ v,
453
+ torch.zeros(
454
+ (v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device
455
+ ),
456
+ ],
457
+ dim=1,
458
+ )
459
+ if attn_mask is not None:
460
+ attn_mask = pad(attn_mask, (0, 1))
461
+ if key_padding_mask is not None:
462
+ key_padding_mask = pad(key_padding_mask, (0, 1))
463
+
464
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
465
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
466
+
467
+ if attn_mask is not None:
468
+ if attn_mask.dtype == torch.bool:
469
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
470
+ else:
471
+ attn_output_weights += attn_mask
472
+
473
+ if key_padding_mask is not None:
474
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
475
+ attn_output_weights = attn_output_weights.masked_fill(
476
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
477
+ float("-inf"),
478
+ )
479
+ attn_output_weights = attn_output_weights.view(
480
+ bsz * num_heads, tgt_len, src_len
481
+ )
482
+
483
+ attn_output_weights = F.softmax(attn_output_weights, dim=-1)
484
+ attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
485
+
486
+ # # use hooks for the attention weights if necessary
487
+ # self.attention_map = attn_output_weights
488
+ # # if attention_probs_forward_hook is not None and attention_probs_backwards_hook is not None:
489
+ # if attention_probs_forward_hook is not None:
490
+ # attention_probs_forward_hook(attn_output_weights)
491
+ # # attn_output_weights.register_hook(attention_probs_backwards_hook)
492
+
493
+ attn_output = torch.bmm(attn_output_weights, v)
494
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
495
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
496
+ attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
497
+
498
+ if need_weights:
499
+ # --- Fix: removed the unnecessary average over heads, Why?
500
+ # average attention weights over heads
501
+ # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
502
+ # return attn_output, attn_output_weights.sum(dim=1) / num_heads
503
+ return attn_output, attn_output_weights
504
+ else:
505
+ return attn_output, None
506
+
507
+
508
+ # ------------ Hooked TimmModel's Residual Transformer Block ------------
509
+ def hooked_resblock_timm_forward(self, x: torch.Tensor) -> torch.Tensor:
510
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
511
+ self.feat_post_attn = x
512
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
513
+ self.feat_post_mlp = x
514
+ return x
515
+
516
+
517
+ # ------------ Hooked TimmModel's Attentional Pooler ------------
518
+ def hooked_attentional_pooler_timm_forward(self, x):
519
+ B, N, C = x.shape
520
+
521
+ if self.pos_embed is not None:
522
+ # FIXME interpolate
523
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
524
+
525
+ q_latent = self.latent.expand(B, -1, -1)
526
+ q = (
527
+ self.q(q_latent)
528
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
529
+ .transpose(1, 2)
530
+ )
531
+
532
+ kv = (
533
+ self.kv(x)
534
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
535
+ .permute(2, 0, 3, 1, 4)
536
+ )
537
+ k, v = kv.unbind(0)
538
+
539
+ q, k = self.q_norm(q), self.k_norm(k)
540
+
541
+ q = q * self.scale
542
+ attn = q @ k.transpose(-2, -1)
543
+ attn = attn.softmax(dim=-1)
544
+ x = attn @ v
545
+
546
+ # Hook to save attention map for explainability
547
+ self.attn_probs = attn
548
+
549
+ x = x.transpose(1, 2).reshape(B, self.latent_len, C)
550
+ x = self.proj(x)
551
+ x = self.proj_drop(x)
552
+
553
+ x = x + self.mlp(self.norm(x))
554
+
555
+ # optional pool if latent seq_len > 1 and pooled output is desired
556
+ if self.pool == "token":
557
+ x = x[:, 0]
558
+ elif self.pool == "avg":
559
+ x = x.mean(1)
560
+ return x
561
+
562
+
563
+ # ------------ OpenCLIP ViT forward with dynamic size ------------
564
+ def vit_dynamic_size_forward(self, x: torch.Tensor):
565
+ x = self.conv1(x) # shape = [*, width, grid, grid]
566
+ grid_h, grid_w = x.shape[2:]
567
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
568
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
569
+
570
+ # class embeddings and positional embeddings
571
+ x = torch.cat(
572
+ [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
573
+ )
574
+ # shape = [*, grid ** 2 + 1, width]
575
+ if x.shape[1] != self.positional_embedding.shape[1]:
576
+ self.positional_embedding.data = resample_abs_pos_embed(
577
+ self.positional_embedding.unsqueeze(0),
578
+ new_size=[grid_h, grid_w],
579
+ # old_size=list(self.grid_size),
580
+ num_prefix_tokens=1,
581
+ interpolation="bicubic",
582
+ antialias=True,
583
+ )
584
+
585
+ x = x + self.positional_embedding.to(x.dtype)
586
+
587
+ x = self.patch_dropout(x)
588
+ x = self.ln_pre(x)
589
+
590
+ x = x.permute(1, 0, 2) # NLD -> LND
591
+ x = self.transformer(x)
592
+ x = x.permute(1, 0, 2) # LND -> NLD
593
+
594
+ if self.attn_pool is not None:
595
+ if self.attn_pool_contrastive is not None:
596
+ # This is untested, WIP pooling that should match paper
597
+ x = self.ln_post(x) # TBD LN first or separate one after each pool?
598
+ tokens = self.attn_pool(x)
599
+ if self.attn_pool_type == "parallel":
600
+ pooled = self.attn_pool_contrastive(x)
601
+ else:
602
+ assert self.attn_pool_type == "cascade"
603
+ pooled = self.attn_pool_contrastive(tokens)
604
+ else:
605
+ # this is the original OpenCLIP CoCa setup, does not match paper
606
+ x = self.attn_pool(x)
607
+ x = self.ln_post(x)
608
+ pooled, tokens = self._global_pool(x)
609
+ elif self.final_ln_after_pool:
610
+ pooled, tokens = self._global_pool(x)
611
+ pooled = self.ln_post(pooled)
612
+ else:
613
+ x = self.ln_post(x)
614
+ pooled, tokens = self._global_pool(x)
615
+
616
+ if self.proj is not None:
617
+ pooled = pooled @ self.proj
618
+
619
+ if self.output_tokens:
620
+ return pooled, tokens
621
+
622
+ return pooled
623
+
624
+
625
+ ################################################################################
626
+ # Visualization utils #
627
+ ################################################################################
628
+
629
+
630
+ def min_max(logits):
631
+ B, num_prompt = logits.shape[:2]
632
+ logits_min = (
633
+ logits.reshape(B, num_prompt, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
634
+ )
635
+ logits_max = (
636
+ logits.reshape(B, num_prompt, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
637
+ )
638
+ logits = (logits - logits_min) / (logits_max - logits_min)
639
+ return logits
640
+
641
+
642
+ def visualize(image, heatmaps, alpha=0.6, save_path: Path = None):
643
+ # heatmaps of shape (N, 1, W, H)
644
+ W, H = heatmaps.shape[-2:]
645
+ if isinstance(image, Image.Image):
646
+ image = image.resize((W, H))
647
+ elif isinstance(image, torch.Tensor):
648
+ if image.ndim > 3:
649
+ image = image.squeeze(0)
650
+ # undo the normalization
651
+ image_unormed = (
652
+ image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]
653
+ ) + torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None]
654
+ # convert to PIL
655
+ image = Image.fromarray(
656
+ (image_unormed.permute(1, 2, 0).numpy() * 255).astype("uint8")
657
+ )
658
+ else:
659
+ raise f"image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}"
660
+
661
+ # plot image
662
+ plt.imshow(image)
663
+ plt.axis("off")
664
+ plt.tight_layout()
665
+ plt.show()
666
+
667
+ if heatmaps.ndim > 3:
668
+ heatmaps = heatmaps.squeeze(0)
669
+ heatmaps = heatmaps.detach().cpu().numpy()
670
+
671
+ img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
672
+ heatmaps = (heatmaps * 255).astype("uint8")
673
+ heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in heatmaps]
674
+
675
+ vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps]
676
+ for i, viz in enumerate(vizs):
677
+ viz = cv2.cvtColor(viz.astype("uint8"), cv2.COLOR_BGR2RGB)
678
+ plt.imshow(viz)
679
+ plt.axis("off")
680
+ plt.tight_layout()
681
+ # remove the margin
682
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
683
+ plt.show()
684
+ if save_path is not None:
685
+ plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
686
+ print(f"Saved visualization at {save_path}")
687
+
688
+
689
+ def list_pretrained():
690
+ openclip_list_ = open_clip.list_pretrained()
691
+ filtered_list = [
692
+ (model_name, pretrained)
693
+ for (model_name, pretrained) in openclip_list_
694
+ if model_name
695
+ ]
696
+ unsupported_models = [
697
+ "RN",
698
+ "convnext",
699
+ ] # legrad doesn't support CNN-based VLMs (for the moment)
700
+ _str = (
701
+ ": ".join(["model_name" + " " * (25 - len("model_name")), "pretrained"]) + "\n"
702
+ ) # for nice display
703
+ for model_name, pretrained in openclip_list_:
704
+ for unsup_model in unsupported_models:
705
+ if unsup_model in model_name:
706
+ skip = True
707
+ break
708
+ else:
709
+ skip = False
710
+ if not skip:
711
+ filtered_list.append((model_name, pretrained))
712
+ _str += (
713
+ ": ".join([model_name + " " * (25 - len(model_name)), pretrained])
714
+ + "\n"
715
+ ) # for nice display
716
+
717
+ print(_str)
718
+ return filtered_list
719
+
720
+
721
+ if __name__ == "__main__":
722
+ list_pretrained()
LeGrad/legrad/wrapper.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import types
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision.transforms import Compose, Resize, InterpolationMode
7
+ import open_clip
8
+ from open_clip.transformer import VisionTransformer
9
+ from open_clip.timm_model import TimmModel
10
+ from einops import rearrange
11
+
12
+ from .utils import (
13
+ hooked_attention_timm_forward,
14
+ hooked_resblock_forward,
15
+ hooked_attention_forward,
16
+ hooked_resblock_timm_forward,
17
+ hooked_attentional_pooler_timm_forward,
18
+ vit_dynamic_size_forward,
19
+ min_max,
20
+ hooked_torch_multi_head_attention_forward,
21
+ )
22
+
23
+
24
+ class LeWrapper(nn.Module):
25
+ """
26
+ Wrapper around OpenCLIP to add LeGrad to OpenCLIP's model while keep all the functionalities of the original model.
27
+ """
28
+
29
+ def __init__(self, model, layer_index=-2):
30
+ super(LeWrapper, self).__init__()
31
+ # ------------ copy of model's attributes and methods ------------
32
+ for attr in dir(model):
33
+ if not attr.startswith("__"):
34
+ setattr(self, attr, getattr(model, attr))
35
+
36
+ # ------------ activate hooks & gradient ------------
37
+ self._activate_hooks(layer_index=layer_index)
38
+
39
+ def _activate_hooks(self, layer_index):
40
+ # ------------ identify model's type ------------
41
+ print("Activating necessary hooks and gradients ....")
42
+ if isinstance(self.visual, VisionTransformer):
43
+ # --- Activate dynamic image size ---
44
+ self.visual.forward = types.MethodType(
45
+ vit_dynamic_size_forward, self.visual
46
+ )
47
+ # Get patch size
48
+ self.patch_size = self.visual.patch_size[0]
49
+ # Get starting depth (in case of negative layer_index)
50
+ self.starting_depth = (
51
+ layer_index
52
+ if layer_index >= 0
53
+ else len(self.visual.transformer.resblocks) + layer_index
54
+ )
55
+
56
+ if self.visual.attn_pool is None:
57
+ self.model_type = "clip"
58
+ self._activate_self_attention_hooks()
59
+ else:
60
+ self.model_type = "coca"
61
+ self._activate_att_pool_hooks(layer_index=layer_index)
62
+
63
+ elif isinstance(self.visual, TimmModel):
64
+ # --- Activate dynamic image size ---
65
+ self.visual.trunk.dynamic_img_size = True
66
+ self.visual.trunk.patch_embed.dynamic_img_size = True
67
+ self.visual.trunk.patch_embed.strict_img_size = False
68
+ self.visual.trunk.patch_embed.flatten = False
69
+ self.visual.trunk.patch_embed.output_fmt = "NHWC"
70
+ self.model_type = "timm_siglip"
71
+ # --- Get patch size ---
72
+ self.patch_size = self.visual.trunk.patch_embed.patch_size[0]
73
+ # --- Get starting depth (in case of negative layer_index) ---
74
+ self.starting_depth = (
75
+ layer_index
76
+ if layer_index >= 0
77
+ else len(self.visual.trunk.blocks) + layer_index
78
+ )
79
+ if (
80
+ hasattr(self.visual.trunk, "attn_pool")
81
+ and self.visual.trunk.attn_pool is not None
82
+ ):
83
+ self._activate_timm_attn_pool_hooks(layer_index=layer_index)
84
+ else:
85
+ self._activate_timm_self_attention_hooks()
86
+ else:
87
+ raise ValueError(
88
+ "Model currently not supported, see legrad.list_pretrained() for a list of available models"
89
+ )
90
+ print("Hooks and gradients activated!")
91
+
92
+ def _activate_self_attention_hooks(self):
93
+ # Adjusting to use the correct structure
94
+ if isinstance(self.visual, VisionTransformer):
95
+ blocks = self.visual.transformer.resblocks
96
+ elif isinstance(self.visual, TimmModel):
97
+ blocks = self.visual.trunk.blocks
98
+ else:
99
+ raise ValueError("Unsupported model type for self-attention hooks")
100
+
101
+ # ---------- Apply Hooks + Activate/Deactivate gradients ----------
102
+ # Necessary steps to get intermediate representations
103
+ for name, param in self.named_parameters():
104
+ param.requires_grad = False
105
+ if name.startswith("visual.trunk.blocks"):
106
+ depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
107
+ if depth >= self.starting_depth:
108
+ param.requires_grad = True
109
+
110
+ # --- Activate the hooks for the specific layers ---
111
+ for layer in range(self.starting_depth, len(blocks)):
112
+ blocks[layer].attn.forward = types.MethodType(
113
+ hooked_attention_forward, blocks[layer].attn
114
+ )
115
+ blocks[layer].forward = types.MethodType(
116
+ hooked_resblock_forward, blocks[layer]
117
+ )
118
+
119
+ def _activate_timm_self_attention_hooks(self):
120
+ # Adjusting to use the correct structure
121
+ blocks = self.visual.trunk.blocks
122
+
123
+ # ---------- Apply Hooks + Activate/Deactivate gradients ----------
124
+ # Necessary steps to get intermediate representations
125
+ for name, param in self.named_parameters():
126
+ param.requires_grad = False
127
+ if name.startswith("visual.trunk.blocks"):
128
+ depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
129
+ if depth >= self.starting_depth:
130
+ param.requires_grad = True
131
+
132
+ # --- Activate the hooks for the specific layers ---
133
+ for layer in range(self.starting_depth, len(blocks)):
134
+ blocks[layer].attn.forward = types.MethodType(
135
+ hooked_attention_timm_forward, blocks[layer].attn
136
+ )
137
+ blocks[layer].forward = types.MethodType(
138
+ hooked_resblock_timm_forward, blocks[layer]
139
+ )
140
+
141
+ def _activate_att_pool_hooks(self, layer_index):
142
+ # ---------- Apply Hooks + Activate/Deactivate gradients ----------
143
+ # Necessary steps to get intermediate representations
144
+ for name, param in self.named_parameters():
145
+ param.requires_grad = False
146
+ if name.startswith("visual.transformer.resblocks"):
147
+ # get the depth
148
+ depth = int(
149
+ name.split("visual.transformer.resblocks.")[-1].split(".")[0]
150
+ )
151
+ if depth >= self.starting_depth:
152
+ param.requires_grad = True
153
+
154
+ # --- Activate the hooks for the specific layers ---
155
+ for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
156
+ self.visual.transformer.resblocks[layer].forward = types.MethodType(
157
+ hooked_resblock_forward, self.visual.transformer.resblocks[layer]
158
+ )
159
+ # --- Apply hook on the attentional pooler ---
160
+ self.visual.attn_pool.attn.forward = types.MethodType(
161
+ hooked_torch_multi_head_attention_forward, self.visual.attn_pool.attn
162
+ )
163
+
164
+ def _activate_timm_attn_pool_hooks(self, layer_index):
165
+ # Ensure all components are present before attaching hooks
166
+ if (
167
+ not hasattr(self.visual.trunk, "attn_pool")
168
+ or self.visual.trunk.attn_pool is None
169
+ ):
170
+ raise ValueError("Attentional pooling not found in TimmModel")
171
+
172
+ self.visual.trunk.attn_pool.forward = types.MethodType(
173
+ hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
174
+ )
175
+ for block in self.visual.trunk.blocks:
176
+ if hasattr(block, "attn"):
177
+ block.attn.forward = types.MethodType(
178
+ hooked_attention_forward, block.attn
179
+ )
180
+
181
+ # --- Deactivate gradient for module that don't need it ---
182
+ for name, param in self.named_parameters():
183
+ param.requires_grad = False
184
+ if name.startswith("visual.trunk.attn_pool"):
185
+ param.requires_grad = True
186
+ if name.startswith("visual.trunk.blocks"):
187
+ # get the depth
188
+ depth = int(name.split("visual.trunk.blocks.")[-1].split(".")[0])
189
+ if depth >= self.starting_depth:
190
+ param.requires_grad = True
191
+
192
+ # --- Activate the hooks for the specific layers by modifying the block's forward ---
193
+ for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
194
+ self.visual.trunk.blocks[layer].forward = types.MethodType(
195
+ hooked_resblock_timm_forward, self.visual.trunk.blocks[layer]
196
+ )
197
+
198
+ self.visual.trunk.attn_pool.forward = types.MethodType(
199
+ hooked_attentional_pooler_timm_forward, self.visual.trunk.attn_pool
200
+ )
201
+
202
+ def compute_legrad(self, text_embedding, image=None, apply_correction=True):
203
+ if "clip" in self.model_type:
204
+ return self.compute_legrad_clip(text_embedding, image)
205
+ elif "siglip" in self.model_type:
206
+ return self.compute_legrad_siglip(
207
+ text_embedding, image, apply_correction=apply_correction
208
+ )
209
+ elif "coca" in self.model_type:
210
+ return self.compute_legrad_coca(text_embedding, image)
211
+
212
+ def compute_legrad_clip(self, text_embedding, image=None):
213
+ num_prompts = text_embedding.shape[0]
214
+ if image is not None:
215
+ # Ensure the image is passed through the model to get the intermediate features
216
+ _ = self.encode_image(image)
217
+
218
+ blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
219
+
220
+ image_features_list = []
221
+
222
+ for layer in range(self.starting_depth, len(self.visual.trunk.blocks)):
223
+ # [num_patch, batch, dim]
224
+ intermediate_feat = blocks_list[layer].feat_post_mlp
225
+ # Mean over the patch tokens
226
+ intermediate_feat = intermediate_feat.mean(dim=1)
227
+ intermediate_feat = self.visual.head(
228
+ self.visual.trunk.norm(intermediate_feat)
229
+ )
230
+ intermediate_feat = F.normalize(intermediate_feat, dim=-1)
231
+ image_features_list.append(intermediate_feat)
232
+
233
+ num_tokens = blocks_list[-1].feat_post_mlp.shape[1] - 1
234
+ w = h = int(math.sqrt(num_tokens))
235
+
236
+ # ----- Get explainability map
237
+ accum_expl_map = 0
238
+ for layer, (blk, img_feat) in enumerate(
239
+ zip(blocks_list[self.starting_depth :], image_features_list)
240
+ ):
241
+ self.visual.zero_grad()
242
+ sim = text_embedding @ img_feat.transpose(-1, -2) # [1, 1]
243
+ one_hot = (
244
+ F.one_hot(torch.arange(0, num_prompts))
245
+ .float()
246
+ .requires_grad_(True)
247
+ .to(text_embedding.device)
248
+ )
249
+ one_hot = torch.sum(one_hot * sim)
250
+
251
+ # [b, num_heads, N, N]
252
+ attn_map = blocks_list[self.starting_depth + layer].attn.attention_map
253
+
254
+ # -------- Get explainability map --------
255
+ # [batch_size * num_heads, N, N]
256
+ grad = torch.autograd.grad(
257
+ one_hot, [attn_map], retain_graph=True, create_graph=True
258
+ )[0]
259
+ # grad = rearrange(grad, '(b h) n m -> b h n m', b=num_prompts) # separate batch and attn heads
260
+ grad = torch.clamp(grad, min=0.0)
261
+
262
+ # average attn over [CLS] + patch tokens
263
+ image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:]
264
+ expl_map = rearrange(image_relevance, "b (w h) -> 1 b w h", w=w, h=h)
265
+ # [B, 1, H, W]
266
+ expl_map = F.interpolate(
267
+ expl_map, scale_factor=self.patch_size, mode="bilinear"
268
+ )
269
+ accum_expl_map += expl_map
270
+
271
+ # Min-Max Norm
272
+ accum_expl_map = min_max(accum_expl_map)
273
+ return accum_expl_map
274
+
275
+ def compute_legrad_coca(self, text_embedding, image=None):
276
+ if image is not None:
277
+ _ = self.encode_image(image)
278
+
279
+ blocks_list = list(
280
+ dict(self.visual.transformer.resblocks.named_children()).values()
281
+ )
282
+
283
+ image_features_list = []
284
+
285
+ for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)):
286
+ intermediate_feat = self.visual.transformer.resblocks[
287
+ layer
288
+ ].feat_post_mlp # [num_patch, batch, dim]
289
+ intermediate_feat = intermediate_feat.permute(
290
+ 1, 0, 2
291
+ ) # [batch, num_patch, dim]
292
+ image_features_list.append(intermediate_feat)
293
+
294
+ num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1
295
+ w = h = int(math.sqrt(num_tokens))
296
+
297
+ # ----- Get explainability map
298
+ accum_expl_map = 0
299
+ for layer, (blk, img_feat) in enumerate(
300
+ zip(blocks_list[self.starting_depth :], image_features_list)
301
+ ):
302
+ self.visual.zero_grad()
303
+ # --- Apply attn_pool ---
304
+ image_embedding = self.visual.attn_pool(img_feat)[
305
+ :, 0
306
+ ] # we keep only the first pooled token as it is only this one trained with the contrastive loss
307
+ image_embedding = image_embedding @ self.visual.proj
308
+
309
+ sim = text_embedding @ image_embedding.transpose(-1, -2) # [1, 1]
310
+ one_hot = torch.sum(sim)
311
+
312
+ attn_map = (
313
+ self.visual.attn_pool.attn.attention_maps
314
+ ) # [num_heads, num_latent, num_patch]
315
+
316
+ # -------- Get explainability map --------
317
+ grad = torch.autograd.grad(
318
+ one_hot, [attn_map], retain_graph=True, create_graph=True
319
+ )[
320
+ 0
321
+ ] # [num_heads, num_latent, num_patch]
322
+ grad = torch.clamp(grad, min=0.0)
323
+
324
+ image_relevance = grad.mean(dim=0)[
325
+ 0, 1:
326
+ ] # average attn over heads + select first latent
327
+ expl_map = rearrange(image_relevance, "(w h) -> 1 1 w h", w=w, h=h)
328
+ expl_map = F.interpolate(
329
+ expl_map, scale_factor=self.patch_size, mode="bilinear"
330
+ ) # [B, 1, H, W]
331
+ accum_expl_map += expl_map
332
+
333
+ # Min-Max Norm
334
+ accum_expl_map = (accum_expl_map - accum_expl_map.min()) / (
335
+ accum_expl_map.max() - accum_expl_map.min()
336
+ )
337
+ return accum_expl_map
338
+
339
+ def _init_empty_embedding(self):
340
+ if not hasattr(self, "empty_embedding"):
341
+ # For the moment only SigLIP is supported & they all have the same tokenizer
342
+ _tok = open_clip.get_tokenizer(model_name="ViT-B-16-SigLIP")
343
+ empty_text = _tok(["a photo of a"]).to(self.logit_scale.data.device)
344
+ empty_embedding = self.encode_text(empty_text)
345
+ empty_embedding = F.normalize(empty_embedding, dim=-1)
346
+ self.empty_embedding = empty_embedding.t()
347
+
348
+ def compute_legrad_siglip(
349
+ self,
350
+ text_embedding,
351
+ image=None,
352
+ apply_correction=True,
353
+ correction_threshold=0.8,
354
+ ):
355
+ # --- Forward CLIP ---
356
+ blocks_list = list(dict(self.visual.trunk.blocks.named_children()).values())
357
+ if image is not None:
358
+ _ = self.encode_image(image) # [bs, num_patch, dim] bs=num_masks
359
+
360
+ image_features_list = []
361
+ for blk in blocks_list[self.starting_depth :]:
362
+ intermediate_feat = blk.feat_post_mlp
363
+ image_features_list.append(intermediate_feat)
364
+
365
+ num_tokens = blocks_list[-1].feat_post_mlp.shape[1]
366
+ w = h = int(math.sqrt(num_tokens))
367
+
368
+ if apply_correction:
369
+ self._init_empty_embedding()
370
+ accum_expl_map_empty = 0
371
+
372
+ accum_expl_map = 0
373
+ for layer, (blk, img_feat) in enumerate(
374
+ zip(blocks_list[self.starting_depth :], image_features_list)
375
+ ):
376
+ self.zero_grad()
377
+ pooled_feat = self.visual.trunk.attn_pool(img_feat)
378
+ pooled_feat = F.normalize(pooled_feat, dim=-1)
379
+ # -------- Get explainability map --------
380
+ sim = text_embedding @ pooled_feat.transpose(-1, -2) # [num_mask, num_mask]
381
+ one_hot = torch.sum(sim)
382
+ grad = torch.autograd.grad(
383
+ one_hot,
384
+ [self.visual.trunk.attn_pool.attn_probs],
385
+ retain_graph=True,
386
+ create_graph=True,
387
+ )[0]
388
+ grad = torch.clamp(grad, min=0.0)
389
+
390
+ image_relevance = grad.mean(dim=1)[
391
+ :, 0
392
+ ] # average attn over [CLS] + patch tokens
393
+ expl_map = rearrange(image_relevance, "b (w h) -> b 1 w h", w=w, h=h)
394
+ accum_expl_map += expl_map
395
+
396
+ if apply_correction:
397
+ # -------- Get empty explainability map --------
398
+ sim_empty = pooled_feat @ self.empty_embedding
399
+ one_hot_empty = torch.sum(sim_empty)
400
+ grad_empty = torch.autograd.grad(
401
+ one_hot_empty,
402
+ [self.visual.trunk.attn_pool.attn_probs],
403
+ retain_graph=True,
404
+ create_graph=True,
405
+ )[0]
406
+ grad_empty = torch.clamp(grad_empty, min=0.0)
407
+
408
+ image_relevance_empty = grad_empty.mean(dim=1)[
409
+ :, 0
410
+ ] # average attn over heads + select query's row
411
+ expl_map_empty = rearrange(
412
+ image_relevance_empty, "b (w h) -> b 1 w h", w=w, h=h
413
+ )
414
+ accum_expl_map_empty += expl_map_empty
415
+
416
+ if apply_correction:
417
+ heatmap_empty = min_max(accum_expl_map_empty)
418
+ accum_expl_map[heatmap_empty > correction_threshold] = 0
419
+
420
+ Res = min_max(accum_expl_map)
421
+ Res = F.interpolate(
422
+ Res, scale_factor=self.patch_size, mode="bilinear"
423
+ ) # [B, 1, H, W]
424
+
425
+ return Res
426
+
427
+
428
+ class LePreprocess(nn.Module):
429
+ """
430
+ Modify OpenCLIP preprocessing to accept arbitrary image size.
431
+ """
432
+
433
+ def __init__(self, preprocess, image_size):
434
+ super(LePreprocess, self).__init__()
435
+ self.transform = Compose(
436
+ [
437
+ Resize(
438
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
439
+ ),
440
+ preprocess.transforms[-3],
441
+ preprocess.transforms[-2],
442
+ preprocess.transforms[-1],
443
+ ]
444
+ )
445
+
446
+ def forward(self, image):
447
+ return self.transform(image)
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MedicalVisualPromptEngineering
3
+ emoji: 🐠
4
+ colorFrom: green
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.38.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import open_clip
5
+ import numpy as np
6
+ from LeGrad.legrad import LeWrapper, LePreprocess
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ # Load BiomedCLIP model
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model_name = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
15
+ model, preprocess = open_clip.create_model_from_pretrained(
16
+ model_name=model_name, device=device
17
+ )
18
+ tokenizer = open_clip.get_tokenizer(model_name=model_name)
19
+ model = LeWrapper(model) # Equip the model with LeGrad
20
+ preprocess = LePreprocess(
21
+ preprocess=preprocess, image_size=448
22
+ ) # Optional higher-res preprocessing
23
+
24
+
25
+ def classify_image_with_biomedclip(editor_value, prompts):
26
+ # editor_value is a dict with keys: 'background', 'layers', 'composite'
27
+ # The 'composite' key contains the final annotated image
28
+
29
+ if editor_value is None:
30
+ return None, None
31
+
32
+ # Get the composite image (background + annotations)
33
+ image = editor_value["composite"]
34
+
35
+ # Ensure image is in PIL format
36
+ if not isinstance(image, Image.Image):
37
+ image = Image.fromarray(image)
38
+
39
+ # Preprocess and encode the image
40
+ image_input = preprocess(image).unsqueeze(0).to(device)
41
+ text_inputs = tokenizer(prompts).to(device)
42
+
43
+ # Encode text and image
44
+
45
+ text_embeddings = model.encode_text(text_inputs, normalize=True)
46
+ image_embeddings = model.encode_image(image_input, normalize=True)
47
+
48
+ # Generate probabilities (optional - not required for LeGrad explanations but included for completeness)
49
+ similarity = (
50
+ model.logit_scale.exp() * image_embeddings @ text_embeddings.T
51
+ ).softmax(dim=-1)
52
+ probabilities = similarity[0].detach().cpu().numpy()
53
+ explanation_maps = model.compute_legrad_clip(
54
+ image=image_input, text_embedding=text_embeddings[probabilities.argmax()]
55
+ )
56
+
57
+ # Convert explanation maps to heatmap
58
+ explanation_maps = explanation_maps.squeeze(0).detach().cpu().numpy()
59
+ explanation_map = (explanation_maps * 255).astype(np.uint8) # Rescale to [0, 255]
60
+
61
+ return probabilities, explanation_map
62
+
63
+ def update_output(editor_value, prompts_input):
64
+ prompts_list = [p.strip() for p in prompts_input.split(",") if p.strip()]
65
+ if not prompts_list:
66
+ return None, "Please enter at least one prompt."
67
+
68
+ probabilities, explanation_map = classify_image_with_biomedclip(
69
+ editor_value, prompts_list
70
+ )
71
+
72
+ if probabilities is None:
73
+ return None, "Please upload and annotate an image."
74
+
75
+ # Create probability display
76
+ prob_text = "\n".join(
77
+ [
78
+ f"{prompt}: {prob*100:.2f}%"
79
+ for prompt, prob in zip(prompts_list, probabilities)
80
+ ]
81
+ )
82
+
83
+ # Prepare the explanation map overlay
84
+ image = editor_value["composite"]
85
+ if not isinstance(image, Image.Image):
86
+ image = Image.fromarray(image)
87
+
88
+ explanation_image = explanation_map[0]
89
+ if isinstance(explanation_image, torch.Tensor):
90
+ explanation_image = explanation_image.cpu().numpy()
91
+
92
+ # Resize the explanation map to match the size of the original image
93
+ explanation_image_resized = cv2.resize(
94
+ explanation_image, (image.width, image.height)
95
+ )
96
+
97
+ # Normalize the explanation map for proper colormap application
98
+ explanation_image_resized = cv2.normalize(
99
+ explanation_image_resized, None, 0, 255, cv2.NORM_MINMAX
100
+ )
101
+
102
+ # Apply the colormap (e.g., COLORMAP_JET)
103
+ explanation_colormap = cv2.applyColorMap(
104
+ explanation_image_resized.astype(np.uint8), cv2.COLORMAP_JET
105
+ )
106
+
107
+ # Convert the original image to a format that OpenCV understands (RGB to BGR)
108
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
109
+
110
+ # Blend the original image and the colormap
111
+ alpha = 0.5 # Transparency factor
112
+ blended_image = cv2.addWeighted(image_cv, 1 - alpha, explanation_colormap, alpha, 0)
113
+
114
+ # Convert back to RGB for displaying with PIL or matplotlib
115
+ blended_image_rgb = cv2.cvtColor(blended_image, cv2.COLOR_BGR2RGB)
116
+ output_image = Image.fromarray(blended_image_rgb)
117
+
118
+ return output_image, prob_text
119
+
120
+
121
+ def clear_inputs():
122
+ return None, ""
123
+
124
+
125
+ with gr.Blocks() as demo:
126
+ gr.Markdown(
127
+ "# ✨ Visual Prompt Engineering for Medical Vision Language Models in Radiology ✨",
128
+ elem_id="main-header",
129
+ )
130
+
131
+ gr.Markdown(
132
+ "This tool applies **visual prompt engineering to improve the classification of medical images using the BiomedCLIP**[3], the current state of the art in zero-shot biomedical image classification. By uploading biomedical images (e.g., chest X-rays), you can manually annotate areas of interest directly on the image. These annotations serve as visual prompts, which guide the model's attention on the region of interest. This technique improves the model's ability to focus on subtle yet important details.\n\n"
133
+ "After annotating and inputting text prompts (e.g., 'A chest X-ray with a benign/malignant lung nodule indicated by a red circle'), the tool returns classification results. These results are accompanied by **explainability maps** generated by **LeGrad** [3], which show where the model focused its attention, conditioned on the highest scoring text prompt. This helps to better interpret the model's decision-making process.\n\n"
134
+ "In our paper **[Visual Prompt Engineering for Medical Vision Language Models in Radiology](https://arxiv.org/pdf/2408.15802)**, we show, that visual prompts such as arrows, circles, and contours improve the zero-shot classification of biomedical vision language models in radiology."
135
+ )
136
+
137
+ gr.Markdown("---") # Horizontal rule for separation
138
+
139
+ gr.Markdown(
140
+ "## 📝 **How It Works**:\n"
141
+ "1. **Upload** a biomedical image.\n"
142
+ "2. **Annotate** the image using the built-in editor to highlight regions of interest.\n"
143
+ "3. **Enter text prompts** separated by comma (e.g., 'A chest X-ray with a (benign/malignant) lung nodule indicated by a red circle').\n"
144
+ "4. **Submit** to get class probabilities and an explainability map conditioned on the highest scoring text prompt."
145
+ )
146
+
147
+ gr.Markdown("---") # Horizontal rule for separation
148
+
149
+ with gr.Row():
150
+ with gr.Column():
151
+ image_editor = gr.ImageEditor(
152
+ label="Upload and Annotate Image",
153
+ type="pil",
154
+ interactive=True,
155
+ mirror_webcam=False,
156
+ layers=False,
157
+ # placeholder="Upload an image",
158
+ scale=2,
159
+ )
160
+ prompts_input = gr.Textbox(
161
+ placeholder="Enter prompts, comma-separated", label="Text Prompts"
162
+ )
163
+ submit_button = gr.Button("Submit", variant="primary")
164
+ with gr.Column():
165
+ output_image = gr.Image(
166
+ type="pil",
167
+ label="Output Image with Explanation Map",
168
+ )
169
+ prob_text = gr.Textbox(
170
+ label="Class Probabilities", interactive=False, lines=10
171
+ )
172
+
173
+ # Manually trigger the computation with the submit button
174
+ inputs = [image_editor, prompts_input]
175
+ outputs = [output_image, prob_text]
176
+ submit_button.click(fn=update_output, inputs=inputs, outputs=outputs)
177
+
178
+ gr.Markdown("---") # Horizontal rule for separation
179
+
180
+ gr.Markdown("### 📝 **References**:\n")
181
+ gr.Markdown(
182
+ "[1] Denner, S., Bujotzek, M., Bounias, D., Zimmerer, D., Stock, R., Jäger, P.F. and Maier-Hein, K., 2024. **Visual Prompt Engineering for Medical Vision Language Models in Radiology**. arXiv preprint arXiv:2408.15802."
183
+ )
184
+ gr.Markdown(
185
+ "[2] Zhang, S., Xu, Y., Usuyama, N., Bagga, J., Tinn, R., Preston, S., Rao, R., Wei, M., Valluri, N., Wong, C. and Lungren, M.P., 2023. **Large-scale domain-specific pretraining for biomedical vision-language processing**. arXiv preprint arXiv:2303.00915, 2(3), p.6.\n"
186
+ )
187
+ gr.Markdown(
188
+ "[3] Bousselham, W., Boggust, A., Chaybouti, S., Strobelt, H. and Kuehne, H., 2024. **LeGrad: An Explainability Method for Vision Transformers via Feature Formation Sensitivity**. arXiv preprint arXiv:2404.03214."
189
+ )
190
+
191
+ if __name__ == "__main__":
192
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.42.0
2
+ torch==2.4.1
3
+ open_clip_torch==2.26.1
4
+ legrad_torch
5
+ transformers==4.44.2