quandao92 commited on
Commit
71d05bb
·
verified ·
1 Parent(s): 0c8e042

Upload 48 files

Browse files
Files changed (48) hide show
  1. AnomalyCLIP_lib/AnomalyCLIP.py +531 -0
  2. AnomalyCLIP_lib/CLIP.py +436 -0
  3. AnomalyCLIP_lib/__init__.py +1 -0
  4. AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-38.pyc +0 -0
  5. AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-39.pyc +0 -0
  6. AnomalyCLIP_lib/__pycache__/CLIP.cpython-39.pyc +0 -0
  7. AnomalyCLIP_lib/__pycache__/__init__.cpython-38.pyc +0 -0
  8. AnomalyCLIP_lib/__pycache__/__init__.cpython-39.pyc +0 -0
  9. AnomalyCLIP_lib/__pycache__/build_model.cpython-38.pyc +0 -0
  10. AnomalyCLIP_lib/__pycache__/build_model.cpython-39.pyc +0 -0
  11. AnomalyCLIP_lib/__pycache__/clip.cpython-38.pyc +0 -0
  12. AnomalyCLIP_lib/__pycache__/clip_model.cpython-38.pyc +0 -0
  13. AnomalyCLIP_lib/__pycache__/clip_surgery_model.cpython-38.pyc +0 -0
  14. AnomalyCLIP_lib/__pycache__/constants.cpython-38.pyc +0 -0
  15. AnomalyCLIP_lib/__pycache__/constants.cpython-39.pyc +0 -0
  16. AnomalyCLIP_lib/__pycache__/model_load.cpython-38.pyc +0 -0
  17. AnomalyCLIP_lib/__pycache__/model_load.cpython-39.pyc +0 -0
  18. AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
  19. AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
  20. AnomalyCLIP_lib/__pycache__/transform.cpython-38.pyc +0 -0
  21. AnomalyCLIP_lib/__pycache__/transform.cpython-39.pyc +0 -0
  22. AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz +3 -0
  23. AnomalyCLIP_lib/build_model.py +50 -0
  24. AnomalyCLIP_lib/constants.py +2 -0
  25. AnomalyCLIP_lib/model_load.py +235 -0
  26. AnomalyCLIP_lib/simple_tokenizer.py +132 -0
  27. AnomalyCLIP_lib/transform.py +133 -0
  28. README.md +193 -183
  29. dataset_config/dataset_get_json.py +51 -0
  30. dataset_config/image_ground_truth.py +68 -0
  31. dataset_config/image_resize.py +13 -0
  32. requirements.txt +20 -0
  33. test.py +231 -0
  34. train.py +207 -0
  35. training_libs/__pycache__/dataset.cpython-39.pyc +0 -0
  36. training_libs/__pycache__/logger.cpython-39.pyc +0 -0
  37. training_libs/__pycache__/loss.cpython-39.pyc +0 -0
  38. training_libs/__pycache__/metrics.cpython-39.pyc +0 -0
  39. training_libs/__pycache__/prompt_ensemble.cpython-39.pyc +0 -0
  40. training_libs/__pycache__/utils.cpython-39.pyc +0 -0
  41. training_libs/__pycache__/visualization.cpython-39.pyc +0 -0
  42. training_libs/dataset.py +116 -0
  43. training_libs/logger.py +25 -0
  44. training_libs/loss.py +125 -0
  45. training_libs/metrics.py +60 -0
  46. training_libs/prompt_ensemble.py +273 -0
  47. training_libs/utils.py +24 -0
  48. training_libs/visualization.py +25 -0
AnomalyCLIP_lib/AnomalyCLIP.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+
8
+
9
+ class Bottleneck(nn.Module):
10
+ expansion = 4
11
+
12
+ def __init__(self, inplanes, planes, stride=1):
13
+ super().__init__()
14
+
15
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17
+ self.bn1 = nn.BatchNorm2d(planes)
18
+ self.relu1 = nn.ReLU(inplace=True)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.relu2 = nn.ReLU(inplace=True)
23
+
24
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
25
+
26
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
27
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
28
+ self.relu3 = nn.ReLU(inplace=True)
29
+
30
+ self.downsample = None
31
+ self.stride = stride
32
+
33
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
34
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
35
+ self.downsample = nn.Sequential(OrderedDict([
36
+ ("-1", nn.AvgPool2d(stride)),
37
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
38
+ ("1", nn.BatchNorm2d(planes * self.expansion))
39
+ ]))
40
+
41
+ def forward(self, x: torch.Tensor):
42
+ identity = x
43
+
44
+ out = self.relu1(self.bn1(self.conv1(x)))
45
+ out = self.relu2(self.bn2(self.conv2(out)))
46
+ out = self.avgpool(out)
47
+ out = self.bn3(self.conv3(out))
48
+
49
+ if self.downsample is not None:
50
+ identity = self.downsample(x)
51
+
52
+ out += identity
53
+ out = self.relu3(out)
54
+ return out
55
+
56
+
57
+ # implement attention module for v-v self-attention
58
+ class Attention(nn.Module):
59
+ def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ self.scale = qk_scale or head_dim ** -0.5
64
+
65
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
+ self.attn_drop = nn.Dropout(attn_drop)
67
+ self.proj = nn.Linear(out_dim, dim)
68
+ self.proj_drop = nn.Dropout(proj_drop)
69
+ self.settings = settings
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv[0], qkv[1], qkv[2]
75
+
76
+ # original self-attention for the original path
77
+ attn_ori = (q @ k.transpose(-2, -1)) * self.scale
78
+ attn_ori = attn_ori.softmax(dim=-1)
79
+ attn_ori = self.attn_drop(attn_ori)
80
+
81
+ # replace k & q by v
82
+ k = v
83
+ q = k
84
+
85
+ # self-attention, higher temperate for resnets performs better
86
+ attn = (q @ k.transpose(-2, -1)) * self.scale
87
+ attn = (attn).softmax(dim=-1)
88
+ attn = self.attn_drop(attn)
89
+
90
+ x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
91
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
92
+ x = self.proj_drop(self.proj(x))
93
+ x_ori = self.proj_drop(self.proj(x_ori))
94
+ return [x, x_ori]
95
+
96
+
97
+
98
+ class LayerNorm(nn.LayerNorm):
99
+ """Subclass torch's LayerNorm to handle fp16."""
100
+
101
+ def forward(self, x: torch.Tensor):
102
+ orig_type = x.dtype
103
+ ret = super().forward(x.type(torch.float32))
104
+ return ret.type(orig_type)
105
+
106
+
107
+ class QuickGELU(nn.Module):
108
+ def forward(self, x: torch.Tensor):
109
+ return x * torch.sigmoid(1.702 * x)
110
+
111
+
112
+ class ResidualAttentionBlock(nn.Module):
113
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details = None):
114
+ super().__init__()
115
+
116
+ self.attn = nn.MultiheadAttention(d_model, n_head)
117
+ self.ln_1 = LayerNorm(d_model)
118
+ self.mlp = nn.Sequential(OrderedDict([
119
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
120
+ ("gelu", QuickGELU()),
121
+ ("c_proj", nn.Linear(d_model * 4, d_model))
122
+ ]))
123
+ self.ln_2 = LayerNorm(d_model)
124
+ self.attn_mask = attn_mask
125
+
126
+ def attention(self, x: torch.Tensor):
127
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
128
+ if isinstance(self.attn, Attention):
129
+ x = x.transpose(0, 1)
130
+ x, x_ori = self.attn(x)
131
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
132
+ else:
133
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
134
+
135
+ def forward(self, x, whole = False, ffn = False):
136
+ # print("xxxxx",x.shape)
137
+ # dual paths for blocks deeper than "d"
138
+
139
+ if isinstance(self.attn, Attention):
140
+ if isinstance(x, list):
141
+ if not ffn:
142
+ x, x_ori = x
143
+ x_res = self.attention(self.ln_1(x_ori))
144
+ x_res, x_ori_res = x_res
145
+ x_ori += x_ori_res
146
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
147
+ x += x_res # skip ffn for the new path
148
+ # print('hellloooo')
149
+ return [x, x_ori]
150
+ else:
151
+ x, x_ori_1 = x
152
+ x_res = self.attention(self.ln_1(x_ori_1))
153
+ x_res, x_ori_res = x_res
154
+ x_ori = x_ori_1 + x_ori_res
155
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
156
+ x += x_res # skip ffn for the new path
157
+ x = x_res + x_ori_1
158
+ x = x + self.mlp(self.ln_2(x))
159
+ return [x, x_ori]
160
+ # start of dual path
161
+ else:
162
+ x_res = self.attention(self.ln_1(x))
163
+ if isinstance(x_res, list):
164
+ x_res, x_ori_res = x_res
165
+ x_ori = x + x_ori_res
166
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
167
+ x += x_res
168
+ return [x, x_ori]
169
+
170
+ # singl path before "d"
171
+ else:
172
+ x = x + self.attention(self.ln_1(x))
173
+ x = x + self.mlp(self.ln_2(x))
174
+ return x
175
+
176
+ class ResidualAttentionBlock_learnable_token(nn.Module):
177
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
178
+ text_layer=False, i = 0):
179
+ super().__init__()
180
+
181
+ self.attn = nn.MultiheadAttention(d_model, n_head)
182
+ self.ln_1 = LayerNorm(d_model)
183
+ self.mlp = nn.Sequential(OrderedDict([
184
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
185
+ ("gelu", QuickGELU()),
186
+ ("c_proj", nn.Linear(d_model * 4, d_model))
187
+ ]))
188
+ self.ln_2 = LayerNorm(d_model)
189
+ self.attn_mask = attn_mask
190
+
191
+ self.i = i
192
+ self.compound_prompt_nctx = design_details['learnabel_text_embedding_length']
193
+ self.text_layer = text_layer
194
+ if i == 0:
195
+ self.first_layer = True
196
+ else:
197
+ self.first_layer = False
198
+
199
+ def attention(self, x: torch.Tensor):
200
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
201
+ if isinstance(self.attn, Attention):
202
+ x = x.transpose(0, 1)
203
+ x, x_ori = self.attn(x)
204
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
205
+ else:
206
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
207
+
208
+ def forward(self, inputs):
209
+
210
+ # dual paths for blocks deeper than "d"
211
+ if isinstance(self.attn, Attention):
212
+ x = inputs[0]
213
+ if isinstance(x, list):
214
+ x, x_ori = x
215
+ x_res = self.attention(self.ln_1(x_ori))
216
+ x_res, x_ori_res = x_res
217
+ x_ori += x_ori_res
218
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
219
+ x += x_res # skip ffn for the new path
220
+ return [x, x_ori]
221
+
222
+ # start of dual path
223
+ else:
224
+ x_res = self.attention(self.ln_1(x))
225
+ if isinstance(x_res, list):
226
+ x_res, x_ori_res = x_res
227
+ x_ori = x + x_ori_res
228
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
229
+ x += x_res
230
+ return [x, x_ori]
231
+
232
+ # singl path before "d"
233
+ else:
234
+ x = inputs[0]
235
+ compound_prompts_deeper = inputs[1]
236
+ counter = inputs[2]
237
+ if not self.first_layer:
238
+ # First check if the ith layer needs compound prompts or not
239
+ if not (counter > len(compound_prompts_deeper) - 1):
240
+ # Appending the learnable tokens in different way
241
+ # x -> [77, NCLS, DIM]
242
+ # First remove the learnable tokens from previous layer
243
+ prefix = x[:1, :, :]
244
+ suffix = x[1 + self.compound_prompt_nctx:, :, :]
245
+ textual_context = compound_prompts_deeper[counter]
246
+ textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
247
+ # Add the learnable tokens of this layer with the input, replaced by previous
248
+ # layer learnable tokens
249
+ x = torch.cat([prefix, textual_context, suffix], dim=0)
250
+ # Once done, update the counter, so that the next time, it does not use same learnable tokens
251
+ counter += 1
252
+ x = x + self.attention(self.ln_1(x))
253
+ x = x + self.mlp(self.ln_2(x))
254
+ return [x, compound_prompts_deeper, counter]
255
+
256
+
257
+ class Transformer(nn.Module):
258
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False, design_details = None ,text_layer = False):
259
+ super().__init__()
260
+ self.width = width
261
+ self.layers = layers
262
+ self.text_layer = text_layer
263
+ self.design_deatails = design_details
264
+ print("text_layer", self.text_layer)
265
+ if self.text_layer and (design_details is not None):
266
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock_learnable_token(width, heads, attn_mask, design_details, text_layer, i=i) for i in range(layers)])
267
+ else:
268
+ self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask,) for i in range(layers)])
269
+
270
+ def ori_CLIP_with_patch_forward(self, x, out_layers):
271
+ idx = 0
272
+ out_tokens = []
273
+ for r in self.resblocks:
274
+ idx += 1
275
+ x = r(x)
276
+ if idx in out_layers:
277
+ if isinstance(x, list):
278
+ out_tokens.append(x[1])
279
+ else:
280
+ out_tokens.append(x)
281
+
282
+ return [x, x], out_tokens
283
+
284
+ def AnomalyCLIP_forward(self, x, out_layers, ffn):
285
+ idx = 0
286
+ out_tokens = []
287
+ for r in self.resblocks:
288
+ idx += 1
289
+ x = r(x, ffn = ffn)
290
+ # print("out_layers", out_layers, idx)
291
+ if idx in out_layers:
292
+ if isinstance(x, list):
293
+ out_tokens.append(x[0])
294
+ else:
295
+ out_tokens.append(x)
296
+ return x, out_tokens
297
+
298
+ def forward(self, x: torch.Tensor, out_layers = [6, 12, 18, 24], DPAM_layer = None, ffn = False):
299
+ # visual encoder forward
300
+ if not self.text_layer:
301
+ out_tokens = []
302
+
303
+ if DPAM_layer is None:
304
+ [x, x], out_tokens = self.ori_CLIP_with_patch_forward(x, out_layers)
305
+ return [x, x], out_tokens
306
+ else:
307
+ x, out_tokens = self.AnomalyCLIP_forward(x, out_layers, ffn)
308
+ return x, out_tokens
309
+ # text encoder forward
310
+ # ori text embedding
311
+ elif self.design_deatails is None:
312
+ for idx, r in enumerate(self.resblocks):
313
+ x = r(x)
314
+ return x
315
+ # insert learnable text embedding
316
+ elif self.design_deatails is not None:
317
+ for idx, r in enumerate(self.resblocks):
318
+ x = r(x)
319
+ return x[0]
320
+ def get_cast_dtype(self) -> torch.dtype:
321
+ return self.resblocks[0].mlp.c_fc.weight.dtype
322
+
323
+ class VisionTransformer(nn.Module):
324
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
325
+ super().__init__()
326
+ self.input_resolution = input_resolution
327
+ self.output_dim = output_dim
328
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
329
+
330
+ scale = width ** -0.5
331
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
332
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
333
+ self.ln_pre = LayerNorm(width)
334
+
335
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
336
+ self.attn = None
337
+ self.embed_dim = width
338
+ self.num_heads = heads
339
+
340
+ self.ln_post = LayerNorm(width)
341
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
342
+
343
+
344
+ @torch.no_grad()
345
+ def DAPM_replace(self, DPAM_layer):
346
+ if DPAM_layer is not None:
347
+ for i in range(1, DPAM_layer):
348
+ self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
349
+ self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
350
+ self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
351
+ self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
352
+ self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
353
+ self.transformer.resblocks[-i].attn = self.attn
354
+
355
+ @torch.no_grad()
356
+ def forward(self, x: torch.Tensor, features_list, ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
357
+
358
+ x = self.conv1(x) # shape = [*, width, grid, grid]
359
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
360
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
361
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
362
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
363
+ new_side = int((x.shape[1] - 1) ** 0.5)
364
+
365
+ # update the position embedding during inference for varied input size
366
+ if side != new_side:
367
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
368
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
369
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
370
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
371
+
372
+ pos = self.positional_embedding.to(x.dtype)
373
+ x = x + pos
374
+ x = self.ln_pre(x)
375
+
376
+ x = x.permute(1, 0, 2) # NLD -> LND
377
+ [x, x_ori], patch_tokens = self.transformer(x, features_list, DPAM_layer = DPAM_layer, ffn = ffn)
378
+
379
+
380
+ if True:
381
+ patch_token_list = []
382
+ for patch_token in patch_tokens:
383
+ patch_token = self.ln_post(patch_token.permute(1, 0, 2)) @ self.proj # LND -> NLD
384
+ patch_token_list.append(patch_token)
385
+ patch_tokens = patch_token_list
386
+
387
+ return x_ori[0, :, :] @ self.proj, patch_tokens
388
+
389
+
390
+ return x
391
+
392
+
393
+ from thop import profile
394
+ class AnomalyCLIP(nn.Module):
395
+ def __init__(self,
396
+ embed_dim: int,
397
+ # vision
398
+ image_resolution: int,
399
+ vision_layers: Union[Tuple[int, int, int, int], int],
400
+ vision_width: int,
401
+ vision_patch_size: int,
402
+ # text
403
+ context_length: int,
404
+ vocab_size: int,
405
+ transformer_width: int,
406
+ transformer_heads: int,
407
+ transformer_layers: int,
408
+ design_details = None
409
+ ):
410
+ super().__init__()
411
+
412
+ self.context_length = context_length
413
+
414
+ if isinstance(vision_layers, (tuple, list)):
415
+ vision_heads = vision_width * 32 // 64
416
+ self.visual = ModifiedResNet(
417
+ layers=vision_layers,
418
+ output_dim=embed_dim,
419
+ heads=vision_heads,
420
+ input_resolution=image_resolution,
421
+ width=vision_width
422
+ )
423
+ else:
424
+ vision_heads = vision_width // 64
425
+ self.visual = VisionTransformer(
426
+ input_resolution=image_resolution,
427
+ patch_size=vision_patch_size,
428
+ width=vision_width,
429
+ layers=vision_layers,
430
+ heads=vision_heads,
431
+ output_dim=embed_dim
432
+ )
433
+
434
+ self.transformer = Transformer(
435
+ width=transformer_width,
436
+ layers=transformer_layers,
437
+ heads=transformer_heads,
438
+ attn_mask=self.build_attention_mask(), text_layer=True, design_details=design_details
439
+ )
440
+
441
+ self.vocab_size = vocab_size
442
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
443
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
444
+ self.ln_final = LayerNorm(transformer_width)
445
+
446
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
447
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
448
+
449
+ self.initialize_parameters()
450
+
451
+ def initialize_parameters(self):
452
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
453
+ nn.init.normal_(self.positional_embedding, std=0.01)
454
+
455
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
456
+ attn_std = self.transformer.width ** -0.5
457
+ fc_std = (2 * self.transformer.width) ** -0.5
458
+ for block in self.transformer.resblocks:
459
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
460
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
461
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
462
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
463
+
464
+ if self.text_projection is not None:
465
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
466
+ def build_attention_mask(self):
467
+ # lazily create causal attention mask, with full attention between the vision tokens
468
+ # pytorch uses additive attention mask; fill with -inf
469
+ mask = torch.empty(self.context_length, self.context_length)
470
+ mask.fill_(float("-inf"))
471
+ mask.triu_(1) # zero out the lower diagonal
472
+ return mask
473
+
474
+ @property
475
+ def dtype(self):
476
+ return self.visual.conv1.weight.dtype
477
+
478
+ def encode_image(self, image, feature_list = [], ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
479
+ return self.visual(image.type(self.dtype), feature_list, ori_patch = ori_patch, proj_use = proj_use, DPAM_layer = DPAM_layer, ffn = ffn)
480
+
481
+
482
+ def encode_text(self, text):
483
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
484
+
485
+ x = x + self.positional_embedding.type(self.dtype)
486
+ x = x.permute(1, 0, 2) # NLD -> LND
487
+ x = self.transformer(x)
488
+ x = x.permute(1, 0, 2) # LND -> NLD
489
+ x = self.ln_final(x).type(self.dtype)
490
+
491
+ # x.shape = [batch_size, n_ctx, transformer.width]
492
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
493
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
494
+
495
+ return x
496
+
497
+ def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
498
+ cast_dtype = self.transformer.get_cast_dtype()
499
+
500
+ # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
501
+
502
+ # x = x + self.positional_embedding.to(cast_dtype)
503
+
504
+ x = prompts + self.positional_embedding.to(cast_dtype)
505
+ x = x.permute(1, 0, 2) # NLD -> LND
506
+ # print("test", x.shape, len(deep_compound_prompts_text))
507
+ if deep_compound_prompts_text is None:
508
+ x = self.transformer(x)
509
+ else:
510
+ x = self.transformer([x, deep_compound_prompts_text, 0])
511
+ x = x.permute(1, 0, 2) # LND -> NLD
512
+ x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
513
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
514
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
515
+ return x
516
+
517
+ def forward(self, image, text):
518
+ image_features = self.encode_image(image)
519
+ text_features = self.encode_text(text)
520
+
521
+ # normalized features
522
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
523
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
524
+
525
+ # cosine similarity as logits
526
+ logit_scale = self.logit_scale.exp()
527
+ logits_per_image = logit_scale * image_features @ text_features.t()
528
+ logits_per_text = logits_per_image.t()
529
+
530
+ # shape = [global_batch_size, global_batch_size]
531
+ return logits_per_image, logits_per_text
AnomalyCLIP_lib/CLIP.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+
72
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
73
+ new_side = int((x.shape[0] - 1) ** 0.5)
74
+
75
+ # update the position embedding during inference for varied input size
76
+ if side != new_side:
77
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
78
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
79
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
80
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
81
+
82
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
83
+ x, _ = F.multi_head_attention_forward(
84
+ query=x, key=x, value=x,
85
+ embed_dim_to_check=x.shape[-1],
86
+ num_heads=self.num_heads,
87
+ q_proj_weight=self.q_proj.weight,
88
+ k_proj_weight=self.k_proj.weight,
89
+ v_proj_weight=self.v_proj.weight,
90
+ in_proj_weight=None,
91
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
92
+ bias_k=None,
93
+ bias_v=None,
94
+ add_zero_attn=False,
95
+ dropout_p=0,
96
+ out_proj_weight=self.c_proj.weight,
97
+ out_proj_bias=self.c_proj.bias,
98
+ use_separate_proj_weight=True,
99
+ training=self.training,
100
+ need_weights=False
101
+ )
102
+
103
+ #return x[0]
104
+ return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
105
+
106
+
107
+ class ModifiedResNet(nn.Module):
108
+ """
109
+ A ResNet class that is similar to torchvision's but contains the following changes:
110
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
111
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
112
+ - The final pooling layer is a QKV attention instead of an average pool
113
+ """
114
+
115
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
116
+ super().__init__()
117
+ self.output_dim = output_dim
118
+ self.input_resolution = input_resolution
119
+
120
+ # the 3-layer stem
121
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
122
+ self.bn1 = nn.BatchNorm2d(width // 2)
123
+ self.relu1 = nn.ReLU(inplace=True)
124
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
125
+ self.bn2 = nn.BatchNorm2d(width // 2)
126
+ self.relu2 = nn.ReLU(inplace=True)
127
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
128
+ self.bn3 = nn.BatchNorm2d(width)
129
+ self.relu3 = nn.ReLU(inplace=True)
130
+ self.avgpool = nn.AvgPool2d(2)
131
+
132
+ # residual layers
133
+ self._inplanes = width # this is a *mutable* variable used during construction
134
+ self.layer1 = self._make_layer(width, layers[0])
135
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
136
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
137
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
138
+
139
+ embed_dim = width * 32 # the ResNet feature dimension
140
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
141
+
142
+ def _make_layer(self, planes, blocks, stride=1):
143
+ layers = [Bottleneck(self._inplanes, planes, stride)]
144
+
145
+ self._inplanes = planes * Bottleneck.expansion
146
+ for _ in range(1, blocks):
147
+ layers.append(Bottleneck(self._inplanes, planes))
148
+
149
+ return nn.Sequential(*layers)
150
+
151
+ def forward(self, x):
152
+ def stem(x):
153
+ x = self.relu1(self.bn1(self.conv1(x)))
154
+ x = self.relu2(self.bn2(self.conv2(x)))
155
+ x = self.relu3(self.bn3(self.conv3(x)))
156
+ x = self.avgpool(x)
157
+ return x
158
+
159
+ x = x.type(self.conv1.weight.dtype)
160
+ x = stem(x)
161
+ x = self.layer1(x)
162
+ x = self.layer2(x)
163
+ x = self.layer3(x)
164
+ x = self.layer4(x)
165
+ x = self.attnpool(x)
166
+
167
+ return x
168
+
169
+
170
+ class LayerNorm(nn.LayerNorm):
171
+ """Subclass torch's LayerNorm to handle fp16."""
172
+
173
+ def forward(self, x: torch.Tensor):
174
+ orig_type = x.dtype
175
+ ret = super().forward(x.type(torch.float32))
176
+ return ret.type(orig_type)
177
+
178
+
179
+ class QuickGELU(nn.Module):
180
+ def forward(self, x: torch.Tensor):
181
+ return x * torch.sigmoid(1.702 * x)
182
+
183
+
184
+ class ResidualAttentionBlock(nn.Module):
185
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
186
+ super().__init__()
187
+
188
+ self.attn = nn.MultiheadAttention(d_model, n_head)
189
+ self.ln_1 = LayerNorm(d_model)
190
+ self.mlp = nn.Sequential(OrderedDict([
191
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
192
+ ("gelu", QuickGELU()),
193
+ ("c_proj", nn.Linear(d_model * 4, d_model))
194
+ ]))
195
+ self.ln_2 = LayerNorm(d_model)
196
+ self.attn_mask = attn_mask
197
+ self.need_weights = need_weights
198
+
199
+ def attention(self, x: torch.Tensor):
200
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
201
+ if self.need_weights == False:
202
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
203
+ else:
204
+ return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ if self.need_weights == False:
208
+ x = x + self.attention(self.ln_1(x))
209
+ x = x + self.mlp(self.ln_2(x))
210
+ return x
211
+ else:
212
+ y, attn = self.attention(self.ln_1(x))
213
+ x = x + y
214
+ x = x + self.mlp(self.ln_2(x))
215
+ return x
216
+
217
+
218
+ class Transformer(nn.Module):
219
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
220
+ super().__init__()
221
+ self.width = width
222
+ self.layers = layers
223
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ return self.resblocks(x)
227
+
228
+ def get_cast_dtype(self) -> torch.dtype:
229
+ return self.resblocks[0].mlp.c_fc.weight.dtype
230
+
231
+
232
+
233
+ class VisionTransformer(nn.Module):
234
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
235
+ super().__init__()
236
+ self.input_resolution = input_resolution
237
+ self.output_dim = output_dim
238
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
239
+
240
+ scale = width ** -0.5
241
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
242
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
243
+ self.ln_pre = LayerNorm(width)
244
+
245
+ self.transformer = Transformer(width, layers, heads, need_weights=True)
246
+
247
+ self.ln_post = LayerNorm(width)
248
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
249
+
250
+ def forward(self, x: torch.Tensor):
251
+ x = self.conv1(x) # shape = [*, width, grid, grid]
252
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
253
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
254
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
+
256
+ #####################################################################################
257
+ side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
258
+ new_side = int((x.shape[1] - 1) ** 0.5)
259
+
260
+ # update the position embedding during inference for varied input size
261
+ if side != new_side:
262
+ new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
263
+ new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
264
+ new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
265
+ self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
266
+ #####################################################################################
267
+
268
+
269
+ x = x + self.positional_embedding.to(x.dtype)
270
+ x = self.ln_pre(x)
271
+
272
+ x = x.permute(1, 0, 2) # NLD -> LND
273
+ x = self.transformer(x)
274
+ x = x.permute(1, 0, 2) # LND -> NLD
275
+
276
+ #x = self.ln_post(x[:, 0, :])
277
+ x = self.ln_post(x) # return both cls token and image tokens
278
+
279
+ if self.proj is not None:
280
+ x = x @ self.proj
281
+
282
+ return x
283
+
284
+
285
+ class CLIP(nn.Module):
286
+ def __init__(self,
287
+ embed_dim: int,
288
+ # vision
289
+ image_resolution: int,
290
+ vision_layers: Union[Tuple[int, int, int, int], int],
291
+ vision_width: int,
292
+ vision_patch_size: int,
293
+ # text
294
+ context_length: int,
295
+ vocab_size: int,
296
+ transformer_width: int,
297
+ transformer_heads: int,
298
+ transformer_layers: int
299
+ ):
300
+ super().__init__()
301
+
302
+ self.context_length = context_length
303
+
304
+ if isinstance(vision_layers, (tuple, list)):
305
+ vision_heads = vision_width * 32 // 64
306
+ self.visual = ModifiedResNet(
307
+ layers=vision_layers,
308
+ output_dim=embed_dim,
309
+ heads=vision_heads,
310
+ input_resolution=image_resolution,
311
+ width=vision_width
312
+ )
313
+ else:
314
+ vision_heads = vision_width // 64
315
+ self.visual = VisionTransformer(
316
+ input_resolution=image_resolution,
317
+ patch_size=vision_patch_size,
318
+ width=vision_width,
319
+ layers=vision_layers,
320
+ heads=vision_heads,
321
+ output_dim=embed_dim
322
+ )
323
+
324
+ self.transformer = Transformer(
325
+ width=transformer_width,
326
+ layers=transformer_layers,
327
+ heads=transformer_heads,
328
+ attn_mask=self.build_attention_mask()
329
+ )
330
+
331
+ self.vocab_size = vocab_size
332
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
333
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
334
+ self.ln_final = LayerNorm(transformer_width)
335
+
336
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
337
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
338
+
339
+ self.initialize_parameters()
340
+
341
+ def initialize_parameters(self):
342
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
343
+ nn.init.normal_(self.positional_embedding, std=0.01)
344
+
345
+ if isinstance(self.visual, ModifiedResNet):
346
+ if self.visual.attnpool is not None:
347
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
348
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
349
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
350
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
351
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
352
+
353
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
354
+ for name, param in resnet_block.named_parameters():
355
+ if name.endswith("bn3.weight"):
356
+ nn.init.zeros_(param)
357
+
358
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
359
+ attn_std = self.transformer.width ** -0.5
360
+ fc_std = (2 * self.transformer.width) ** -0.5
361
+ for block in self.transformer.resblocks:
362
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
363
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
364
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
365
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
366
+
367
+ if self.text_projection is not None:
368
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
369
+
370
+ def build_attention_mask(self):
371
+ # lazily create causal attention mask, with full attention between the vision tokens
372
+ # pytorch uses additive attention mask; fill with -inf
373
+ mask = torch.empty(self.context_length, self.context_length)
374
+ mask.fill_(float("-inf"))
375
+ mask.triu_(1) # zero out the lower diagonal
376
+ return mask
377
+
378
+ @property
379
+ def dtype(self):
380
+ return self.visual.conv1.weight.dtype
381
+
382
+ def encode_image(self, image):
383
+ return self.visual(image.type(self.dtype))
384
+
385
+ def encode_text(self, text):
386
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
387
+
388
+ x = x + self.positional_embedding.type(self.dtype)
389
+ x = x.permute(1, 0, 2) # NLD -> LND
390
+ x = self.transformer(x)
391
+ x = x.permute(1, 0, 2) # LND -> NLD
392
+ x = self.ln_final(x).type(self.dtype)
393
+
394
+ # x.shape = [batch_size, n_ctx, transformer.width]
395
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
396
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
397
+
398
+ return x
399
+
400
+ def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
401
+ cast_dtype = self.transformer.get_cast_dtype()
402
+
403
+ # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
404
+
405
+ # x = x + self.positional_embedding.to(cast_dtype)
406
+
407
+ x = prompts + self.positional_embedding.to(cast_dtype)
408
+ x = x.permute(1, 0, 2) # NLD -> LND
409
+ # print("test", x.shape, len(deep_compound_prompts_text))
410
+ if deep_compound_prompts_text is None:
411
+ x = self.transformer(x)
412
+ else:
413
+ x = self.transformer([x, deep_compound_prompts_text, 0])
414
+ x = x.permute(1, 0, 2) # LND -> NLD
415
+ x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
416
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
417
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
418
+ return x
419
+
420
+
421
+
422
+ def forward(self, image, text):
423
+ image_features = self.encode_image(image)
424
+ text_features = self.encode_text(text)
425
+
426
+ # normalized features
427
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
428
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
429
+
430
+ # cosine similarity as logits
431
+ logit_scale = self.logit_scale.exp()
432
+ logits_per_image = logit_scale * image_features @ text_features.t()
433
+ logits_per_text = logits_per_image.t()
434
+
435
+ # shape = [global_batch_size, global_batch_size]
436
+ return logits_per_image, logits_per_text
AnomalyCLIP_lib/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_load import *
AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-38.pyc ADDED
Binary file (15.2 kB). View file
 
AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-39.pyc ADDED
Binary file (15.2 kB). View file
 
AnomalyCLIP_lib/__pycache__/CLIP.cpython-39.pyc ADDED
Binary file (13.9 kB). View file
 
AnomalyCLIP_lib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (199 Bytes). View file
 
AnomalyCLIP_lib/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (209 Bytes). View file
 
AnomalyCLIP_lib/__pycache__/build_model.cpython-38.pyc ADDED
Binary file (2.27 kB). View file
 
AnomalyCLIP_lib/__pycache__/build_model.cpython-39.pyc ADDED
Binary file (2.21 kB). View file
 
AnomalyCLIP_lib/__pycache__/clip.cpython-38.pyc ADDED
Binary file (19.7 kB). View file
 
AnomalyCLIP_lib/__pycache__/clip_model.cpython-38.pyc ADDED
Binary file (13.9 kB). View file
 
AnomalyCLIP_lib/__pycache__/clip_surgery_model.cpython-38.pyc ADDED
Binary file (20.7 kB). View file
 
AnomalyCLIP_lib/__pycache__/constants.cpython-38.pyc ADDED
Binary file (279 Bytes). View file
 
AnomalyCLIP_lib/__pycache__/constants.cpython-39.pyc ADDED
Binary file (289 Bytes). View file
 
AnomalyCLIP_lib/__pycache__/model_load.cpython-38.pyc ADDED
Binary file (7.79 kB). View file
 
AnomalyCLIP_lib/__pycache__/model_load.cpython-39.pyc ADDED
Binary file (7.87 kB). View file
 
AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-38.pyc ADDED
Binary file (5.82 kB). View file
 
AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-39.pyc ADDED
Binary file (5.79 kB). View file
 
AnomalyCLIP_lib/__pycache__/transform.cpython-38.pyc ADDED
Binary file (4.18 kB). View file
 
AnomalyCLIP_lib/__pycache__/transform.cpython-39.pyc ADDED
Binary file (4.16 kB). View file
 
AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
AnomalyCLIP_lib/build_model.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from .CLIP import CLIP
3
+ from .AnomalyCLIP import AnomalyCLIP
4
+
5
+ def build_model(name: str, state_dict: dict, design_details = None):
6
+ vit = "visual.proj" in state_dict
7
+
8
+ if vit:
9
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
10
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
11
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
12
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
13
+ image_resolution = vision_patch_size * grid_size
14
+ else:
15
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
16
+ vision_layers = tuple(counts)
17
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
18
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
19
+ vision_patch_size = None
20
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
21
+ image_resolution = output_width * 32
22
+
23
+ embed_dim = state_dict["text_projection"].shape[1]
24
+ context_length = state_dict["positional_embedding"].shape[0]
25
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
26
+ transformer_width = state_dict["ln_final.weight"].shape[0]
27
+ transformer_heads = transformer_width // 64
28
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
29
+ # print('name', name)
30
+ # if 'CS-' in name:
31
+ if design_details is not None:
32
+ model = AnomalyCLIP(
33
+ embed_dim,
34
+ image_resolution, vision_layers, vision_width, vision_patch_size,
35
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details
36
+ )
37
+ else:
38
+ model = CLIP(
39
+ embed_dim,
40
+ image_resolution, vision_layers, vision_width, vision_patch_size,
41
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
42
+ )
43
+
44
+ for key in ["input_resolution", "context_length", "vocab_size"]:
45
+ if key in state_dict:
46
+ del state_dict[key]
47
+
48
+ #convert_weights(model)
49
+ model.load_state_dict(state_dict)
50
+ return model.eval()
AnomalyCLIP_lib/constants.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
AnomalyCLIP_lib/model_load.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ from .build_model import build_model
15
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
16
+ from torchvision.transforms import InterpolationMode
17
+
18
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
19
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
20
+
21
+
22
+ __all__ = ["available_models", "load",
23
+ "get_similarity_map", "compute_similarity"]
24
+ _tokenizer = _Tokenizer()
25
+
26
+ _MODELS = {
27
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
28
+ }
29
+
30
+
31
+ def _download(
32
+ url: str,
33
+ cache_dir: Union[str, None] = None,
34
+ ):
35
+
36
+ if not cache_dir:
37
+ # cache_dir = os.path.expanduser("~/.cache/clip")
38
+ cache_dir = os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip")
39
+ os.makedirs(cache_dir, exist_ok=True)
40
+ filename = os.path.basename(url)
41
+
42
+ if 'openaipublic' in url:
43
+ expected_sha256 = url.split("/")[-2]
44
+ elif 'mlfoundations' in url:
45
+ expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
46
+ else:
47
+ expected_sha256 = ''
48
+
49
+ download_target = os.path.join(cache_dir, filename)
50
+
51
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
52
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
53
+
54
+ if os.path.isfile(download_target):
55
+ if expected_sha256:
56
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
57
+ return download_target
58
+ else:
59
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60
+ else:
61
+ return download_target
62
+
63
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
64
+ with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
65
+ while True:
66
+ buffer = source.read(8192)
67
+ if not buffer:
68
+ break
69
+
70
+ output.write(buffer)
71
+ loop.update(len(buffer))
72
+
73
+ if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
74
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
75
+
76
+ return download_target
77
+
78
+
79
+ def _convert_image_to_rgb(image):
80
+ return image.convert("RGB")
81
+
82
+
83
+ def _transform(n_px):
84
+ return Compose([
85
+ Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC),
86
+ #CenterCrop(n_px), # rm center crop to explain whole image
87
+ _convert_image_to_rgb,
88
+ ToTensor(),
89
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
90
+ ])
91
+
92
+
93
+ def available_models() -> List[str]:
94
+ """Returns the names of available CLIP models"""
95
+ return list(_MODELS.keys())
96
+
97
+
98
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
99
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
100
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
101
+ state_dict = checkpoint['state_dict']
102
+ else:
103
+ state_dict = checkpoint
104
+ if next(iter(state_dict.items()))[0].startswith('module'):
105
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
106
+ return state_dict
107
+
108
+ def load_checkpoint(model, checkpoint_path, strict=True):
109
+ state_dict = load_state_dict(checkpoint_path)
110
+ # detect old format and make compatible with new format
111
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
112
+ state_dict = convert_to_custom_text_state_dict(state_dict)
113
+ resize_pos_embed(state_dict, model)
114
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
115
+ return incompatible_keys
116
+
117
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None):
118
+ """Load a CLIP model
119
+
120
+ Parameters
121
+ ----------
122
+ name : str
123
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
124
+
125
+ device : Union[str, torch.device]
126
+ The device to put the loaded model
127
+
128
+ jit : bool
129
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
130
+
131
+ download_root: str
132
+ path to download the model files; by default, it uses "~/.cache/clip"
133
+
134
+ Returns
135
+ -------
136
+ model : torch.nn.Module
137
+ The CLIP model
138
+
139
+ preprocess : Callable[[PIL.Image], torch.Tensor]
140
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
141
+ """
142
+ print("name", name)
143
+ if name in _MODELS:
144
+ # model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
145
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip"))
146
+ elif os.path.isfile(name):
147
+ model_path = name
148
+ else:
149
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
150
+
151
+ with open(model_path, 'rb') as opened_file:
152
+ try:
153
+ # loading JIT archive
154
+ model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
155
+ state_dict = None
156
+ except RuntimeError:
157
+ # loading saved state dict
158
+ if jit:
159
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
160
+ jit = False
161
+ state_dict = torch.load(opened_file, map_location="cpu")
162
+
163
+ if not jit:
164
+ model = build_model(name, state_dict or model.state_dict(), design_details).to(device)
165
+ if str(device) == "cpu":
166
+ model.float()
167
+ return model, _transform(model.visual.input_resolution)
168
+
169
+ # patch the device names
170
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
171
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
172
+
173
+ def patch_device(module):
174
+ try:
175
+ graphs = [module.graph] if hasattr(module, "graph") else []
176
+ except RuntimeError:
177
+ graphs = []
178
+
179
+ if hasattr(module, "forward1"):
180
+ graphs.append(module.forward1.graph)
181
+
182
+ for graph in graphs:
183
+ for node in graph.findAllNodes("prim::Constant"):
184
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
185
+ node.copyAttributes(device_node)
186
+
187
+ model.apply(patch_device)
188
+ patch_device(model.encode_image)
189
+ patch_device(model.encode_text)
190
+
191
+ # patch dtype to float32 on CPU
192
+ if str(device) == "cpu":
193
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
194
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
195
+ float_node = float_input.node()
196
+
197
+ def patch_float(module):
198
+ try:
199
+ graphs = [module.graph] if hasattr(module, "graph") else []
200
+ except RuntimeError:
201
+ graphs = []
202
+
203
+ if hasattr(module, "forward1"):
204
+ graphs.append(module.forward1.graph)
205
+
206
+ for graph in graphs:
207
+ for node in graph.findAllNodes("aten::to"):
208
+ inputs = list(node.inputs())
209
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
210
+ if inputs[i].node()["value"] == 5:
211
+ inputs[i].node().copyAttributes(float_node)
212
+
213
+ model.apply(patch_float)
214
+ patch_float(model.encode_image)
215
+ patch_float(model.encode_text)
216
+
217
+ model.float()
218
+
219
+ return model, _transform(model.input_resolution.item())
220
+
221
+
222
+ def get_similarity_map(sm, shape):
223
+ side = int(sm.shape[1] ** 0.5)
224
+ sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
225
+ sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
226
+ sm = sm.permute(0, 2, 3, 1)
227
+ return sm
228
+
229
+
230
+ def compute_similarity(image_features, text_features, t=2):
231
+ prob_1 = image_features[:, :1, :] @ text_features.t()
232
+ b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
233
+ feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
234
+ similarity = feats.sum(-1)
235
+ return (similarity/0.07).softmax(-1), prob_1
AnomalyCLIP_lib/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
AnomalyCLIP_lib/transform.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass, asdict
3
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torchvision.transforms.functional as F
8
+
9
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
10
+ CenterCrop
11
+
12
+ from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
13
+
14
+
15
+ @dataclass
16
+ class AugmentationCfg:
17
+ scale: Tuple[float, float] = (0.9, 1.0)
18
+ ratio: Optional[Tuple[float, float]] = None
19
+ color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
20
+ interpolation: Optional[str] = None
21
+ re_prob: Optional[float] = None
22
+ re_count: Optional[int] = None
23
+ use_timm: bool = False
24
+
25
+
26
+ class ResizeMaxSize(nn.Module):
27
+
28
+ def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
29
+ super().__init__()
30
+ if not isinstance(max_size, int):
31
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
32
+ self.max_size = max_size
33
+ self.interpolation = interpolation
34
+ self.fn = min if fn == 'min' else min
35
+ self.fill = fill
36
+
37
+ def forward(self, img):
38
+ if isinstance(img, torch.Tensor):
39
+ height, width = img.shape[:2]
40
+ else:
41
+ width, height = img.size
42
+ scale = self.max_size / float(max(height, width))
43
+ if scale != 1.0:
44
+ new_size = tuple(round(dim * scale) for dim in (height, width))
45
+ img = F.resize(img, new_size, self.interpolation)
46
+ pad_h = self.max_size - new_size[0]
47
+ pad_w = self.max_size - new_size[1]
48
+ img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
49
+ return img
50
+
51
+
52
+ def _convert_to_rgb(image):
53
+ return image.convert('RGB')
54
+
55
+
56
+ def image_transform(
57
+ image_size: int,
58
+ is_train: bool,
59
+ mean: Optional[Tuple[float, ...]] = None,
60
+ std: Optional[Tuple[float, ...]] = None,
61
+ resize_longest_max: bool = False,
62
+ fill_color: int = 0,
63
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
64
+ ):
65
+ mean = mean or OPENAI_DATASET_MEAN
66
+ if not isinstance(mean, (list, tuple)):
67
+ mean = (mean,) * 3
68
+
69
+ std = std or OPENAI_DATASET_STD
70
+ if not isinstance(std, (list, tuple)):
71
+ std = (std,) * 3
72
+
73
+ if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
74
+ # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
75
+ image_size = image_size[0]
76
+
77
+ if isinstance(aug_cfg, dict):
78
+ aug_cfg = AugmentationCfg(**aug_cfg)
79
+ else:
80
+ aug_cfg = aug_cfg or AugmentationCfg()
81
+ normalize = Normalize(mean=mean, std=std)
82
+ if is_train:
83
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
84
+ use_timm = aug_cfg_dict.pop('use_timm', False)
85
+ if use_timm:
86
+ from timm.data import create_transform # timm can still be optional
87
+ if isinstance(image_size, (tuple, list)):
88
+ assert len(image_size) >= 2
89
+ input_size = (3,) + image_size[-2:]
90
+ else:
91
+ input_size = (3, image_size, image_size)
92
+ # by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
93
+ aug_cfg_dict.setdefault('interpolation', 'random')
94
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
95
+ train_transform = create_transform(
96
+ input_size=input_size,
97
+ is_training=True,
98
+ hflip=0.,
99
+ mean=mean,
100
+ std=std,
101
+ re_mode='pixel',
102
+ **aug_cfg_dict,
103
+ )
104
+ else:
105
+ train_transform = Compose([
106
+ RandomResizedCrop(
107
+ image_size,
108
+ scale=aug_cfg_dict.pop('scale'),
109
+ interpolation=InterpolationMode.BICUBIC,
110
+ ),
111
+ _convert_to_rgb,
112
+ ToTensor(),
113
+ normalize,
114
+ ])
115
+ if aug_cfg_dict:
116
+ warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
117
+ return train_transform
118
+ else:
119
+ if resize_longest_max:
120
+ transforms = [
121
+ ResizeMaxSize(image_size, fill=fill_color)
122
+ ]
123
+ else:
124
+ transforms = [
125
+ Resize(image_size, interpolation=InterpolationMode.BICUBIC),
126
+ CenterCrop(image_size),
127
+ ]
128
+ transforms.extend([
129
+ _convert_to_rgb,
130
+ ToTensor(),
131
+ normalize,
132
+ ])
133
+ return Compose(transforms)
README.md CHANGED
@@ -1,117 +1,80 @@
1
 
2
- # CLIP 기반 제품 결함 탐지 모델 카드
3
-
4
- ## 모델 세부사항
5
-
6
- ### 모델 설명
7
-
8
- 모델은 CLIP 기반의 이상 탐지 방법을 사용하여 제품 결함을 탐지합니다.
9
- 사전 훈련된 CLIP 모델을 fine-tuning하여 제품 이미지에서 결함을 식별하고, 생산 라인에서 품질 관리 및 결함 감지를 자동화합니다.
10
-
11
- - **Developed by:** 오석
12
- - **Funded by:** 4INLAB INC.
13
- - **Shared by:** zhou2023anomalyclip
14
- - **Model type:** CLIP based Anomaly Detection
15
- - **Language(s):** Python, PyTorch
16
- - **License:** Apache 2.0, MIT, GPL-3.0
17
-
18
- ### 기술적 제한사항
19
-
20
- - 모델은 결함 탐지를 위한 충분하고 다양한 훈련 데이터를 필요로 합니다. 훈련 데이터셋이 부족하거나 불균형할 경우, 모델의 성능이 저하될 수 있습니다.
21
- - 실시간 결함 감지 성능은 하드웨어 사양에 따라 달라질 수 있으며, 높은 해상도에서 결함을 탐지하는 정확도가 떨어질 수 있습니다.
22
- - 결함이 미세하거나 제품 간 유사성이 매우 높은 경우, 모델이 결함을 정확하게 탐지하지 못할 수 있습니다.
23
-
24
- ## 학습 세부사항
25
-
26
- ### Hardware
27
- - **CPU:** Intel Core i9-13900K (24 Cores, 32 Threads)
28
- - **RAM:** 64GB DDR5
29
- - **GPU:** NVIDIA RTX 4090Ti 24GB
30
- - **Storage:** 1TB NVMe SSD + 2TB HDD
31
- - **Operating System:** Windows 11 pro
32
-
33
- ### 데이터셋 정보
34
-
35
- 이 모델은 시계열 재고 데이터를 사용하여 훈련됩니다. 이 데이터는 재고 수준, 날짜 및 기타 관련 특성에 대한 정보를 포함하고 있습니다.
36
- 데이터는 Conv1D와 BiLSTM 레이어에 적합하도록 MinMax 스케일링을 사용하여 전처리되고 정규화됩니다.
37
-
38
-
39
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/E8pMyLfUnlIQFCLbTiLba.png)
40
-
41
- - **Data sources:** https://huggingface.co/datasets/quandao92/vision-inventory-prediction-data
42
- - **Training size:**
43
- - 1차 : Few-shot learning with anomaly (10ea), good (4ea)
44
- - 2차 : Few-shot learning with anomaly (10ea), good (10ea)
45
- - 3 : Few-shot learning with anomaly (10ea), good (110ea)
46
-
47
- - **Time-step:** 5초 이내
48
-
49
- - **Data Processing Techniques:**
50
- - normalization:
51
- description: "이미지 픽셀 값을 평균 및 표준편차로 표준화"
52
- method: "'Normalize' from 'torchvision.transforms'"
53
- - max_resize:
54
- description: "이미지의 최대 크기를 유지하며, 비율을 맞추고 패딩을 추가하여 크기 조정"
55
- method: "Custom 'ResizeMaxSize' class"
56
- - random_resized_crop:
57
- description: "훈련 중에 이미지를 랜덤으로 자르고 크기를 조정하여 변형을 추가"
58
- method: "'RandomResizedCrop' from 'torchvision.transforms'"
59
- - resize:
60
- description: "모델 입력에 맞게 이미지를 고정된 크기로 조정"
61
- method: "'Resize' with BICUBIC interpolation"
62
- - center_crop:
63
- description: "이미지의 중앙 부분을 지정된 크기로 자르기"
64
- method: "'CenterCrop'"
65
- - to_tensor:
66
- description: "이미지를 PyTorch 텐서로 변환"
67
- method: "'ToTensor'"
68
- - augmentation (optional):
69
- description: "데이터 증강을 위해 다양한 랜덤 변환 적용, 'AugmentationCfg'로 설정 가능"
70
- method: "Uses 'timm' library if specified"
71
-
72
-
73
- # AD-CLIP Model Architecture
74
-
75
-
76
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/1wFBzBCgF4sOefROGE7RO.png)
77
-
78
- - **model:**
79
- - input_layer:
80
- - image_size: [640, 640, 3] # 표준 입력 이미지 크기
81
- - backbone:
82
- - name: CLIP (ViT-B-32) # CLIP 모델의 비전 트랜스포머를 백본으로 사용
83
- - filters: [32, 64, 128, 256, 512] # 비전 트랜스포머의 각 레이어 필터 크기
84
- - neck:
85
- - name: Anomaly Detection Module # 결함 탐지를 위한 추가 모듈
86
- - method: Contrastive Learning # CLIP 모델의 특징을 사용한 대조 학습 기법
87
- - head:
88
- - name: Anomaly Detection Head # 결함 탐지를 위한 최종 출력 레이어
89
- - outputs:
90
- - anomaly_score: 1 # 이상 탐지 점수 (비정상/정상 구분)
91
- - class_probabilities: N # 각 클래스에 대한 확률 (결함 여부)
92
-
93
- # Optimizer and Loss Function
94
- - **training:**
95
- - optimizer:
96
- - name: AdamW # AdamW 옵티마이저 (가중치 감쇠 포함)
97
- - lr: 0.0001 # 학습률
98
- - loss:
99
- - classification_loss: 1.0 # 분류 손실 (교차 엔트로피)
100
- - anomaly_loss: 1.0 # 결함 탐지 손실 (이상 탐지 모델에 대한 손실)
101
- - contrastive_loss: 1.0 # 대조 학습 손실 (유사도 기반 손실)
102
-
103
- # Metrics
104
- - **metrics:**
105
- - Precision # 정밀도 (Precision)
106
- - Recall # 재현율 (Recall)
107
- - mAP # 평균 정밀도 (Mean Average Precision)
108
- - F1-Score # F1-점수 (균형 잡힌 평가 지표)
109
-
110
- # Training Parameters
111
- **하이퍼파라미터 설정**
112
- - Learning Rate: 0.001.
113
- - Batch Size: 8.
114
- - Epochs: 200.
115
 
116
  # Pre-trained CLIP model
117
  | Model | Download |
@@ -121,75 +84,132 @@
121
  | ViT-L/14 | [download](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) |
122
  | ViT-L/14@336px | [download](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt) |
123
 
124
- # Evaluation Parameters
125
- - F1-score: 95%이상.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
 
 
 
 
 
127
 
128
- # 학습 성능 테스트 결과
 
 
 
 
129
 
130
- - **학습성능 결과과 그래프**:
131
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/RduhNlkWiyPXj-vbAkJga.png)
 
 
 
132
 
133
- <div style="display: flex; justify-content: space-between;">
134
- <div style="text-align: center; margin-right: 20px;">
135
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/_lUD77x-yueXycuIn7jya.png" height="80%" width="100%" style="margin-right:5px;">
136
- <p>1차 학습 성능</p>
137
- </div>
138
- <div style="text-align: center; margin-right: 20px;">
139
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/NHDH9N94cI-KqP8k-ASUN.png" height="80%" width="100%" style="margin-right:5px;">
140
- <p>2차 학습 성능</p>
141
- </div>
142
- <div style="text-align: center; margin-right: 20px;">
143
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/6n0DnnQjXD8Ql-p3Owxan.png" height="80%" width="100%" style="margin-right:5px;">
144
- <p>3차 학습 성능</p>
145
- </div>
146
- </div>
147
 
148
- - **학습 결과표**:
149
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/kDxl9q6X2dxCRJm5nc7jR.png)
150
-
151
- - **테스트 결과**:
152
- <div style="display: flex; justify-content: space-between;">
153
- <div style="text-align: center; margin-right: 20px;">
154
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/A91V0GdrcUcX01cC-biG9.png" height="600" width="1000" style="margin-right:5px;">
155
- <p>Anomaly Product</p>
156
- </div>
157
- <div style="text-align: center; margin-right: 20px;">
158
- <img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/PxleIhphzViTGCubVhWn7.png" height="600" width="1000" style="margin-right:5px;">
159
- <p>Normal Product</p>
160
- </div>
161
- </div>
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- # 설치 및 실행 가이라인
 
 
165
 
166
- 이 모델을 실행하려면 Python과 함께 다음 라이브러리가 필요합니다:
167
 
168
- - **ftfy==6.2.0**: 텍스트 정규화 및 인코딩 문제를 해결하는 라이브러리.
169
- - **matplotlib==3.9.0**: 데이터 시각화 및 그래프 생성을 위한 라이브러리.
170
- - **numpy==1.24.3**: 수치 연산을 위한 핵심 라이브러리.
171
- - **opencv_python==4.9.0.80**: 이미지 및 비디오 처리용 라이브러리.
172
- - **pandas==2.2.2**: 데이터 분석 및 조작을 위한 라이브러리.
173
- - **Pillow==10.3.0**: 이미지 파일 처리 및 변환을 위한 라이브러리.
174
- - **PyQt5==5.15.10**: GUI 애플리케이션 개발을 위한 프레임워크.
175
- - **PyQt5_sip==12.13.0**: PyQt5와 Python 간의 인터페이스를 제공하는 라이브러리.
176
- - **regex==2024.5.15**: 정규 표현식 처리를 위한 라이브러리.
177
- - **scikit_learn==1.2.2**: 기계 학습 및 데이터 분석을 위한 라이브러리.
178
- - **scipy==1.9.1**: 과학 및 기술 계산을 위한 라이브러리.
179
- - **setuptools==59.5.0**: Python 패키지 배포 및 설치를 위한 라이브러리.
180
- - **scikit-image**: 이미지 처리 및 분석을 위한 라이브러리.
181
- - **tabulate==0.9.0**: 표 형태로 데이터를 출력하는 라이브러리.
182
- - **thop==0.1.1.post2209072238**: PyTorch 모델의 FLOP 수를 계산하는 도구.
183
- - **timm==0.6.13**: 다양한 최신 이미지 분류 모델을 제공하는 라이브러리.
184
- - **torch==2.0.0**: PyTorch 딥러닝 프레임워크.
185
- - **torchvision==0.15.1**: 컴퓨터 비전 작업을 위한 PyTorch 확장 라이브러리.
186
- - **tqdm==4.65.0**: 진행 상황을 시각적으로 표시하는 라이브러리.
187
- - **pyautogui**: GUI 자동화를 위한 라이브러리.
188
 
 
 
 
 
 
189
 
190
 
191
 
192
- ### 모델 실행 단계:
193
 
194
  ### ✅ Prompt generating
195
  ```ruby
@@ -249,18 +269,8 @@ parser.add_argument("--dpam", type=int, default=20, help="dpam size")
249
  → If you want to focus only on the final layers (where the model usually learns complex features), you can choose fewer DPAM layers.
250
  ```
251
 
252
- ### ✅ Test process
253
 
254
- 👍 **Load pre-trained and Fine tuned (Checkpoints) models**
255
- 1. Pre-trained mode (./pre-trained model/):
256
- ```ruby
257
- → Contains the pre-trained model (ViT-B, ViT-L,....)
258
- → Used as the starting point for training the CLIP model
259
- → Pre-trained model helps speed up and improve training by leveraging previously learned features
260
- ```
261
- 2. Fine-tuned models (./checkpoint/):
262
- ```ruby
263
- → "epoch_N.pth" files in this folder store the model's states during the fine-tuning process.
264
- → Each ".pth" file represents a version of the model fine-tuned from the pre-trained model
265
- → These checkpoints can be used to resume fine-tuning, evaluate the model at different stages, or select the best-performing version
266
- ```
 
1
 
2
+ # CLIP based ANOMALY DETECTION
3
+
4
+ <div align="center">
5
+
6
+ [![Status](https://img.shields.io/badge/status-active-success.svg)]()
7
+ [![GitHub Issues](https://img.shields.io/github/issues/kylelobo/The-Documentation-Compendium.svg)](https://github.com/kylelobo/The-Documentation-Compendium/issues)
8
+ [![GitHub Pull Requests](https://img.shields.io/github/issues-pr/kylelobo/The-Documentation-Compendium.svg)](https://github.com/kylelobo/The-Documentation-Compendium/pulls)
9
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](/LICENSE)
10
+
11
+ </div>
12
+
13
+ ---
14
+
15
+ <p align="center"> Anomaly detection (AD) requires detection models trained using auxiliary data to detect anomalies without any training sample in a target dataset. AnomalyCLIP is to learn object-agnostic text prompts that capture generic normality and abnormality in an image regardless of its foreground objects. This allows our model to focus on the abnormal image regions rather than the object semantics, enabling generalized normality and abnormality recognition on diverse types of objects. All experiments are conducted in PyTorch-2.0.0 with a single NVIDIA RTX 4090 24GB.
16
+ <br>
17
+ </p>
18
+
19
+
20
+
21
+
22
+ # 📝 Table of Contents
23
+
24
+ - [Update](#update)
25
+ - [Install & Dependence](#install--dependence)
26
+ - [Dataset Preparation](#dataset-preparation)
27
+ - [Pre-trained CLIP model](#pre-trained-clip-model)
28
+ - [Usage](#usage)
29
+ - [Code Details](#code-details)
30
+ - [References](#references)
31
+
32
+ # Update
33
+ - 08.08.2024: Code has been released !!!
34
+
35
+
36
+ # Install & Dependence
37
+
38
+ ### ⭕ Tested Platform
39
+ - Software Information
40
+ ```
41
+ OS: Windows 11 64 bit
42
+ Python: 3.9.18 (anaconda)
43
+ PyTorch: 2.0.0
44
+ Cuda Toolkit: 11.8
45
+ CudDNN: 9.3.0.75 for cuda11
46
+ ```
47
+ ![analysis](./docs/CUDA_info.png)
48
+
49
+ - Hardware
50
+ ```
51
+ CPU: Intel(R) Core(TM) i7-14700K 3.40 GHz
52
+ RAM: 64GB
53
+ GPU: Nvidia RTX4090 (24GB)
54
+ ```
55
+
56
+
57
+ - Install Python libraries
58
+ ```
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ # Dataset Preparation
63
+
64
+ Download the dataset below:
65
+
66
+ * Industrial Domain:
67
+
68
+ | Dataset | Download |
69
+ | --- | --- |
70
+ | MVTec | [download](https://www.mvtec.com/company/research/datasets/mvtec-ad) |
71
+ | VisA | [download](https://github.com/amazon-science/spot-diff) |
72
+ | MPDD | [download](https://github.com/stepanje/MPDD) |
73
+ | BTAD | [download](http://avires.dimi.uniud.it/papers/btad/btad.zip) |
74
+ | SDD | [download](https://www.vicos.si/resources/kolektorsdd/) |
75
+ | DAGM | [download](https://www.kaggle.com/datasets/mhskjelvareid/dagm-2007-competition-dataset-optical-inspection) |
76
+ | DTD-Synthetic | [download](https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1) |
77
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  # Pre-trained CLIP model
80
  | Model | Download |
 
84
  | ViT-L/14 | [download](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) |
85
  | ViT-L/14@336px | [download](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt) |
86
 
87
+ # Usage
88
+ - for train (Fine-tuning)
89
+ ```ruby
90
+ python train.py
91
+ ```
92
+ - for test with dataset (many data)
93
+ ```ruby
94
+ python test.py
95
+ ```
96
+ - for simple test (개별 이미지 테스트)
97
+ ```ruby
98
+ python Simple_test_code.py
99
+ ```
100
+ - for UI app test (simple app developed)
101
+ ```ruby
102
+ python monitor_check.py
103
+ ```
104
+ - for real-time detection test (webcam and video tracking)
105
+ ```ruby
106
+ python real_time_CLIP.py
107
+ ```
108
+
109
+ # Code Details
110
+
111
+ ### ✅Dataset configuration
112
+
113
+ - Dataset configuration as example below
114
+ ```
115
+ ├── data/
116
+ │ ├── COMP_1/
117
+ │ │ ├── product_1/
118
+ │ │ │ ├──grouth_truth
119
+ │ │ │ │ ├──anomaly_1
120
+ │ │ │ │ ├──anomaly_2
121
+ │ │ │ │
122
+ │ │ │ ├──test/
123
+ │ │ │ │ ├──good
124
+ │ │ │ │ ├──anomaly_1
125
+ │ │ │ │ ├──anomaly_2
126
+ │ │ │ │
127
+ │ │ │ ├──train/
128
+ │ │ │ │ ├──good
129
+ │ │ │ │ ├──anomaly_1
130
+ │ │ │ │ ├──anomaly_2
131
+ │ │ │ │
132
+ │ │ ├── product_2/
133
+ │ │ │ │
134
+ │ │
135
+ │ ├── COMP_2/
136
+ │ │
137
+ ```
138
 
139
+ - Generate JSON file storing all the above information of dataset ( -> meta_train.json, meta_test.json)
140
+ ```ruby
141
+ cd dataset_config
142
+ python dataset_get_json.py
143
+ ```
144
 
145
+ - Making all grouth_truth (only anomaly mask) by hand
146
+ ```ruby
147
+ cd dataset_config
148
+ python image_ground_truth.py
149
+ ```
150
 
151
+ - Dataset configuration for train and test
152
+ ```ruby
153
+ cd training_libs
154
+ python dataset.py
155
+ ```
156
 
157
+ → _ _init_ _ 메서드는 데이터셋의 루트 디렉토리, 변환 함수, 데이터셋 이름, 모드를 입력으로 받음
158
+ 메타 정보를 담은 JSON 파일 (meta_train.json)을 읽어와 클래스 이름 목록과 모든 데이터 항목을 리스트에 저장
159
+ → generate_class_info 함수를 호출하여 클래스 정보를 생성하고 클래스 이름을 클래스 ID에 매핑
160
+ → _ _len_ _ 메서드는 데이터셋의 샘플 수를 반환
161
+ _ _getitem_ _ 메서드는 주어진 인덱스의 샘플 데이터를 반환
162
+ 이미지 경로를 통해 이미지를 읽고, 이상 여부에 따라 마스크 이미지를 생성
163
+ → 필요시 이미지와 마스크에 변환 함수를 적용
164
+ → 이미지, 마스크, 클래스 이름, 이상 여부, 이미지 경로, 클래스 ID를 포함한 딕셔너리를 반환
 
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ ### ✅ Image pre-processing (transformation) for train and test
168
+ ```ruby
169
+ training_lib/utils.py
170
+ ```
171
+ ```ruby
172
+ AnomalyCLIP_lib/transform.py
173
+ ```
174
+ ⭐ **Data Processing Techniques**
175
+ 1. Normalization
176
+ → Standardize image pixel values using mean and standard deviation
177
+ → Utilized via *'Normalize'* from *'torchvision.transforms'*
178
+
179
+ 2. Normalization
180
+ → Resize the image to a maximum dimension while maintaining aspect ratio, with padding
181
+ → Custom *'ResizeMaxSize'* class
182
+
183
+ 3. RandomResizedCrop
184
+ → Randomly crop and resize images during training to create variability
185
+ → Implemented via *'RandomResizedCrop'* from *'torchvision.transforms'*
186
+
187
+ 4. Resize
188
+ → Resize images to a fixed size for model input
189
+ → Done using *'Resize'* with BICUBIC interpolation
190
+
191
+ 5. Center Crop
192
+ → Crop the central region of the image to the desired size
193
+ → Applied using *'CenterCrop'*
194
+
195
+ 6. ToTensor
196
+ → Convert images to PyTorch tensors
197
+ → Done with *'ToTensor'*
198
 
199
+ 7. Augmentation (Optional)
200
+ → Apply various random transformations for data augmentation, configurable via *'AugmentationCfg' *
201
+ → Uses *'timm'* library if specified
202
 
 
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ ⭐ **Libraries Used**
206
+ 1. *'torch'*: Core deep learning library for tensor operations and model building
207
+ 2. *'torchvision'*: Provides image processing utilities like Resize, CenterCrop, Normalize, etc
208
+ 3. *'timm'*: Optional, for advanced augmentation and transformations
209
+ 4. *'AnomalyCLIP_lib'*: Custom library for dataset-specific constants and transformations
210
 
211
 
212
 
 
213
 
214
  ### ✅ Prompt generating
215
  ```ruby
 
269
  → If you want to focus only on the final layers (where the model usually learns complex features), you can choose fewer DPAM layers.
270
  ```
271
 
 
272
 
273
+
274
+ # References
275
+ - AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection [[github](https://github.com/zqhang/AnomalyCLIP.git)]
276
+
 
 
 
 
 
 
 
 
 
dataset_config/dataset_get_json.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+
5
+ class DATASolver(object):
6
+
7
+ CLSNAMES = [
8
+ 'shinpyung',
9
+ # 'gear'# Change to 'welding_test' for testing
10
+ ]
11
+
12
+ def __init__(self, root='data/4inlab'):
13
+ self.root = root
14
+ self.meta_path = f'{root}/meta_train.json' # Change to meta_test.json for testing
15
+
16
+ def run(self):
17
+ info = dict(train={}, test={})
18
+ anomaly_samples = 0
19
+ normal_samples = 0
20
+ for cls_name in self.CLSNAMES:
21
+ cls_dir = f'{self.root}/{cls_name}'
22
+ for phase in ['train', 'test']:
23
+ cls_info = []
24
+ species = os.listdir(f'{cls_dir}/{phase}')
25
+ for specie in species:
26
+ is_abnormal = True if specie not in ['good'] else False
27
+ img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
28
+ mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
29
+ img_names.sort()
30
+ mask_names.sort() if mask_names is not None else None
31
+ for idx, img_name in enumerate(img_names):
32
+ info_img = dict(
33
+ img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
34
+ mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
35
+ cls_name=cls_name,
36
+ specie_name=specie,
37
+ anomaly=1 if is_abnormal else 0,
38
+ )
39
+ cls_info.append(info_img)
40
+ if phase == 'test':
41
+ if is_abnormal:
42
+ anomaly_samples = anomaly_samples + 1
43
+ else:
44
+ normal_samples = normal_samples + 1
45
+ info[phase][cls_name] = cls_info
46
+ with open(self.meta_path, 'w') as f:
47
+ f.write(json.dumps(info, indent=4) + "\n")
48
+ print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples)
49
+ if __name__ == '__main__':
50
+ runner = DATASolver(root='data/4inlab')
51
+ runner.run()
dataset_config/image_ground_truth.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+
6
+ # Initialize global variables
7
+ points = []
8
+ drawing = False
9
+
10
+ # Function to clear the drawn points
11
+ def clear_points():
12
+ global points
13
+ points = []
14
+
15
+ # Mouse callback function
16
+ def draw_polygon(event, x, y, flags, param):
17
+ global points, drawing, img
18
+
19
+ if event == cv2.EVENT_LBUTTONDOWN:
20
+ drawing = True
21
+ points.append((x, y))
22
+
23
+ elif event == cv2.EVENT_MOUSEMOVE:
24
+ if drawing:
25
+ img_copy = img.copy()
26
+ for i in range(1, len(points)):
27
+ cv2.line(img_copy, points[i - 1], points[i], (255, 0, 0), 2)
28
+ if len(points) > 0:
29
+ cv2.circle(img_copy, points[-1], 3, (0, 0, 255), -1, lineType=cv2.LINE_AA) # Hiển thị điểm chọn của chuột
30
+ cv2.imshow('image', img_copy)
31
+
32
+ elif event == cv2.EVENT_LBUTTONUP:
33
+ drawing = False
34
+ points.append((x, y))
35
+ pts = np.array(points, np.int32)
36
+ pts = pts.reshape((-1, 1, 2))
37
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
38
+ cv2.fillPoly(mask, [pts], 255)
39
+ cv2.imwrite(mask_path, mask)
40
+ cv2.imshow('image', img)
41
+
42
+ # Function to process images in a folder
43
+ def process_images_in_folder(folder_path):
44
+ global img, mask_path
45
+
46
+ for img_name in os.listdir(folder_path):
47
+ if img_name.endswith('.jpg'):
48
+ img_path = os.path.join(folder_path, img_name)
49
+ mask_path = os.path.join(folder_path, f'{os.path.splitext(img_name)[0]}_mask.jpg')
50
+ img = cv2.imread(img_path)
51
+
52
+ # Create a window and bind the mouse callback function
53
+ cv2.namedWindow('image')
54
+ cv2.setMouseCallback('image', draw_polygon)
55
+
56
+ while True:
57
+ cv2.imshow('image', img)
58
+ k = cv2.waitKey(1) & 0xFF
59
+ if k == 27: # Press 'ESC' to exit
60
+ break
61
+
62
+ clear_points()
63
+ cv2.destroyAllWindows()
64
+
65
+ # Define folders to process
66
+ folder_path=r'C:\Users\20240805\Documents\GitHub\AD-CLIP\data\4inlab\shinpyung\train\anomaly'
67
+ process_images_in_folder(folder_path)
68
+ # %%
dataset_config/image_resize.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+
4
+ def resize_image(input_image_path, output_image_path, size=(518, 518)):
5
+ with Image.open(input_image_path) as image:
6
+ resized_image = image.resize(size)
7
+ resized_image.save(output_image_path)
8
+
9
+ # Example usage:
10
+ input_image = r'\4inlab\shinpyung\train\anomaly'
11
+ output_image = r'\4inlab\shinpyung\train\anomaly\resize'
12
+ resize_image(input_image, output_image)
13
+
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ftfy==6.2.0
2
+ matplotlib==3.9.0
3
+ numpy==1.24.3
4
+ opencv_python==4.9.0.80
5
+ pandas==2.2.2
6
+ Pillow==10.3.0
7
+ PyQt5==5.15.10
8
+ PyQt5_sip==12.13.0
9
+ regex==2024.5.15
10
+ scikit_learn==1.2.2
11
+ scipy==1.9.1
12
+ setuptools==59.5.0
13
+ scikit-image
14
+ tabulate==0.9.0
15
+ thop==0.1.1.post2209072238
16
+ timm==0.6.13
17
+ torch==2.0.0
18
+ torchvision==0.15.1
19
+ tqdm==4.65.0
20
+ pyautogui
test.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import AnomalyCLIP_lib
3
+ import torch
4
+ import argparse
5
+ import torch.nn.functional as F
6
+ from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
7
+ from training_libs.loss import FocalLoss, BinaryDiceLoss
8
+ from training_libs.utils import normalize
9
+ from training_libs.dataset import Dataset_test
10
+ from training_libs.logger import get_logger
11
+ from tqdm import tqdm
12
+
13
+ import os
14
+ import random
15
+ import numpy as np
16
+ from tabulate import tabulate
17
+ from training_libs.utils import get_transform
18
+
19
+ def setup_seed(seed):
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ np.random.seed(seed)
23
+ random.seed(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+
27
+ from training_libs.visualization import visualizer
28
+
29
+ from training_libs.metrics import image_level_metrics, pixel_level_metrics
30
+ from tqdm import tqdm
31
+ from scipy.ndimage import gaussian_filter
32
+
33
+
34
+ def test(args):
35
+ img_size = args.image_size
36
+ features_list = args.features_list
37
+ dataset_dir = args.data_path
38
+ save_path = args.save_path
39
+ dataset_name = args.dataset
40
+
41
+ logger = get_logger(args.save_path)
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ # device = "gpu"
44
+
45
+ AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
46
+ model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
47
+ model.eval()
48
+ # torch.save(model.state_dict(),"pre-trained models/clip")
49
+
50
+ preprocess, target_transform = get_transform(args)
51
+ test_data = Dataset_test(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
52
+ test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
53
+ obj_list = test_data.obj_list
54
+
55
+
56
+ results = {}
57
+ metrics = {}
58
+ for obj in obj_list:
59
+ results[obj] = {}
60
+ results[obj]['gt_sp'] = []
61
+ results[obj]['pr_sp'] = []
62
+ results[obj]['imgs_masks'] = []
63
+ results[obj]['anomaly_maps'] = []
64
+ metrics[obj] = {}
65
+ metrics[obj]['pixel-auroc'] = 0
66
+ metrics[obj]['pixel-aupro'] = 0
67
+ metrics[obj]['image-auroc'] = 0
68
+ metrics[obj]['image-ap'] = 0
69
+
70
+ prompt_learner = AnomalyCLIP_PromptLearner(model.to(device=device), AnomalyCLIP_parameters)
71
+
72
+
73
+ #Add check-point from trained model with normal images
74
+ # checkpoint = torch.load("checkpoint/241120_SP_DPAM_13_518/epoch_500.pth",map_location=torch.device('cpu'))
75
+ # prompt_learner.load_state_dict(checkpoint["prompt_learner"])
76
+
77
+
78
+ #Add check-point from trained model with normal images
79
+ # checkpoint = torch.load(args.checkpoint_path,map_location=torch.device(device=device))
80
+ # prompt_learner.load_state_dict(checkpoint["prompt_learner"])
81
+
82
+
83
+ prompt_learner.to(device)
84
+ model.to(device)
85
+ model.visual.DAPM_replace(DPAM_layer = 13)
86
+
87
+ prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
88
+ print("print(prompts)")
89
+ print(prompts)
90
+
91
+
92
+
93
+ text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
94
+ text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
95
+ text_features = text_features/text_features.norm(dim=-1, keepdim=True)
96
+
97
+
98
+
99
+ model.to(device)
100
+ for idx, items in enumerate(tqdm(test_dataloader)):
101
+ image = items['img'].to(device)
102
+ cls_name = items['cls_name']
103
+ cls_id = items['cls_id']
104
+
105
+ gt_mask_initial = items['img_mask']
106
+ #convert gt mask to good (0) and anomaly (1)
107
+ gt_mask = items['img_mask']
108
+ gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
109
+
110
+
111
+ results[cls_name[0]]['imgs_masks'].append(gt_mask) # px
112
+ results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu())
113
+
114
+ with torch.no_grad():
115
+ image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20)
116
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
117
+
118
+ text_probs = image_features @ text_features.permute(0, 2, 1)
119
+ text_probs = (text_probs/0.07).softmax(-1)
120
+ text_probs = text_probs[:, 0, 1]
121
+ anomaly_map_list = []
122
+ for idx, patch_feature in enumerate(patch_features):
123
+ if idx >= args.feature_map_layer[0]:
124
+ patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
125
+ similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
126
+ similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size)
127
+ anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0
128
+ anomaly_map_list.append(anomaly_map)
129
+
130
+ anomaly_map = torch.stack(anomaly_map_list)
131
+
132
+ anomaly_map = anomaly_map.sum(dim = 0)
133
+ results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu())
134
+ anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 )
135
+ results[cls_name[0]]['anomaly_maps'].append(anomaly_map)
136
+
137
+ #Save the anomaly map images
138
+ visualizer(items['img_path'], anomaly_map.detach().cpu().numpy(), args.image_size, args.save_path, cls_name)
139
+
140
+ print("print(results)")
141
+ torch.save(results,"results/results_shinpyung_0.pt")
142
+ # print(results)
143
+
144
+ table_ls = []
145
+ image_auroc_list = []
146
+ image_ap_list = []
147
+ pixel_auroc_list = []
148
+ pixel_aupro_list = []
149
+ for obj in obj_list:
150
+ table = []
151
+ table.append(obj)
152
+ results[obj]['imgs_masks'] = torch.cat(results[obj]['imgs_masks'])
153
+ results[obj]['anomaly_maps'] = torch.cat(results[obj]['anomaly_maps']).detach().cpu().numpy()
154
+ if args.metrics == 'image-level':
155
+ image_auroc = image_level_metrics(results, obj, "image-auroc")
156
+ image_ap = image_level_metrics(results, obj, "image-ap")
157
+ table.append(str(np.round(image_auroc * 100, decimals=1)))
158
+ table.append(str(np.round(image_ap * 100, decimals=1)))
159
+ image_auroc_list.append(image_auroc)
160
+ image_ap_list.append(image_ap)
161
+ elif args.metrics == 'pixel-level':
162
+ pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
163
+ pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
164
+ table.append(str(np.round(pixel_auroc * 100, decimals=1)))
165
+ table.append(str(np.round(pixel_aupro * 100, decimals=1)))
166
+ pixel_auroc_list.append(pixel_auroc)
167
+ pixel_aupro_list.append(pixel_aupro)
168
+ elif args.metrics == 'image-pixel-level':
169
+ image_auroc = image_level_metrics(results, obj, "image-auroc")
170
+ image_ap = image_level_metrics(results, obj, "image-ap")
171
+ pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
172
+ pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
173
+ table.append(str(np.round(pixel_auroc * 100, decimals=1)))
174
+ table.append(str(np.round(pixel_aupro * 100, decimals=1)))
175
+ table.append(str(np.round(image_auroc * 100, decimals=1)))
176
+ table.append(str(np.round(image_ap * 100, decimals=1)))
177
+ image_auroc_list.append(image_auroc)
178
+ image_ap_list.append(image_ap)
179
+ pixel_auroc_list.append(pixel_auroc)
180
+ pixel_aupro_list.append(pixel_aupro)
181
+ table_ls.append(table)
182
+
183
+ if args.metrics == 'image-level':
184
+ # logger
185
+ table_ls.append(['mean',
186
+ str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
187
+ str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
188
+ results = tabulate(table_ls, headers=['objects', 'image_auroc', 'image_ap'], tablefmt="pipe")
189
+ elif args.metrics == 'pixel-level':
190
+ # logger
191
+ table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
192
+ str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1))
193
+ ])
194
+ results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro'], tablefmt="pipe")
195
+ elif args.metrics == 'image-pixel-level':
196
+ # logger
197
+ table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
198
+ str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1)),
199
+ str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
200
+ str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
201
+ results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro', 'image_auroc', 'image_ap'], tablefmt="pipe")
202
+ logger.info("\n%s", results)
203
+
204
+
205
+ if __name__ == '__main__':
206
+ parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True)
207
+ # paths
208
+ parser.add_argument("--data_path", type=str, default="./data/4inlab/", help="path to test dataset")
209
+ parser.add_argument("--save_path", type=str, default='./results/', help='path to save results')
210
+ parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to checkpoint')
211
+ # model
212
+ parser.add_argument("--dataset", type=str, default='4inlab')
213
+ parser.add_argument("--image_size", type=int, default=518, help="image size")
214
+ parser.add_argument("--depth", type=int, default=9, help="image size")
215
+ parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
216
+ parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
217
+ parser.add_argument("--metrics", type=str, default='image-pixel-level')
218
+ parser.add_argument("--seed", type=int, default=111, help="random seed")
219
+ parser.add_argument("--sigma", type=int, default=4, help="zero shot")
220
+ # Specify layers from which feature maps will be extracted (can pass multiple values)
221
+ parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
222
+
223
+ # List of layers whose features will be used
224
+ parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
225
+
226
+
227
+ args = parser.parse_args()
228
+ print(args)
229
+ setup_seed(args.seed)
230
+ test(args)
231
+ #%%
train.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
3
+ import AnomalyCLIP_lib
4
+ import torch
5
+ import argparse
6
+ import torch.nn.functional as F
7
+ from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
8
+ from training_libs.loss import FocalLoss, BinaryDiceLoss
9
+ from training_libs.utils import normalize
10
+ from training_libs.dataset import Dataset_train
11
+ from training_libs.logger import get_logger
12
+ from tqdm import tqdm
13
+ import numpy as np
14
+ import random
15
+ from training_libs.utils import get_transform
16
+ import matplotlib.pyplot as plt
17
+
18
+ import warnings
19
+ warnings.filterwarnings("ignore", category=UserWarning)
20
+
21
+
22
+ def setup_seed(seed):
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+ np.random.seed(seed)
26
+ random.seed(seed)
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.backends.cudnn.benchmark = False
29
+
30
+ class RealTimePlotter: #
31
+ def __init__(self):
32
+ self.epochs = []
33
+ self.loss_list = []
34
+ self.image_loss_list = []
35
+ self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(14, 6))
36
+ plt.ion()
37
+ self.fig.show()
38
+ self.fig.canvas.flush_events()
39
+
40
+ def update(self, epoch, loss, image_loss):
41
+ self.epochs.append(epoch)
42
+ self.loss_list.append(loss)
43
+ self.image_loss_list.append(image_loss)
44
+
45
+ self.ax1.clear()
46
+ self.ax2.clear()
47
+
48
+ self.ax1.plot(self.epochs, self.loss_list, label='Training Loss')
49
+ self.ax1.set_title('Training Loss')
50
+ self.ax1.set_xlabel('Epochs')
51
+ self.ax1.set_ylabel('Loss')
52
+ self.ax1.legend()
53
+
54
+ self.ax2.plot(self.epochs, self.image_loss_list, label='Image Loss')
55
+ self.ax2.set_title('Image Loss')
56
+ self.ax2.set_xlabel('Epochs')
57
+ self.ax2.set_ylabel('Loss')
58
+ self.ax2.legend()
59
+
60
+ self.fig.canvas.flush_events()
61
+
62
+ def train(args):
63
+
64
+ logger = get_logger(args.save_path)
65
+
66
+ preprocess, target_transform = get_transform(args)
67
+ device = "cuda" if torch.cuda.is_available() else "cpu"
68
+ # device = "cpu"
69
+
70
+ AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
71
+
72
+ # model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters)
73
+ model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
74
+ model.eval()
75
+
76
+ train_data = Dataset_train(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
77
+ train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
78
+
79
+ ##########################################################################################
80
+ prompt_learner = AnomalyCLIP_PromptLearner(model.to(device), AnomalyCLIP_parameters)
81
+ prompt_learner.to(device)
82
+ model.to(device)
83
+ model.visual.DAPM_replace(DPAM_layer = args.dpam)
84
+ ##########################################################################################
85
+ optimizer = torch.optim.Adam(list(prompt_learner.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
86
+
87
+ # losses
88
+ loss_focal = FocalLoss()
89
+ loss_dice = BinaryDiceLoss()
90
+
91
+
92
+ model.eval()
93
+ prompt_learner.train()
94
+ # plotter = RealTimePlotter()
95
+
96
+ for epoch in tqdm(range(args.epoch)):
97
+ model.eval()
98
+ prompt_learner.train()
99
+ loss_list = []
100
+ image_loss_list = []
101
+
102
+ for items in tqdm(train_dataloader):
103
+ image = items['img'].to(device)
104
+ label = items['anomaly']
105
+
106
+ gt = items['img_mask'].squeeze().to(device)
107
+ gt[gt > 0.5] = 1
108
+ gt[gt <= 0.5] = 0
109
+
110
+ with torch.no_grad():
111
+ # Apply DPAM to the layer from 6 to 24
112
+ # DPAM_layer represents the number of layer refined by DPAM from top to bottom
113
+ # DPAM_layer = 1, no DPAM is used
114
+ # DPAM_layer = 20 as default
115
+ image_features, patch_features = model.encode_image(image, args.features_list, DPAM_layer = args.dpam)
116
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
117
+
118
+ ####################################
119
+ prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
120
+ text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
121
+ text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
122
+ text_features = text_features/text_features.norm(dim=-1, keepdim=True)
123
+ # Apply DPAM surgery
124
+ text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1)
125
+ text_probs = text_probs[:, 0, ...]/0.07
126
+
127
+
128
+
129
+ image_loss = F.cross_entropy(text_probs.squeeze(), label.long().cuda()) #Process with GPU
130
+ #image_loss = F.cross_entropy(text_probs.squeeze(), label.long()) #Without GPU and using CPU
131
+ image_loss_list.append(image_loss.item())
132
+ ######################################################################
133
+ similarity_map_list = []
134
+ # similarity_map_list.append(similarity_map)
135
+ for idx, patch_feature in enumerate(patch_features):
136
+ if idx >= args.feature_map_layer[0]:
137
+ patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
138
+ similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
139
+ similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size).permute(0, 3, 1, 2)
140
+ similarity_map_list.append(similarity_map)
141
+
142
+ loss = 0
143
+ for i in range(len(similarity_map_list)):
144
+ loss += loss_focal(similarity_map_list[i], gt)
145
+ loss += loss_dice(similarity_map_list[i][:, 1, :, :], gt)
146
+ loss += loss_dice(similarity_map_list[i][:, 0, :, :], 1-gt)
147
+
148
+ optimizer.zero_grad()
149
+ (loss+image_loss).backward()
150
+ optimizer.step()
151
+ loss_list.append(loss.item())
152
+ # logs
153
+ if (epoch + 1) % args.print_freq == 0:
154
+ avg_loss = np.mean(loss_list)
155
+ avg_image_loss = np.mean(image_loss_list)
156
+ logger.info('epoch [{}/{}], loss:{:.4f}, image_loss:{:.4f}'.format(epoch + 1, args.epoch, avg_loss, avg_image_loss))
157
+ # plotter.update(epoch + 1, avg_loss, avg_image_loss) #Realtime training performance monitoring
158
+
159
+ # save model
160
+ if (epoch + 1) % args.save_freq == 0:
161
+ ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth')
162
+ torch.save({"prompt_learner": prompt_learner.state_dict(),"epoch":epoch+1}, ckp_path)
163
+
164
+ if __name__ == '__main__':
165
+ parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) # Initialize the argument parser
166
+
167
+ # Define the path to the training dataset and model checkpoint saving
168
+ parser.add_argument("--train_data_path", type=str, default="./data/4inlab", help="train dataset path")
169
+ parser.add_argument("--save_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to save results')
170
+
171
+ # Specify the name of the training dataset
172
+ parser.add_argument("--dataset", type=str, default='4inlab', help="train dataset name")
173
+
174
+ # Set the depth parameter (Note: "image size" in help may be misleading)
175
+ parser.add_argument("--depth", type=int, default=9, help="image size")
176
+
177
+ # Set the prompt length and learnable text embedding length for "zero-shot" learning
178
+ parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
179
+ parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
180
+
181
+ # Specify layers from which feature maps will be extracted (can pass multiple values)
182
+ parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
183
+
184
+ # List of layers whose features will be used
185
+ parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
186
+
187
+ # Setting parameters for training
188
+ parser.add_argument("--epoch", type=int, default=400, help="epochs")
189
+ parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
190
+ parser.add_argument("--batch_size", type=int, default=8, help="batch size")
191
+
192
+ # Size/depth parameter for the DPAM (Deep Prompt Attention Mechanism)
193
+ parser.add_argument("--dpam", type=int, default=13, help="dpam size")
194
+
195
+ # Define the size of input images used for training
196
+ parser.add_argument("--image_size", type=int, default=518, help="image size")
197
+
198
+ # Frequency (in epochs) of logging training information and saving
199
+ parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
200
+ parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
201
+ parser.add_argument("--seed", type=int, default=111, help="random seed")
202
+
203
+ args = parser.parse_args() # Parse the command-line arguments and store them in the 'args' object
204
+ setup_seed(args.seed) # Set the random seed for reproducibility using the provided seed value
205
+ train(args) # Call the training function with the parsed arguments
206
+
207
+
training_libs/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
training_libs/__pycache__/logger.cpython-39.pyc ADDED
Binary file (890 Bytes). View file
 
training_libs/__pycache__/loss.cpython-39.pyc ADDED
Binary file (4.19 kB). View file
 
training_libs/__pycache__/metrics.cpython-39.pyc ADDED
Binary file (1.98 kB). View file
 
training_libs/__pycache__/prompt_ensemble.cpython-39.pyc ADDED
Binary file (7.13 kB). View file
 
training_libs/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.06 kB). View file
 
training_libs/__pycache__/visualization.cpython-39.pyc ADDED
Binary file (1.17 kB). View file
 
training_libs/dataset.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+ import numpy as np
6
+ import torch
7
+ import os
8
+
9
+ def generate_class_info(dataset_name, mode='train'):
10
+ class_name_map_class_id = {}
11
+ if dataset_name == 'mvtec':
12
+ # obj_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
13
+ # 'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood']
14
+ obj_list = ['bottle']
15
+ elif dataset_name == '4inlab':
16
+ if mode=='train':
17
+ obj_list = ['shinpyung'] # With training
18
+ elif mode=='test':
19
+ obj_list = ['shinpyung'] # With testing
20
+ elif dataset_name == 'task1':
21
+ if mode=='train':
22
+ obj_list = ['cup']
23
+ elif dataset_name == 'task2':
24
+ if mode=='train':
25
+ obj_list = ['fire']
26
+ elif dataset_name == 'smoke_cloud':
27
+ if mode=='train':
28
+ obj_list = ['fire']
29
+
30
+ for k, index in zip(obj_list, range(len(obj_list))):
31
+ class_name_map_class_id[k] = index
32
+
33
+ return obj_list, class_name_map_class_id
34
+
35
+ class Dataset_test(data.Dataset):
36
+ def __init__(self, root, transform, target_transform, dataset_name, mode="test"):
37
+ self.root = root
38
+ self.transform = transform
39
+ self.target_transform = target_transform
40
+ self.data_all = []
41
+ meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
42
+ name = self.root.split('/')[-1]
43
+ meta_info = meta_info[mode]
44
+
45
+ self.cls_names = list(meta_info.keys())
46
+ for cls_name in self.cls_names:
47
+ self.data_all.extend(meta_info[cls_name])
48
+ self.length = len(self.data_all)
49
+
50
+ self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='test')
51
+ def __len__(self):
52
+ return self.length
53
+
54
+ def __getitem__(self, index):
55
+ data = self.data_all[index]
56
+ img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
57
+ data['specie_name'], data['anomaly']
58
+ img = Image.open(os.path.join(self.root, img_path))
59
+ if anomaly == 0:
60
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
61
+ else:
62
+ if os.path.isdir(os.path.join(self.root, mask_path)):
63
+ # just for classification not report error
64
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
65
+ else:
66
+ img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
67
+ img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
68
+ # transforms
69
+ img = self.transform(img) if self.transform is not None else img
70
+ img_mask = self.target_transform(
71
+ img_mask) if self.target_transform is not None and img_mask is not None else img_mask
72
+ img_mask = [] if img_mask is None else img_mask
73
+ return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
74
+ 'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
75
+
76
+
77
+ class Dataset_train(data.Dataset):
78
+ def __init__(self, root, transform, target_transform, dataset_name, mode="train"):
79
+ self.root = root
80
+ self.transform = transform
81
+ self.target_transform = target_transform
82
+ self.data_all = []
83
+ meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
84
+ name = self.root.split('/')[-1]
85
+ meta_info = meta_info[mode]
86
+
87
+ self.cls_names = list(meta_info.keys())
88
+ for cls_name in self.cls_names:
89
+ self.data_all.extend(meta_info[cls_name])
90
+ self.length = len(self.data_all)
91
+
92
+ self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='train')
93
+ def __len__(self):
94
+ return self.length
95
+
96
+ def __getitem__(self, index):
97
+ data = self.data_all[index]
98
+ img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
99
+ data['specie_name'], data['anomaly']
100
+ img = Image.open(os.path.join(self.root, img_path))
101
+ if anomaly == 0:
102
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
103
+ else:
104
+ if os.path.isdir(os.path.join(self.root, mask_path)):
105
+ # just for classification not report error
106
+ img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
107
+ else:
108
+ img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
109
+ img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
110
+ # transforms
111
+ img = self.transform(img) if self.transform is not None else img
112
+ img_mask = self.target_transform(
113
+ img_mask) if self.target_transform is not None and img_mask is not None else img_mask
114
+ img_mask = [] if img_mask is None else img_mask
115
+ return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
116
+ 'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
training_libs/logger.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ import os
4
+
5
+ def get_logger(save_path):
6
+ if not os.path.exists(save_path):
7
+ os.makedirs(save_path)
8
+
9
+ txt_path = os.path.join(save_path, 'log.txt')
10
+ # logger
11
+ root_logger = logging.getLogger()
12
+ for handler in root_logger.handlers[:]:
13
+ root_logger.removeHandler(handler)
14
+ root_logger.setLevel(logging.WARNING)
15
+ logger = logging.getLogger('test')
16
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
17
+ datefmt='%y-%m-%d %H:%M:%S')
18
+ logger.setLevel(logging.INFO)
19
+ file_handler = logging.FileHandler(txt_path, mode='a')
20
+ file_handler.setFormatter(formatter)
21
+ logger.addHandler(file_handler)
22
+ console_handler = logging.StreamHandler()
23
+ console_handler.setFormatter(formatter)
24
+ logger.addHandler(console_handler)
25
+ return logger
training_libs/loss.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from math import exp
6
+
7
+ class FocalLoss(nn.Module):
8
+ """
9
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
10
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
11
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
12
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
13
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
14
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
15
+ focus on hard misclassified example
16
+ :param smooth: (float,double) smooth value when cross entropy
17
+ :param balance_index: (int) balance class index, should be specific when alpha is float
18
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
19
+ """
20
+
21
+ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
22
+ super(FocalLoss, self).__init__()
23
+ self.apply_nonlin = apply_nonlin
24
+ self.alpha = alpha
25
+ self.gamma = gamma
26
+ self.balance_index = balance_index
27
+ self.smooth = smooth
28
+ self.size_average = size_average
29
+
30
+ if self.smooth is not None:
31
+ if self.smooth < 0 or self.smooth > 1.0:
32
+ raise ValueError('smooth value should be in [0,1]')
33
+
34
+ def forward(self, logit, target):
35
+ if self.apply_nonlin is not None:
36
+ logit = self.apply_nonlin(logit)
37
+ num_class = logit.shape[1]
38
+
39
+ if logit.dim() > 2:
40
+ # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
41
+ logit = logit.view(logit.size(0), logit.size(1), -1)
42
+ logit = logit.permute(0, 2, 1).contiguous()
43
+ logit = logit.view(-1, logit.size(-1))
44
+ target = torch.squeeze(target, 1)
45
+ target = target.view(-1, 1)
46
+ alpha = self.alpha
47
+
48
+ if alpha is None:
49
+ alpha = torch.ones(num_class, 1)
50
+ elif isinstance(alpha, (list, np.ndarray)):
51
+ assert len(alpha) == num_class
52
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
53
+ alpha = alpha / alpha.sum()
54
+ elif isinstance(alpha, float):
55
+ alpha = torch.ones(num_class, 1)
56
+ alpha = alpha * (1 - self.alpha)
57
+ alpha[self.balance_index] = self.alpha
58
+
59
+ else:
60
+ raise TypeError('Not support alpha type')
61
+
62
+ if alpha.device != logit.device:
63
+ alpha = alpha.to(logit.device)
64
+
65
+ idx = target.cpu().long()
66
+
67
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
68
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
69
+ if one_hot_key.device != logit.device:
70
+ one_hot_key = one_hot_key.to(logit.device)
71
+
72
+ if self.smooth:
73
+ one_hot_key = torch.clamp(
74
+ one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
75
+ pt = (one_hot_key * logit).sum(1) + self.smooth
76
+ logpt = pt.log()
77
+
78
+ gamma = self.gamma
79
+
80
+ alpha = alpha[idx]
81
+ alpha = torch.squeeze(alpha)
82
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
83
+
84
+ if self.size_average:
85
+ loss = loss.mean()
86
+ return loss
87
+
88
+
89
+ class BinaryDiceLoss(nn.Module):
90
+ def __init__(self):
91
+ super(BinaryDiceLoss, self).__init__()
92
+
93
+ def forward(self, input, targets):
94
+ # Get the size N of each batch
95
+ N = targets.size()[0]
96
+ # Smooth variable
97
+ smooth = 1
98
+ # Reshape the width and height to the same dimension
99
+ input_flat = input.view(N, -1)
100
+ targets_flat = targets.view(N, -1)
101
+
102
+ intersection = input_flat * targets_flat
103
+ N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
104
+ # Calculate the loss average for each image in a batch
105
+ loss = 1 - N_dice_eff.sum() / N
106
+ return loss
107
+
108
+ def smooth(arr, lamda1):
109
+ new_array = arr
110
+ arr2 = torch.zeros_like(arr)
111
+ arr2[:, :-1, :] = arr[:, 1:, :]
112
+ arr2[:, -1, :] = arr[:, -1, :]
113
+
114
+ new_array2 = torch.zeros_like(new_array)
115
+ new_array2[:, :, :-1] = new_array[:, :, 1:]
116
+ new_array2[:, :, -1] = new_array[:, :, -1]
117
+ loss = (torch.sum((arr2 - arr) ** 2) + torch.sum((new_array2 - new_array) ** 2)) / 2
118
+ return lamda1 * loss
119
+
120
+ def sparsity(arr, target, lamda2):
121
+ if target == 0:
122
+ loss = torch.mean(torch.norm(arr, dim=0))
123
+ else:
124
+ loss = torch.mean(torch.norm(1-arr, dim=0))
125
+ return lamda2 * loss
training_libs/metrics.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
2
+ import numpy as np
3
+ from skimage import measure
4
+
5
+ def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
6
+ # ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
7
+ binary_amaps = np.zeros_like(amaps, dtype=bool)
8
+ min_th, max_th = amaps.min(), amaps.max()
9
+ delta = (max_th - min_th) / max_step
10
+ pros, fprs, ths = [], [], []
11
+ for th in np.arange(min_th, max_th, delta):
12
+ binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
13
+ pro = []
14
+ for binary_amap, mask in zip(binary_amaps, masks):
15
+ for region in measure.regionprops(measure.label(mask)):
16
+ tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
17
+ pro.append(tp_pixels / region.area)
18
+ inverse_masks = 1 - masks
19
+ fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
20
+ fpr = fp_pixels / inverse_masks.sum()
21
+ pros.append(np.array(pro).mean())
22
+ fprs.append(fpr)
23
+ ths.append(th)
24
+ pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
25
+ idxes = fprs < expect_fpr
26
+ fprs = fprs[idxes]
27
+ fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
28
+ pro_auc = auc(fprs, pros[idxes])
29
+ return pro_auc
30
+
31
+
32
+ def image_level_metrics(results, obj, metric):
33
+ gt = results[obj]['gt_sp']
34
+ pr = results[obj]['pr_sp']
35
+ gt = np.array(gt)
36
+ pr = np.array(pr)
37
+ if metric == 'image-auroc':
38
+ performance = roc_auc_score(gt, pr)
39
+ elif metric == 'image-ap':
40
+ performance = average_precision_score(gt, pr)
41
+
42
+ return performance
43
+ # table.append(str(np.round(performance * 100, decimals=1)))
44
+
45
+
46
+ def pixel_level_metrics(results, obj, metric):
47
+ gt = results[obj]['imgs_masks']
48
+ pr = results[obj]['anomaly_maps']
49
+ gt = np.array(gt)
50
+ pr = np.array(pr)
51
+ if metric == 'pixel-auroc':
52
+ performance = roc_auc_score(gt.ravel(), pr.ravel())
53
+ elif metric == 'pixel-aupro':
54
+ if len(gt.shape) == 4:
55
+ gt = gt.squeeze(1)
56
+ if len(pr.shape) == 4:
57
+ pr = pr.squeeze(1)
58
+ performance = cal_pro_score(gt, pr)
59
+ return performance
60
+
training_libs/prompt_ensemble.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List
3
+ from pkg_resources import packaging
4
+ import torch
5
+ import numpy as np
6
+ from AnomalyCLIP_lib.simple_tokenizer import SimpleTokenizer as _Tokenizer
7
+ # from open_clip import tokenizer
8
+ # simple_tokenizer = tokenizer.SimpleTokenizer()
9
+ from copy import deepcopy
10
+ import torch.nn as nn
11
+
12
+ _tokenizer = _Tokenizer()
13
+
14
+
15
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
16
+ """
17
+ Returns the tokenized representation of given input string(s)
18
+
19
+ Parameters
20
+ ----------
21
+ texts : Union[str, List[str]]
22
+ An input string or a list of input strings to tokenize
23
+
24
+ context_length : int
25
+ The context length to use; all CLIP models use 77 as the context length
26
+
27
+ truncate: bool
28
+ Whether to truncate the text in case its encoding is longer than the context length
29
+
30
+ Returns
31
+ -------
32
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
33
+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
34
+ """
35
+ if isinstance(texts, str):
36
+ texts = [texts]
37
+
38
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
39
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
40
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
41
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
42
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
43
+ else:
44
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
45
+
46
+ for i, tokens in enumerate(all_tokens):
47
+ if len(tokens) > context_length:
48
+ if truncate:
49
+ tokens = tokens[:context_length]
50
+ tokens[-1] = eot_token
51
+ else:
52
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
53
+ result[i, :len(tokens)] = torch.tensor(tokens)
54
+
55
+ return result
56
+
57
+ # def encode_text_with_prompt_ensemble(model, texts, device):
58
+ # prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
59
+ # prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
60
+ # prompt_state = [prompt_normal, prompt_abnormal]
61
+ # prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
62
+
63
+ # text_features = []
64
+ # for i in range(len(prompt_state)):
65
+ # prompted_state = [state.format(texts[0]) for state in prompt_state[i]]
66
+ # prompted_sentence = []
67
+ # for s in prompted_state:
68
+ # for template in prompt_templates:
69
+ # prompted_sentence.append(template.format(s))
70
+ # prompted_sentence = tokenize(prompted_sentence)
71
+ # class_embeddings = model.encode_text(prompted_sentence.to(device))
72
+ # class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
73
+ # class_embedding = class_embeddings.mean(dim=0)
74
+ # class_embedding /= class_embedding.norm()
75
+ # text_features.append(class_embedding)
76
+
77
+ # text_features = torch.stack(text_features, dim=1).to(device).t()
78
+
79
+ # return text_features
80
+
81
+
82
+
83
+ def _get_clones(module, N):
84
+ return nn.ModuleList([deepcopy(module) for i in range(N)])
85
+ class AnomalyCLIP_PromptLearner(nn.Module):
86
+ def __init__(self, clip_model, design_details):
87
+ super().__init__()
88
+ classnames = ["object"]
89
+ self.n_cls = len(classnames)
90
+ self.n_ctx = design_details["Prompt_length"]
91
+ n_ctx_pos = self.n_ctx
92
+ n_ctx_neg = self.n_ctx
93
+ self.text_encoder_n_ctx = design_details["learnabel_text_embedding_length"]
94
+ ctx_init_pos = ""
95
+ ctx_init_neg = ""
96
+ dtype = clip_model.transformer.get_cast_dtype()
97
+ device = clip_model.token_embedding.weight.device
98
+
99
+ ctx_dim = clip_model.ln_final.weight.shape[0]
100
+
101
+
102
+ self.classnames = classnames
103
+
104
+ self.state_normal_list = [
105
+ "{}",
106
+ ]
107
+
108
+ self.state_anomaly_list = [
109
+ "damaged {}",
110
+ ]
111
+
112
+ normal_num = len(self.state_normal_list)
113
+ anormaly_num = len(self.state_anomaly_list)
114
+ self.normal_num = normal_num
115
+ self.anormaly_num = anormaly_num
116
+
117
+ if ctx_init_pos and ctx_init_neg:
118
+ # use given words to initialize context vectors
119
+ ctx_init_pos = ctx_init_pos.replace("_", " ")
120
+ ctx_init_neg = ctx_init_neg.replace("_", " ")
121
+ n_ctx_pos = len(ctx_init_pos.split(" "))
122
+ n_ctx_neg = len(ctx_init_neg.split(" "))
123
+ # Initialize text into bpd encoding
124
+ prompt_pos = tokenize(ctx_init_pos)
125
+ prompt_neg = tokenize(ctx_init_neg)
126
+ with torch.no_grad():
127
+ # Generate corresponding text embedding
128
+ embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype)
129
+ embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype)
130
+ # Remove EOS and # CLS, EOS, and get the learnable textual prompt
131
+ ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :]
132
+ ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :]
133
+ prompt_prefix_pos = ctx_init_pos
134
+ prompt_prefix_neg = ctx_init_neg
135
+ if True:
136
+ ctx_vectors_pos_ = []
137
+ ctx_vectors_neg_ = []
138
+ for _ in range(self.n_cls):
139
+ ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos))
140
+ ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg))
141
+ ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0)
142
+ ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0)
143
+
144
+ else:
145
+ # Random Initialization
146
+ if True:
147
+ print("Initializing class-specific contexts")
148
+ # Here cls is the number of classes, n_ctx_pos represents the length of learnable tokens, ctx_dim indicates the dimension of the prompt
149
+ ctx_vectors_pos = torch.empty(self.n_cls, self.normal_num, n_ctx_pos, ctx_dim, dtype=dtype)
150
+ ctx_vectors_neg = torch.empty(self.n_cls, self.anormaly_num, n_ctx_neg, ctx_dim, dtype=dtype)
151
+ else:
152
+ print("Initializing a generic context")
153
+ ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype)
154
+ ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype)
155
+ nn.init.normal_(ctx_vectors_pos, std=0.02)
156
+ nn.init.normal_(ctx_vectors_neg, std=0.02)
157
+ prompt_prefix_pos = " ".join(["X"] * n_ctx_pos)
158
+ prompt_prefix_neg = " ".join(["X"] * n_ctx_neg)
159
+ self.compound_prompts_depth = design_details["learnabel_text_embedding_depth"]
160
+ self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(self.text_encoder_n_ctx, ctx_dim))
161
+ for _ in range(self.compound_prompts_depth - 1)])
162
+ for single_para in self.compound_prompts_text:
163
+ print("single_para", single_para.shape)
164
+ nn.init.normal_(single_para, std=0.02)
165
+
166
+ single_layer = nn.Linear(ctx_dim, 896)
167
+ self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
168
+
169
+
170
+ self.ctx_pos = nn.Parameter(ctx_vectors_pos) # to be optimized
171
+ self.ctx_neg = nn.Parameter(ctx_vectors_neg) # to be optimized
172
+
173
+ classnames = [name.replace("_", " ") for name in classnames]
174
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
175
+
176
+
177
+ prompts_pos = [prompt_prefix_pos + " " + template.format(name)+ "." for template in self.state_normal_list for name in classnames]
178
+ prompts_neg = [prompt_prefix_neg + " " + template.format(name)+ "." for template in self.state_anomaly_list for name in classnames]
179
+
180
+ # print("Normal Prompt:",prompts_pos )
181
+ # print("Anomaly Prompt:",prompts_neg )
182
+
183
+ tokenized_prompts_pos = []
184
+ tokenized_prompts_neg = []
185
+
186
+ for p_pos in prompts_pos:
187
+ tokenized_prompts_pos.append(tokenize(p_pos))
188
+ for p_neg in prompts_neg:
189
+ tokenized_prompts_neg.append(tokenize(p_neg))
190
+
191
+
192
+ tokenized_prompts_pos = [tokenize(p_pos).to(device) for p_pos in prompts_pos] # Move tokenized_prompts_pos to the same device
193
+ tokenized_prompts_neg = [tokenize(p_neg).to(device) for p_neg in prompts_neg] # Move tokenized_prompts_neg to the same device
194
+
195
+ tokenized_prompts_pos = torch.cat(tokenized_prompts_pos)
196
+ tokenized_prompts_neg = torch.cat(tokenized_prompts_neg)
197
+ # Generate corresponding text embedding
198
+ with torch.no_grad():
199
+ embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype)
200
+ embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype)
201
+ n, l, d = embedding_pos.shape
202
+ print("embedding_pos", embedding_pos.shape)
203
+ embedding_pos = embedding_pos.reshape(normal_num, self.n_cls, l, d).permute(1, 0, 2, 3)
204
+ embedding_neg = embedding_neg.reshape(anormaly_num, self.n_cls, l, d).permute(1, 0, 2, 3)
205
+
206
+
207
+ self.register_buffer("token_prefix_pos", embedding_pos[:, :, :1, :] )
208
+ self.register_buffer("token_suffix_pos", embedding_pos[:, :,1 + n_ctx_pos:, :])
209
+ self.register_buffer("token_prefix_neg", embedding_neg[:,:, :1, :])
210
+ self.register_buffer("token_suffix_neg", embedding_neg[:, :, 1 + n_ctx_neg:, :])
211
+
212
+ n, d = tokenized_prompts_pos.shape
213
+ tokenized_prompts_pos = tokenized_prompts_pos.reshape(normal_num, self.n_cls, d).permute(1, 0, 2)
214
+
215
+ n, d = tokenized_prompts_neg.shape
216
+ tokenized_prompts_neg = tokenized_prompts_neg.reshape(anormaly_num, self.n_cls, d).permute(1, 0, 2)
217
+
218
+ self.n_ctx_pos = n_ctx_pos
219
+ self.n_ctx_neg = n_ctx_neg
220
+ # tokenized_prompts = torch.cat([tokenized_prompts_pos, tokenized_prompts_neg], dim=0) # torch.Tensor
221
+ self.register_buffer("tokenized_prompts_pos", tokenized_prompts_pos)
222
+ self.register_buffer("tokenized_prompts_neg", tokenized_prompts_neg)
223
+ print("tokenized_prompts shape", self.tokenized_prompts_pos.shape, self.tokenized_prompts_neg.shape)
224
+
225
+
226
+
227
+ def forward(self, cls_id =None):
228
+
229
+ ctx_pos = self.ctx_pos
230
+ ctx_neg = self.ctx_neg
231
+ ctx_pos = self.ctx_pos
232
+ ctx_neg = self.ctx_neg
233
+ # print("shape", self.ctx_pos[0:1].shape, ctx_pos.shape)
234
+ prefix_pos = self.token_prefix_pos
235
+ prefix_neg = self.token_prefix_neg
236
+ suffix_pos = self.token_suffix_pos
237
+ suffix_neg = self.token_suffix_neg
238
+
239
+ # print(prefix_pos.shape, prefix_neg.shape)
240
+
241
+ prompts_pos = torch.cat(
242
+ [
243
+ # N(the number of template), 1, dim
244
+ prefix_pos, # (n_cls, 1, dim)
245
+ ctx_pos, # (n_cls, n_ctx, dim)
246
+ suffix_pos, # (n_cls, *, dim)
247
+ ],
248
+ dim=2,
249
+ )
250
+
251
+ prompts_neg = torch.cat(
252
+ [
253
+ prefix_neg, # (n_cls, 1, dim)
254
+ ctx_neg, # (n_cls, n_ctx, dim)
255
+ suffix_neg, # (n_cls, *, dim)
256
+ ],
257
+ dim=2,
258
+ )
259
+ _, _, l, d = prompts_pos.shape
260
+ prompts_pos = prompts_pos.reshape(-1, l, d)
261
+ _, _, l, d = prompts_neg.shape
262
+ prompts_neg = prompts_neg.reshape(-1, l, d)
263
+ prompts = torch.cat([prompts_pos, prompts_neg], dim=0)
264
+
265
+
266
+ _, l, d = self.tokenized_prompts_pos.shape
267
+ tokenized_prompts_pos = self.tokenized_prompts_pos.reshape(-1, d)
268
+ _, l, d = self.tokenized_prompts_neg.shape
269
+ tokenized_prompts_neg = self.tokenized_prompts_neg.reshape(-1, d)
270
+ tokenized_prompts = torch.cat((tokenized_prompts_pos, tokenized_prompts_neg), dim = 0)
271
+
272
+
273
+ return prompts, tokenized_prompts, self.compound_prompts_text
training_libs/utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ # from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
3
+ from AnomalyCLIP_lib.transform import image_transform
4
+ from AnomalyCLIP_lib.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
5
+
6
+
7
+
8
+ def normalize(pred, max_value=None, min_value=None):
9
+ if max_value is None or min_value is None:
10
+ return (pred - pred.min()) / (pred.max() - pred.min())
11
+ else:
12
+ return (pred - min_value) / (max_value - min_value)
13
+
14
+ def get_transform(args):
15
+ preprocess = image_transform(args.image_size, is_train=False, mean = OPENAI_DATASET_MEAN, std = OPENAI_DATASET_STD)
16
+ target_transform = transforms.Compose([
17
+ transforms.Resize((args.image_size, args.image_size)),
18
+ transforms.CenterCrop(args.image_size),
19
+ transforms.ToTensor()
20
+ ])
21
+ preprocess.transforms[0] = transforms.Resize(size=(args.image_size, args.image_size), interpolation=transforms.InterpolationMode.BICUBIC,
22
+ max_size=None, antialias=None)
23
+ preprocess.transforms[1] = transforms.CenterCrop(size=(args.image_size, args.image_size))
24
+ return preprocess, target_transform
training_libs/visualization.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ # from utils import normalize
4
+ from sklearn.preprocessing import normalize
5
+ import numpy as np
6
+
7
+ def visualizer(pathes, anomaly_map, img_size, save_path, cls_name):
8
+ for idx, path in enumerate(pathes):
9
+ cls = path.split('/')[-2]
10
+ filename = path.split('/')[-1]
11
+ vis = cv2.cvtColor(cv2.resize(cv2.imread(path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB
12
+ mask = normalize(anomaly_map[idx])
13
+ vis = apply_ad_scoremap(vis, mask)
14
+ vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR
15
+ save_vis = os.path.join(save_path, 'imgs', cls_name[idx], cls)
16
+ if not os.path.exists(save_vis):
17
+ os.makedirs(save_vis)
18
+ cv2.imwrite(os.path.join(save_vis, filename), vis)
19
+
20
+ def apply_ad_scoremap(image, scoremap, alpha=0.5):
21
+ np_image = np.asarray(image, dtype=float)
22
+ scoremap = (scoremap * 255).astype(np.uint8)
23
+ scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
24
+ scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
25
+ return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)