Upload 48 files
Browse files- AnomalyCLIP_lib/AnomalyCLIP.py +531 -0
- AnomalyCLIP_lib/CLIP.py +436 -0
- AnomalyCLIP_lib/__init__.py +1 -0
- AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/CLIP.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/__init__.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/__init__.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/build_model.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/build_model.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/clip.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/clip_model.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/clip_surgery_model.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/constants.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/constants.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/model_load.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/model_load.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/transform.cpython-38.pyc +0 -0
- AnomalyCLIP_lib/__pycache__/transform.cpython-39.pyc +0 -0
- AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz +3 -0
- AnomalyCLIP_lib/build_model.py +50 -0
- AnomalyCLIP_lib/constants.py +2 -0
- AnomalyCLIP_lib/model_load.py +235 -0
- AnomalyCLIP_lib/simple_tokenizer.py +132 -0
- AnomalyCLIP_lib/transform.py +133 -0
- README.md +193 -183
- dataset_config/dataset_get_json.py +51 -0
- dataset_config/image_ground_truth.py +68 -0
- dataset_config/image_resize.py +13 -0
- requirements.txt +20 -0
- test.py +231 -0
- train.py +207 -0
- training_libs/__pycache__/dataset.cpython-39.pyc +0 -0
- training_libs/__pycache__/logger.cpython-39.pyc +0 -0
- training_libs/__pycache__/loss.cpython-39.pyc +0 -0
- training_libs/__pycache__/metrics.cpython-39.pyc +0 -0
- training_libs/__pycache__/prompt_ensemble.cpython-39.pyc +0 -0
- training_libs/__pycache__/utils.cpython-39.pyc +0 -0
- training_libs/__pycache__/visualization.cpython-39.pyc +0 -0
- training_libs/dataset.py +116 -0
- training_libs/logger.py +25 -0
- training_libs/loss.py +125 -0
- training_libs/metrics.py +60 -0
- training_libs/prompt_ensemble.py +273 -0
- training_libs/utils.py +24 -0
- training_libs/visualization.py +25 -0
AnomalyCLIP_lib/AnomalyCLIP.py
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
|
9 |
+
class Bottleneck(nn.Module):
|
10 |
+
expansion = 4
|
11 |
+
|
12 |
+
def __init__(self, inplanes, planes, stride=1):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
16 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
17 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
18 |
+
self.relu1 = nn.ReLU(inplace=True)
|
19 |
+
|
20 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
21 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
22 |
+
self.relu2 = nn.ReLU(inplace=True)
|
23 |
+
|
24 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
25 |
+
|
26 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
27 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
28 |
+
self.relu3 = nn.ReLU(inplace=True)
|
29 |
+
|
30 |
+
self.downsample = None
|
31 |
+
self.stride = stride
|
32 |
+
|
33 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
34 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
35 |
+
self.downsample = nn.Sequential(OrderedDict([
|
36 |
+
("-1", nn.AvgPool2d(stride)),
|
37 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
38 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
39 |
+
]))
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
identity = x
|
43 |
+
|
44 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
45 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
46 |
+
out = self.avgpool(out)
|
47 |
+
out = self.bn3(self.conv3(out))
|
48 |
+
|
49 |
+
if self.downsample is not None:
|
50 |
+
identity = self.downsample(x)
|
51 |
+
|
52 |
+
out += identity
|
53 |
+
out = self.relu3(out)
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
# implement attention module for v-v self-attention
|
58 |
+
class Attention(nn.Module):
|
59 |
+
def __init__(self, out_dim, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., settings=''):
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
self.scale = qk_scale or head_dim ** -0.5
|
64 |
+
|
65 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
66 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
67 |
+
self.proj = nn.Linear(out_dim, dim)
|
68 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
69 |
+
self.settings = settings
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
B, N, C = x.shape
|
73 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
74 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
75 |
+
|
76 |
+
# original self-attention for the original path
|
77 |
+
attn_ori = (q @ k.transpose(-2, -1)) * self.scale
|
78 |
+
attn_ori = attn_ori.softmax(dim=-1)
|
79 |
+
attn_ori = self.attn_drop(attn_ori)
|
80 |
+
|
81 |
+
# replace k & q by v
|
82 |
+
k = v
|
83 |
+
q = k
|
84 |
+
|
85 |
+
# self-attention, higher temperate for resnets performs better
|
86 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
87 |
+
attn = (attn).softmax(dim=-1)
|
88 |
+
attn = self.attn_drop(attn)
|
89 |
+
|
90 |
+
x_ori = (attn_ori @ v).transpose(1, 2).reshape(B, N, C)
|
91 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
92 |
+
x = self.proj_drop(self.proj(x))
|
93 |
+
x_ori = self.proj_drop(self.proj(x_ori))
|
94 |
+
return [x, x_ori]
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
class LayerNorm(nn.LayerNorm):
|
99 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
100 |
+
|
101 |
+
def forward(self, x: torch.Tensor):
|
102 |
+
orig_type = x.dtype
|
103 |
+
ret = super().forward(x.type(torch.float32))
|
104 |
+
return ret.type(orig_type)
|
105 |
+
|
106 |
+
|
107 |
+
class QuickGELU(nn.Module):
|
108 |
+
def forward(self, x: torch.Tensor):
|
109 |
+
return x * torch.sigmoid(1.702 * x)
|
110 |
+
|
111 |
+
|
112 |
+
class ResidualAttentionBlock(nn.Module):
|
113 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details = None):
|
114 |
+
super().__init__()
|
115 |
+
|
116 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
117 |
+
self.ln_1 = LayerNorm(d_model)
|
118 |
+
self.mlp = nn.Sequential(OrderedDict([
|
119 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
120 |
+
("gelu", QuickGELU()),
|
121 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
122 |
+
]))
|
123 |
+
self.ln_2 = LayerNorm(d_model)
|
124 |
+
self.attn_mask = attn_mask
|
125 |
+
|
126 |
+
def attention(self, x: torch.Tensor):
|
127 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
128 |
+
if isinstance(self.attn, Attention):
|
129 |
+
x = x.transpose(0, 1)
|
130 |
+
x, x_ori = self.attn(x)
|
131 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
132 |
+
else:
|
133 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
134 |
+
|
135 |
+
def forward(self, x, whole = False, ffn = False):
|
136 |
+
# print("xxxxx",x.shape)
|
137 |
+
# dual paths for blocks deeper than "d"
|
138 |
+
|
139 |
+
if isinstance(self.attn, Attention):
|
140 |
+
if isinstance(x, list):
|
141 |
+
if not ffn:
|
142 |
+
x, x_ori = x
|
143 |
+
x_res = self.attention(self.ln_1(x_ori))
|
144 |
+
x_res, x_ori_res = x_res
|
145 |
+
x_ori += x_ori_res
|
146 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
147 |
+
x += x_res # skip ffn for the new path
|
148 |
+
# print('hellloooo')
|
149 |
+
return [x, x_ori]
|
150 |
+
else:
|
151 |
+
x, x_ori_1 = x
|
152 |
+
x_res = self.attention(self.ln_1(x_ori_1))
|
153 |
+
x_res, x_ori_res = x_res
|
154 |
+
x_ori = x_ori_1 + x_ori_res
|
155 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
156 |
+
x += x_res # skip ffn for the new path
|
157 |
+
x = x_res + x_ori_1
|
158 |
+
x = x + self.mlp(self.ln_2(x))
|
159 |
+
return [x, x_ori]
|
160 |
+
# start of dual path
|
161 |
+
else:
|
162 |
+
x_res = self.attention(self.ln_1(x))
|
163 |
+
if isinstance(x_res, list):
|
164 |
+
x_res, x_ori_res = x_res
|
165 |
+
x_ori = x + x_ori_res
|
166 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
167 |
+
x += x_res
|
168 |
+
return [x, x_ori]
|
169 |
+
|
170 |
+
# singl path before "d"
|
171 |
+
else:
|
172 |
+
x = x + self.attention(self.ln_1(x))
|
173 |
+
x = x + self.mlp(self.ln_2(x))
|
174 |
+
return x
|
175 |
+
|
176 |
+
class ResidualAttentionBlock_learnable_token(nn.Module):
|
177 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
|
178 |
+
text_layer=False, i = 0):
|
179 |
+
super().__init__()
|
180 |
+
|
181 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
182 |
+
self.ln_1 = LayerNorm(d_model)
|
183 |
+
self.mlp = nn.Sequential(OrderedDict([
|
184 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
185 |
+
("gelu", QuickGELU()),
|
186 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
187 |
+
]))
|
188 |
+
self.ln_2 = LayerNorm(d_model)
|
189 |
+
self.attn_mask = attn_mask
|
190 |
+
|
191 |
+
self.i = i
|
192 |
+
self.compound_prompt_nctx = design_details['learnabel_text_embedding_length']
|
193 |
+
self.text_layer = text_layer
|
194 |
+
if i == 0:
|
195 |
+
self.first_layer = True
|
196 |
+
else:
|
197 |
+
self.first_layer = False
|
198 |
+
|
199 |
+
def attention(self, x: torch.Tensor):
|
200 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
201 |
+
if isinstance(self.attn, Attention):
|
202 |
+
x = x.transpose(0, 1)
|
203 |
+
x, x_ori = self.attn(x)
|
204 |
+
return [x.transpose(0, 1), x_ori.transpose(0, 1)]
|
205 |
+
else:
|
206 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
207 |
+
|
208 |
+
def forward(self, inputs):
|
209 |
+
|
210 |
+
# dual paths for blocks deeper than "d"
|
211 |
+
if isinstance(self.attn, Attention):
|
212 |
+
x = inputs[0]
|
213 |
+
if isinstance(x, list):
|
214 |
+
x, x_ori = x
|
215 |
+
x_res = self.attention(self.ln_1(x_ori))
|
216 |
+
x_res, x_ori_res = x_res
|
217 |
+
x_ori += x_ori_res
|
218 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
219 |
+
x += x_res # skip ffn for the new path
|
220 |
+
return [x, x_ori]
|
221 |
+
|
222 |
+
# start of dual path
|
223 |
+
else:
|
224 |
+
x_res = self.attention(self.ln_1(x))
|
225 |
+
if isinstance(x_res, list):
|
226 |
+
x_res, x_ori_res = x_res
|
227 |
+
x_ori = x + x_ori_res
|
228 |
+
x_ori = x_ori + self.mlp(self.ln_2(x_ori))
|
229 |
+
x += x_res
|
230 |
+
return [x, x_ori]
|
231 |
+
|
232 |
+
# singl path before "d"
|
233 |
+
else:
|
234 |
+
x = inputs[0]
|
235 |
+
compound_prompts_deeper = inputs[1]
|
236 |
+
counter = inputs[2]
|
237 |
+
if not self.first_layer:
|
238 |
+
# First check if the ith layer needs compound prompts or not
|
239 |
+
if not (counter > len(compound_prompts_deeper) - 1):
|
240 |
+
# Appending the learnable tokens in different way
|
241 |
+
# x -> [77, NCLS, DIM]
|
242 |
+
# First remove the learnable tokens from previous layer
|
243 |
+
prefix = x[:1, :, :]
|
244 |
+
suffix = x[1 + self.compound_prompt_nctx:, :, :]
|
245 |
+
textual_context = compound_prompts_deeper[counter]
|
246 |
+
textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
|
247 |
+
# Add the learnable tokens of this layer with the input, replaced by previous
|
248 |
+
# layer learnable tokens
|
249 |
+
x = torch.cat([prefix, textual_context, suffix], dim=0)
|
250 |
+
# Once done, update the counter, so that the next time, it does not use same learnable tokens
|
251 |
+
counter += 1
|
252 |
+
x = x + self.attention(self.ln_1(x))
|
253 |
+
x = x + self.mlp(self.ln_2(x))
|
254 |
+
return [x, compound_prompts_deeper, counter]
|
255 |
+
|
256 |
+
|
257 |
+
class Transformer(nn.Module):
|
258 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False, design_details = None ,text_layer = False):
|
259 |
+
super().__init__()
|
260 |
+
self.width = width
|
261 |
+
self.layers = layers
|
262 |
+
self.text_layer = text_layer
|
263 |
+
self.design_deatails = design_details
|
264 |
+
print("text_layer", self.text_layer)
|
265 |
+
if self.text_layer and (design_details is not None):
|
266 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock_learnable_token(width, heads, attn_mask, design_details, text_layer, i=i) for i in range(layers)])
|
267 |
+
else:
|
268 |
+
self.resblocks = nn.ModuleList([ResidualAttentionBlock(width, heads, attn_mask,) for i in range(layers)])
|
269 |
+
|
270 |
+
def ori_CLIP_with_patch_forward(self, x, out_layers):
|
271 |
+
idx = 0
|
272 |
+
out_tokens = []
|
273 |
+
for r in self.resblocks:
|
274 |
+
idx += 1
|
275 |
+
x = r(x)
|
276 |
+
if idx in out_layers:
|
277 |
+
if isinstance(x, list):
|
278 |
+
out_tokens.append(x[1])
|
279 |
+
else:
|
280 |
+
out_tokens.append(x)
|
281 |
+
|
282 |
+
return [x, x], out_tokens
|
283 |
+
|
284 |
+
def AnomalyCLIP_forward(self, x, out_layers, ffn):
|
285 |
+
idx = 0
|
286 |
+
out_tokens = []
|
287 |
+
for r in self.resblocks:
|
288 |
+
idx += 1
|
289 |
+
x = r(x, ffn = ffn)
|
290 |
+
# print("out_layers", out_layers, idx)
|
291 |
+
if idx in out_layers:
|
292 |
+
if isinstance(x, list):
|
293 |
+
out_tokens.append(x[0])
|
294 |
+
else:
|
295 |
+
out_tokens.append(x)
|
296 |
+
return x, out_tokens
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor, out_layers = [6, 12, 18, 24], DPAM_layer = None, ffn = False):
|
299 |
+
# visual encoder forward
|
300 |
+
if not self.text_layer:
|
301 |
+
out_tokens = []
|
302 |
+
|
303 |
+
if DPAM_layer is None:
|
304 |
+
[x, x], out_tokens = self.ori_CLIP_with_patch_forward(x, out_layers)
|
305 |
+
return [x, x], out_tokens
|
306 |
+
else:
|
307 |
+
x, out_tokens = self.AnomalyCLIP_forward(x, out_layers, ffn)
|
308 |
+
return x, out_tokens
|
309 |
+
# text encoder forward
|
310 |
+
# ori text embedding
|
311 |
+
elif self.design_deatails is None:
|
312 |
+
for idx, r in enumerate(self.resblocks):
|
313 |
+
x = r(x)
|
314 |
+
return x
|
315 |
+
# insert learnable text embedding
|
316 |
+
elif self.design_deatails is not None:
|
317 |
+
for idx, r in enumerate(self.resblocks):
|
318 |
+
x = r(x)
|
319 |
+
return x[0]
|
320 |
+
def get_cast_dtype(self) -> torch.dtype:
|
321 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
322 |
+
|
323 |
+
class VisionTransformer(nn.Module):
|
324 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
325 |
+
super().__init__()
|
326 |
+
self.input_resolution = input_resolution
|
327 |
+
self.output_dim = output_dim
|
328 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
329 |
+
|
330 |
+
scale = width ** -0.5
|
331 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
332 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
333 |
+
self.ln_pre = LayerNorm(width)
|
334 |
+
|
335 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
336 |
+
self.attn = None
|
337 |
+
self.embed_dim = width
|
338 |
+
self.num_heads = heads
|
339 |
+
|
340 |
+
self.ln_post = LayerNorm(width)
|
341 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
342 |
+
|
343 |
+
|
344 |
+
@torch.no_grad()
|
345 |
+
def DAPM_replace(self, DPAM_layer):
|
346 |
+
if DPAM_layer is not None:
|
347 |
+
for i in range(1, DPAM_layer):
|
348 |
+
self.attn = Attention(self.embed_dim, self.embed_dim, self.num_heads, True)
|
349 |
+
self.attn.qkv.weight.data = self.transformer.resblocks[-i].attn.in_proj_weight.clone()
|
350 |
+
self.attn.qkv.bias.data = self.transformer.resblocks[-i].attn.in_proj_bias.clone()
|
351 |
+
self.attn.proj.weight.data = self.transformer.resblocks[-i].attn.out_proj.weight.clone()
|
352 |
+
self.attn.proj.bias.data = self.transformer.resblocks[-i].attn.out_proj.bias.clone()
|
353 |
+
self.transformer.resblocks[-i].attn = self.attn
|
354 |
+
|
355 |
+
@torch.no_grad()
|
356 |
+
def forward(self, x: torch.Tensor, features_list, ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
|
357 |
+
|
358 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
359 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
360 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
361 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
362 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
363 |
+
new_side = int((x.shape[1] - 1) ** 0.5)
|
364 |
+
|
365 |
+
# update the position embedding during inference for varied input size
|
366 |
+
if side != new_side:
|
367 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
368 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
369 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
370 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
371 |
+
|
372 |
+
pos = self.positional_embedding.to(x.dtype)
|
373 |
+
x = x + pos
|
374 |
+
x = self.ln_pre(x)
|
375 |
+
|
376 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
377 |
+
[x, x_ori], patch_tokens = self.transformer(x, features_list, DPAM_layer = DPAM_layer, ffn = ffn)
|
378 |
+
|
379 |
+
|
380 |
+
if True:
|
381 |
+
patch_token_list = []
|
382 |
+
for patch_token in patch_tokens:
|
383 |
+
patch_token = self.ln_post(patch_token.permute(1, 0, 2)) @ self.proj # LND -> NLD
|
384 |
+
patch_token_list.append(patch_token)
|
385 |
+
patch_tokens = patch_token_list
|
386 |
+
|
387 |
+
return x_ori[0, :, :] @ self.proj, patch_tokens
|
388 |
+
|
389 |
+
|
390 |
+
return x
|
391 |
+
|
392 |
+
|
393 |
+
from thop import profile
|
394 |
+
class AnomalyCLIP(nn.Module):
|
395 |
+
def __init__(self,
|
396 |
+
embed_dim: int,
|
397 |
+
# vision
|
398 |
+
image_resolution: int,
|
399 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
400 |
+
vision_width: int,
|
401 |
+
vision_patch_size: int,
|
402 |
+
# text
|
403 |
+
context_length: int,
|
404 |
+
vocab_size: int,
|
405 |
+
transformer_width: int,
|
406 |
+
transformer_heads: int,
|
407 |
+
transformer_layers: int,
|
408 |
+
design_details = None
|
409 |
+
):
|
410 |
+
super().__init__()
|
411 |
+
|
412 |
+
self.context_length = context_length
|
413 |
+
|
414 |
+
if isinstance(vision_layers, (tuple, list)):
|
415 |
+
vision_heads = vision_width * 32 // 64
|
416 |
+
self.visual = ModifiedResNet(
|
417 |
+
layers=vision_layers,
|
418 |
+
output_dim=embed_dim,
|
419 |
+
heads=vision_heads,
|
420 |
+
input_resolution=image_resolution,
|
421 |
+
width=vision_width
|
422 |
+
)
|
423 |
+
else:
|
424 |
+
vision_heads = vision_width // 64
|
425 |
+
self.visual = VisionTransformer(
|
426 |
+
input_resolution=image_resolution,
|
427 |
+
patch_size=vision_patch_size,
|
428 |
+
width=vision_width,
|
429 |
+
layers=vision_layers,
|
430 |
+
heads=vision_heads,
|
431 |
+
output_dim=embed_dim
|
432 |
+
)
|
433 |
+
|
434 |
+
self.transformer = Transformer(
|
435 |
+
width=transformer_width,
|
436 |
+
layers=transformer_layers,
|
437 |
+
heads=transformer_heads,
|
438 |
+
attn_mask=self.build_attention_mask(), text_layer=True, design_details=design_details
|
439 |
+
)
|
440 |
+
|
441 |
+
self.vocab_size = vocab_size
|
442 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
443 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
444 |
+
self.ln_final = LayerNorm(transformer_width)
|
445 |
+
|
446 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
447 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
448 |
+
|
449 |
+
self.initialize_parameters()
|
450 |
+
|
451 |
+
def initialize_parameters(self):
|
452 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
453 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
454 |
+
|
455 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
456 |
+
attn_std = self.transformer.width ** -0.5
|
457 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
458 |
+
for block in self.transformer.resblocks:
|
459 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
460 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
461 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
462 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
463 |
+
|
464 |
+
if self.text_projection is not None:
|
465 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
466 |
+
def build_attention_mask(self):
|
467 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
468 |
+
# pytorch uses additive attention mask; fill with -inf
|
469 |
+
mask = torch.empty(self.context_length, self.context_length)
|
470 |
+
mask.fill_(float("-inf"))
|
471 |
+
mask.triu_(1) # zero out the lower diagonal
|
472 |
+
return mask
|
473 |
+
|
474 |
+
@property
|
475 |
+
def dtype(self):
|
476 |
+
return self.visual.conv1.weight.dtype
|
477 |
+
|
478 |
+
def encode_image(self, image, feature_list = [], ori_patch = False, proj_use = True, DPAM_layer = None, ffn = False):
|
479 |
+
return self.visual(image.type(self.dtype), feature_list, ori_patch = ori_patch, proj_use = proj_use, DPAM_layer = DPAM_layer, ffn = ffn)
|
480 |
+
|
481 |
+
|
482 |
+
def encode_text(self, text):
|
483 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
484 |
+
|
485 |
+
x = x + self.positional_embedding.type(self.dtype)
|
486 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
487 |
+
x = self.transformer(x)
|
488 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
489 |
+
x = self.ln_final(x).type(self.dtype)
|
490 |
+
|
491 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
492 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
493 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
494 |
+
|
495 |
+
return x
|
496 |
+
|
497 |
+
def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
|
498 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
499 |
+
|
500 |
+
# x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
501 |
+
|
502 |
+
# x = x + self.positional_embedding.to(cast_dtype)
|
503 |
+
|
504 |
+
x = prompts + self.positional_embedding.to(cast_dtype)
|
505 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
506 |
+
# print("test", x.shape, len(deep_compound_prompts_text))
|
507 |
+
if deep_compound_prompts_text is None:
|
508 |
+
x = self.transformer(x)
|
509 |
+
else:
|
510 |
+
x = self.transformer([x, deep_compound_prompts_text, 0])
|
511 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
512 |
+
x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
|
513 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
514 |
+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
515 |
+
return x
|
516 |
+
|
517 |
+
def forward(self, image, text):
|
518 |
+
image_features = self.encode_image(image)
|
519 |
+
text_features = self.encode_text(text)
|
520 |
+
|
521 |
+
# normalized features
|
522 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
523 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
524 |
+
|
525 |
+
# cosine similarity as logits
|
526 |
+
logit_scale = self.logit_scale.exp()
|
527 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
528 |
+
logits_per_text = logits_per_image.t()
|
529 |
+
|
530 |
+
# shape = [global_batch_size, global_batch_size]
|
531 |
+
return logits_per_image, logits_per_text
|
AnomalyCLIP_lib/CLIP.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
|
72 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
73 |
+
new_side = int((x.shape[0] - 1) ** 0.5)
|
74 |
+
|
75 |
+
# update the position embedding during inference for varied input size
|
76 |
+
if side != new_side:
|
77 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
78 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
79 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
80 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
81 |
+
|
82 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
83 |
+
x, _ = F.multi_head_attention_forward(
|
84 |
+
query=x, key=x, value=x,
|
85 |
+
embed_dim_to_check=x.shape[-1],
|
86 |
+
num_heads=self.num_heads,
|
87 |
+
q_proj_weight=self.q_proj.weight,
|
88 |
+
k_proj_weight=self.k_proj.weight,
|
89 |
+
v_proj_weight=self.v_proj.weight,
|
90 |
+
in_proj_weight=None,
|
91 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
92 |
+
bias_k=None,
|
93 |
+
bias_v=None,
|
94 |
+
add_zero_attn=False,
|
95 |
+
dropout_p=0,
|
96 |
+
out_proj_weight=self.c_proj.weight,
|
97 |
+
out_proj_bias=self.c_proj.bias,
|
98 |
+
use_separate_proj_weight=True,
|
99 |
+
training=self.training,
|
100 |
+
need_weights=False
|
101 |
+
)
|
102 |
+
|
103 |
+
#return x[0]
|
104 |
+
return x.transpose(0, 1) # return both cls token and image tokens, B,N,C
|
105 |
+
|
106 |
+
|
107 |
+
class ModifiedResNet(nn.Module):
|
108 |
+
"""
|
109 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
110 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
111 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
112 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
113 |
+
"""
|
114 |
+
|
115 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
116 |
+
super().__init__()
|
117 |
+
self.output_dim = output_dim
|
118 |
+
self.input_resolution = input_resolution
|
119 |
+
|
120 |
+
# the 3-layer stem
|
121 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
122 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
123 |
+
self.relu1 = nn.ReLU(inplace=True)
|
124 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
125 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
126 |
+
self.relu2 = nn.ReLU(inplace=True)
|
127 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
128 |
+
self.bn3 = nn.BatchNorm2d(width)
|
129 |
+
self.relu3 = nn.ReLU(inplace=True)
|
130 |
+
self.avgpool = nn.AvgPool2d(2)
|
131 |
+
|
132 |
+
# residual layers
|
133 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
134 |
+
self.layer1 = self._make_layer(width, layers[0])
|
135 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
136 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
137 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
138 |
+
|
139 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
140 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
141 |
+
|
142 |
+
def _make_layer(self, planes, blocks, stride=1):
|
143 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
144 |
+
|
145 |
+
self._inplanes = planes * Bottleneck.expansion
|
146 |
+
for _ in range(1, blocks):
|
147 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
148 |
+
|
149 |
+
return nn.Sequential(*layers)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
def stem(x):
|
153 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
154 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
155 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
156 |
+
x = self.avgpool(x)
|
157 |
+
return x
|
158 |
+
|
159 |
+
x = x.type(self.conv1.weight.dtype)
|
160 |
+
x = stem(x)
|
161 |
+
x = self.layer1(x)
|
162 |
+
x = self.layer2(x)
|
163 |
+
x = self.layer3(x)
|
164 |
+
x = self.layer4(x)
|
165 |
+
x = self.attnpool(x)
|
166 |
+
|
167 |
+
return x
|
168 |
+
|
169 |
+
|
170 |
+
class LayerNorm(nn.LayerNorm):
|
171 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
172 |
+
|
173 |
+
def forward(self, x: torch.Tensor):
|
174 |
+
orig_type = x.dtype
|
175 |
+
ret = super().forward(x.type(torch.float32))
|
176 |
+
return ret.type(orig_type)
|
177 |
+
|
178 |
+
|
179 |
+
class QuickGELU(nn.Module):
|
180 |
+
def forward(self, x: torch.Tensor):
|
181 |
+
return x * torch.sigmoid(1.702 * x)
|
182 |
+
|
183 |
+
|
184 |
+
class ResidualAttentionBlock(nn.Module):
|
185 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
186 |
+
super().__init__()
|
187 |
+
|
188 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
189 |
+
self.ln_1 = LayerNorm(d_model)
|
190 |
+
self.mlp = nn.Sequential(OrderedDict([
|
191 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
192 |
+
("gelu", QuickGELU()),
|
193 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
194 |
+
]))
|
195 |
+
self.ln_2 = LayerNorm(d_model)
|
196 |
+
self.attn_mask = attn_mask
|
197 |
+
self.need_weights = need_weights
|
198 |
+
|
199 |
+
def attention(self, x: torch.Tensor):
|
200 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
201 |
+
if self.need_weights == False:
|
202 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
203 |
+
else:
|
204 |
+
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor):
|
207 |
+
if self.need_weights == False:
|
208 |
+
x = x + self.attention(self.ln_1(x))
|
209 |
+
x = x + self.mlp(self.ln_2(x))
|
210 |
+
return x
|
211 |
+
else:
|
212 |
+
y, attn = self.attention(self.ln_1(x))
|
213 |
+
x = x + y
|
214 |
+
x = x + self.mlp(self.ln_2(x))
|
215 |
+
return x
|
216 |
+
|
217 |
+
|
218 |
+
class Transformer(nn.Module):
|
219 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, need_weights: bool = False):
|
220 |
+
super().__init__()
|
221 |
+
self.width = width
|
222 |
+
self.layers = layers
|
223 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, need_weights if i == layers - 1 else False) for i in range(layers)])
|
224 |
+
|
225 |
+
def forward(self, x: torch.Tensor):
|
226 |
+
return self.resblocks(x)
|
227 |
+
|
228 |
+
def get_cast_dtype(self) -> torch.dtype:
|
229 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
class VisionTransformer(nn.Module):
|
234 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
235 |
+
super().__init__()
|
236 |
+
self.input_resolution = input_resolution
|
237 |
+
self.output_dim = output_dim
|
238 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
239 |
+
|
240 |
+
scale = width ** -0.5
|
241 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
242 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
243 |
+
self.ln_pre = LayerNorm(width)
|
244 |
+
|
245 |
+
self.transformer = Transformer(width, layers, heads, need_weights=True)
|
246 |
+
|
247 |
+
self.ln_post = LayerNorm(width)
|
248 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor):
|
251 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
252 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
253 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
254 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
255 |
+
|
256 |
+
#####################################################################################
|
257 |
+
side = int((self.positional_embedding.shape[0] - 1) ** 0.5)
|
258 |
+
new_side = int((x.shape[1] - 1) ** 0.5)
|
259 |
+
|
260 |
+
# update the position embedding during inference for varied input size
|
261 |
+
if side != new_side:
|
262 |
+
new_pos = self.positional_embedding[1:, :].reshape(-1, side, side, x.shape[-1]).permute(0, 3, 1, 2)
|
263 |
+
new_pos = torch.nn.functional.interpolate(new_pos, (new_side, new_side), mode='bilinear')
|
264 |
+
new_pos = new_pos.reshape(-1, x.shape[-1], new_side * new_side).transpose(1, 2)
|
265 |
+
self.positional_embedding.data = torch.cat([self.positional_embedding[:1, :], new_pos[0]], 0)
|
266 |
+
#####################################################################################
|
267 |
+
|
268 |
+
|
269 |
+
x = x + self.positional_embedding.to(x.dtype)
|
270 |
+
x = self.ln_pre(x)
|
271 |
+
|
272 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
273 |
+
x = self.transformer(x)
|
274 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
275 |
+
|
276 |
+
#x = self.ln_post(x[:, 0, :])
|
277 |
+
x = self.ln_post(x) # return both cls token and image tokens
|
278 |
+
|
279 |
+
if self.proj is not None:
|
280 |
+
x = x @ self.proj
|
281 |
+
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class CLIP(nn.Module):
|
286 |
+
def __init__(self,
|
287 |
+
embed_dim: int,
|
288 |
+
# vision
|
289 |
+
image_resolution: int,
|
290 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
291 |
+
vision_width: int,
|
292 |
+
vision_patch_size: int,
|
293 |
+
# text
|
294 |
+
context_length: int,
|
295 |
+
vocab_size: int,
|
296 |
+
transformer_width: int,
|
297 |
+
transformer_heads: int,
|
298 |
+
transformer_layers: int
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.context_length = context_length
|
303 |
+
|
304 |
+
if isinstance(vision_layers, (tuple, list)):
|
305 |
+
vision_heads = vision_width * 32 // 64
|
306 |
+
self.visual = ModifiedResNet(
|
307 |
+
layers=vision_layers,
|
308 |
+
output_dim=embed_dim,
|
309 |
+
heads=vision_heads,
|
310 |
+
input_resolution=image_resolution,
|
311 |
+
width=vision_width
|
312 |
+
)
|
313 |
+
else:
|
314 |
+
vision_heads = vision_width // 64
|
315 |
+
self.visual = VisionTransformer(
|
316 |
+
input_resolution=image_resolution,
|
317 |
+
patch_size=vision_patch_size,
|
318 |
+
width=vision_width,
|
319 |
+
layers=vision_layers,
|
320 |
+
heads=vision_heads,
|
321 |
+
output_dim=embed_dim
|
322 |
+
)
|
323 |
+
|
324 |
+
self.transformer = Transformer(
|
325 |
+
width=transformer_width,
|
326 |
+
layers=transformer_layers,
|
327 |
+
heads=transformer_heads,
|
328 |
+
attn_mask=self.build_attention_mask()
|
329 |
+
)
|
330 |
+
|
331 |
+
self.vocab_size = vocab_size
|
332 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
333 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
334 |
+
self.ln_final = LayerNorm(transformer_width)
|
335 |
+
|
336 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
337 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
338 |
+
|
339 |
+
self.initialize_parameters()
|
340 |
+
|
341 |
+
def initialize_parameters(self):
|
342 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
343 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
344 |
+
|
345 |
+
if isinstance(self.visual, ModifiedResNet):
|
346 |
+
if self.visual.attnpool is not None:
|
347 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
348 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
349 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
350 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
351 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
352 |
+
|
353 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
354 |
+
for name, param in resnet_block.named_parameters():
|
355 |
+
if name.endswith("bn3.weight"):
|
356 |
+
nn.init.zeros_(param)
|
357 |
+
|
358 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
359 |
+
attn_std = self.transformer.width ** -0.5
|
360 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
361 |
+
for block in self.transformer.resblocks:
|
362 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
363 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
364 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
365 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
366 |
+
|
367 |
+
if self.text_projection is not None:
|
368 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
369 |
+
|
370 |
+
def build_attention_mask(self):
|
371 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
372 |
+
# pytorch uses additive attention mask; fill with -inf
|
373 |
+
mask = torch.empty(self.context_length, self.context_length)
|
374 |
+
mask.fill_(float("-inf"))
|
375 |
+
mask.triu_(1) # zero out the lower diagonal
|
376 |
+
return mask
|
377 |
+
|
378 |
+
@property
|
379 |
+
def dtype(self):
|
380 |
+
return self.visual.conv1.weight.dtype
|
381 |
+
|
382 |
+
def encode_image(self, image):
|
383 |
+
return self.visual(image.type(self.dtype))
|
384 |
+
|
385 |
+
def encode_text(self, text):
|
386 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
387 |
+
|
388 |
+
x = x + self.positional_embedding.type(self.dtype)
|
389 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
390 |
+
x = self.transformer(x)
|
391 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
392 |
+
x = self.ln_final(x).type(self.dtype)
|
393 |
+
|
394 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
395 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
396 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
397 |
+
|
398 |
+
return x
|
399 |
+
|
400 |
+
def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
|
401 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
402 |
+
|
403 |
+
# x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
404 |
+
|
405 |
+
# x = x + self.positional_embedding.to(cast_dtype)
|
406 |
+
|
407 |
+
x = prompts + self.positional_embedding.to(cast_dtype)
|
408 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
409 |
+
# print("test", x.shape, len(deep_compound_prompts_text))
|
410 |
+
if deep_compound_prompts_text is None:
|
411 |
+
x = self.transformer(x)
|
412 |
+
else:
|
413 |
+
x = self.transformer([x, deep_compound_prompts_text, 0])
|
414 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
415 |
+
x = self.ln_final(x).type(self.dtype) # [batch_size, n_ctx, transformer.width]
|
416 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
417 |
+
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
|
418 |
+
return x
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
def forward(self, image, text):
|
423 |
+
image_features = self.encode_image(image)
|
424 |
+
text_features = self.encode_text(text)
|
425 |
+
|
426 |
+
# normalized features
|
427 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
428 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
429 |
+
|
430 |
+
# cosine similarity as logits
|
431 |
+
logit_scale = self.logit_scale.exp()
|
432 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
433 |
+
logits_per_text = logits_per_image.t()
|
434 |
+
|
435 |
+
# shape = [global_batch_size, global_batch_size]
|
436 |
+
return logits_per_image, logits_per_text
|
AnomalyCLIP_lib/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model_load import *
|
AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-38.pyc
ADDED
Binary file (15.2 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/AnomalyCLIP.cpython-39.pyc
ADDED
Binary file (15.2 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/CLIP.cpython-39.pyc
ADDED
Binary file (13.9 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (199 Bytes). View file
|
|
AnomalyCLIP_lib/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (209 Bytes). View file
|
|
AnomalyCLIP_lib/__pycache__/build_model.cpython-38.pyc
ADDED
Binary file (2.27 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/build_model.cpython-39.pyc
ADDED
Binary file (2.21 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/clip.cpython-38.pyc
ADDED
Binary file (19.7 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/clip_model.cpython-38.pyc
ADDED
Binary file (13.9 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/clip_surgery_model.cpython-38.pyc
ADDED
Binary file (20.7 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/constants.cpython-38.pyc
ADDED
Binary file (279 Bytes). View file
|
|
AnomalyCLIP_lib/__pycache__/constants.cpython-39.pyc
ADDED
Binary file (289 Bytes). View file
|
|
AnomalyCLIP_lib/__pycache__/model_load.cpython-38.pyc
ADDED
Binary file (7.79 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/model_load.cpython-39.pyc
ADDED
Binary file (7.87 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-38.pyc
ADDED
Binary file (5.82 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/simple_tokenizer.cpython-39.pyc
ADDED
Binary file (5.79 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/transform.cpython-38.pyc
ADDED
Binary file (4.18 kB). View file
|
|
AnomalyCLIP_lib/__pycache__/transform.cpython-39.pyc
ADDED
Binary file (4.16 kB). View file
|
|
AnomalyCLIP_lib/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
AnomalyCLIP_lib/build_model.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from .CLIP import CLIP
|
3 |
+
from .AnomalyCLIP import AnomalyCLIP
|
4 |
+
|
5 |
+
def build_model(name: str, state_dict: dict, design_details = None):
|
6 |
+
vit = "visual.proj" in state_dict
|
7 |
+
|
8 |
+
if vit:
|
9 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
10 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
11 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
12 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
13 |
+
image_resolution = vision_patch_size * grid_size
|
14 |
+
else:
|
15 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
16 |
+
vision_layers = tuple(counts)
|
17 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
18 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
19 |
+
vision_patch_size = None
|
20 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
21 |
+
image_resolution = output_width * 32
|
22 |
+
|
23 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
24 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
25 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
26 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
27 |
+
transformer_heads = transformer_width // 64
|
28 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
29 |
+
# print('name', name)
|
30 |
+
# if 'CS-' in name:
|
31 |
+
if design_details is not None:
|
32 |
+
model = AnomalyCLIP(
|
33 |
+
embed_dim,
|
34 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
35 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, design_details = design_details
|
36 |
+
)
|
37 |
+
else:
|
38 |
+
model = CLIP(
|
39 |
+
embed_dim,
|
40 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
41 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
42 |
+
)
|
43 |
+
|
44 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
45 |
+
if key in state_dict:
|
46 |
+
del state_dict[key]
|
47 |
+
|
48 |
+
#convert_weights(model)
|
49 |
+
model.load_state_dict(state_dict)
|
50 |
+
return model.eval()
|
AnomalyCLIP_lib/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
AnomalyCLIP_lib/model_load.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
from .build_model import build_model
|
15 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
16 |
+
from torchvision.transforms import InterpolationMode
|
17 |
+
|
18 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
19 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
20 |
+
|
21 |
+
|
22 |
+
__all__ = ["available_models", "load",
|
23 |
+
"get_similarity_map", "compute_similarity"]
|
24 |
+
_tokenizer = _Tokenizer()
|
25 |
+
|
26 |
+
_MODELS = {
|
27 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def _download(
|
32 |
+
url: str,
|
33 |
+
cache_dir: Union[str, None] = None,
|
34 |
+
):
|
35 |
+
|
36 |
+
if not cache_dir:
|
37 |
+
# cache_dir = os.path.expanduser("~/.cache/clip")
|
38 |
+
cache_dir = os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip")
|
39 |
+
os.makedirs(cache_dir, exist_ok=True)
|
40 |
+
filename = os.path.basename(url)
|
41 |
+
|
42 |
+
if 'openaipublic' in url:
|
43 |
+
expected_sha256 = url.split("/")[-2]
|
44 |
+
elif 'mlfoundations' in url:
|
45 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
46 |
+
else:
|
47 |
+
expected_sha256 = ''
|
48 |
+
|
49 |
+
download_target = os.path.join(cache_dir, filename)
|
50 |
+
|
51 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
52 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
53 |
+
|
54 |
+
if os.path.isfile(download_target):
|
55 |
+
if expected_sha256:
|
56 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
57 |
+
return download_target
|
58 |
+
else:
|
59 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
60 |
+
else:
|
61 |
+
return download_target
|
62 |
+
|
63 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
64 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
65 |
+
while True:
|
66 |
+
buffer = source.read(8192)
|
67 |
+
if not buffer:
|
68 |
+
break
|
69 |
+
|
70 |
+
output.write(buffer)
|
71 |
+
loop.update(len(buffer))
|
72 |
+
|
73 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
74 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
75 |
+
|
76 |
+
return download_target
|
77 |
+
|
78 |
+
|
79 |
+
def _convert_image_to_rgb(image):
|
80 |
+
return image.convert("RGB")
|
81 |
+
|
82 |
+
|
83 |
+
def _transform(n_px):
|
84 |
+
return Compose([
|
85 |
+
Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC),
|
86 |
+
#CenterCrop(n_px), # rm center crop to explain whole image
|
87 |
+
_convert_image_to_rgb,
|
88 |
+
ToTensor(),
|
89 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
90 |
+
])
|
91 |
+
|
92 |
+
|
93 |
+
def available_models() -> List[str]:
|
94 |
+
"""Returns the names of available CLIP models"""
|
95 |
+
return list(_MODELS.keys())
|
96 |
+
|
97 |
+
|
98 |
+
def load_state_dict(checkpoint_path: str, map_location='cpu'):
|
99 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
100 |
+
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
101 |
+
state_dict = checkpoint['state_dict']
|
102 |
+
else:
|
103 |
+
state_dict = checkpoint
|
104 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
105 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
106 |
+
return state_dict
|
107 |
+
|
108 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
109 |
+
state_dict = load_state_dict(checkpoint_path)
|
110 |
+
# detect old format and make compatible with new format
|
111 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
112 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
113 |
+
resize_pos_embed(state_dict, model)
|
114 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
115 |
+
return incompatible_keys
|
116 |
+
|
117 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", design_details = None, jit: bool = False, download_root: str = None):
|
118 |
+
"""Load a CLIP model
|
119 |
+
|
120 |
+
Parameters
|
121 |
+
----------
|
122 |
+
name : str
|
123 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
124 |
+
|
125 |
+
device : Union[str, torch.device]
|
126 |
+
The device to put the loaded model
|
127 |
+
|
128 |
+
jit : bool
|
129 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
130 |
+
|
131 |
+
download_root: str
|
132 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
133 |
+
|
134 |
+
Returns
|
135 |
+
-------
|
136 |
+
model : torch.nn.Module
|
137 |
+
The CLIP model
|
138 |
+
|
139 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
140 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
141 |
+
"""
|
142 |
+
print("name", name)
|
143 |
+
if name in _MODELS:
|
144 |
+
# model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
145 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("/remote-home/iot_zhouqihang/root/.cache/clip"))
|
146 |
+
elif os.path.isfile(name):
|
147 |
+
model_path = name
|
148 |
+
else:
|
149 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
150 |
+
|
151 |
+
with open(model_path, 'rb') as opened_file:
|
152 |
+
try:
|
153 |
+
# loading JIT archive
|
154 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
155 |
+
state_dict = None
|
156 |
+
except RuntimeError:
|
157 |
+
# loading saved state dict
|
158 |
+
if jit:
|
159 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
160 |
+
jit = False
|
161 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
162 |
+
|
163 |
+
if not jit:
|
164 |
+
model = build_model(name, state_dict or model.state_dict(), design_details).to(device)
|
165 |
+
if str(device) == "cpu":
|
166 |
+
model.float()
|
167 |
+
return model, _transform(model.visual.input_resolution)
|
168 |
+
|
169 |
+
# patch the device names
|
170 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
171 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
172 |
+
|
173 |
+
def patch_device(module):
|
174 |
+
try:
|
175 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
176 |
+
except RuntimeError:
|
177 |
+
graphs = []
|
178 |
+
|
179 |
+
if hasattr(module, "forward1"):
|
180 |
+
graphs.append(module.forward1.graph)
|
181 |
+
|
182 |
+
for graph in graphs:
|
183 |
+
for node in graph.findAllNodes("prim::Constant"):
|
184 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
185 |
+
node.copyAttributes(device_node)
|
186 |
+
|
187 |
+
model.apply(patch_device)
|
188 |
+
patch_device(model.encode_image)
|
189 |
+
patch_device(model.encode_text)
|
190 |
+
|
191 |
+
# patch dtype to float32 on CPU
|
192 |
+
if str(device) == "cpu":
|
193 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
194 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
195 |
+
float_node = float_input.node()
|
196 |
+
|
197 |
+
def patch_float(module):
|
198 |
+
try:
|
199 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
200 |
+
except RuntimeError:
|
201 |
+
graphs = []
|
202 |
+
|
203 |
+
if hasattr(module, "forward1"):
|
204 |
+
graphs.append(module.forward1.graph)
|
205 |
+
|
206 |
+
for graph in graphs:
|
207 |
+
for node in graph.findAllNodes("aten::to"):
|
208 |
+
inputs = list(node.inputs())
|
209 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
210 |
+
if inputs[i].node()["value"] == 5:
|
211 |
+
inputs[i].node().copyAttributes(float_node)
|
212 |
+
|
213 |
+
model.apply(patch_float)
|
214 |
+
patch_float(model.encode_image)
|
215 |
+
patch_float(model.encode_text)
|
216 |
+
|
217 |
+
model.float()
|
218 |
+
|
219 |
+
return model, _transform(model.input_resolution.item())
|
220 |
+
|
221 |
+
|
222 |
+
def get_similarity_map(sm, shape):
|
223 |
+
side = int(sm.shape[1] ** 0.5)
|
224 |
+
sm = sm.reshape(sm.shape[0], side, side, -1).permute(0, 3, 1, 2)
|
225 |
+
sm = torch.nn.functional.interpolate(sm, shape, mode='bilinear')
|
226 |
+
sm = sm.permute(0, 2, 3, 1)
|
227 |
+
return sm
|
228 |
+
|
229 |
+
|
230 |
+
def compute_similarity(image_features, text_features, t=2):
|
231 |
+
prob_1 = image_features[:, :1, :] @ text_features.t()
|
232 |
+
b, n_t, n_i, c = image_features.shape[0], text_features.shape[0], image_features.shape[1], image_features.shape[2]
|
233 |
+
feats = image_features.reshape(b, n_i, 1, c) * text_features.reshape(1, 1, n_t, c)
|
234 |
+
similarity = feats.sum(-1)
|
235 |
+
return (similarity/0.07).softmax(-1), prob_1
|
AnomalyCLIP_lib/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
AnomalyCLIP_lib/transform.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from dataclasses import dataclass, asdict
|
3 |
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
|
9 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
10 |
+
CenterCrop
|
11 |
+
|
12 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class AugmentationCfg:
|
17 |
+
scale: Tuple[float, float] = (0.9, 1.0)
|
18 |
+
ratio: Optional[Tuple[float, float]] = None
|
19 |
+
color_jitter: Optional[Union[float, Tuple[float, float, float]]] = None
|
20 |
+
interpolation: Optional[str] = None
|
21 |
+
re_prob: Optional[float] = None
|
22 |
+
re_count: Optional[int] = None
|
23 |
+
use_timm: bool = False
|
24 |
+
|
25 |
+
|
26 |
+
class ResizeMaxSize(nn.Module):
|
27 |
+
|
28 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
29 |
+
super().__init__()
|
30 |
+
if not isinstance(max_size, int):
|
31 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
32 |
+
self.max_size = max_size
|
33 |
+
self.interpolation = interpolation
|
34 |
+
self.fn = min if fn == 'min' else min
|
35 |
+
self.fill = fill
|
36 |
+
|
37 |
+
def forward(self, img):
|
38 |
+
if isinstance(img, torch.Tensor):
|
39 |
+
height, width = img.shape[:2]
|
40 |
+
else:
|
41 |
+
width, height = img.size
|
42 |
+
scale = self.max_size / float(max(height, width))
|
43 |
+
if scale != 1.0:
|
44 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
45 |
+
img = F.resize(img, new_size, self.interpolation)
|
46 |
+
pad_h = self.max_size - new_size[0]
|
47 |
+
pad_w = self.max_size - new_size[1]
|
48 |
+
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
49 |
+
return img
|
50 |
+
|
51 |
+
|
52 |
+
def _convert_to_rgb(image):
|
53 |
+
return image.convert('RGB')
|
54 |
+
|
55 |
+
|
56 |
+
def image_transform(
|
57 |
+
image_size: int,
|
58 |
+
is_train: bool,
|
59 |
+
mean: Optional[Tuple[float, ...]] = None,
|
60 |
+
std: Optional[Tuple[float, ...]] = None,
|
61 |
+
resize_longest_max: bool = False,
|
62 |
+
fill_color: int = 0,
|
63 |
+
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
64 |
+
):
|
65 |
+
mean = mean or OPENAI_DATASET_MEAN
|
66 |
+
if not isinstance(mean, (list, tuple)):
|
67 |
+
mean = (mean,) * 3
|
68 |
+
|
69 |
+
std = std or OPENAI_DATASET_STD
|
70 |
+
if not isinstance(std, (list, tuple)):
|
71 |
+
std = (std,) * 3
|
72 |
+
|
73 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
74 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
75 |
+
image_size = image_size[0]
|
76 |
+
|
77 |
+
if isinstance(aug_cfg, dict):
|
78 |
+
aug_cfg = AugmentationCfg(**aug_cfg)
|
79 |
+
else:
|
80 |
+
aug_cfg = aug_cfg or AugmentationCfg()
|
81 |
+
normalize = Normalize(mean=mean, std=std)
|
82 |
+
if is_train:
|
83 |
+
aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
|
84 |
+
use_timm = aug_cfg_dict.pop('use_timm', False)
|
85 |
+
if use_timm:
|
86 |
+
from timm.data import create_transform # timm can still be optional
|
87 |
+
if isinstance(image_size, (tuple, list)):
|
88 |
+
assert len(image_size) >= 2
|
89 |
+
input_size = (3,) + image_size[-2:]
|
90 |
+
else:
|
91 |
+
input_size = (3, image_size, image_size)
|
92 |
+
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
|
93 |
+
aug_cfg_dict.setdefault('interpolation', 'random')
|
94 |
+
aug_cfg_dict.setdefault('color_jitter', None) # disable by default
|
95 |
+
train_transform = create_transform(
|
96 |
+
input_size=input_size,
|
97 |
+
is_training=True,
|
98 |
+
hflip=0.,
|
99 |
+
mean=mean,
|
100 |
+
std=std,
|
101 |
+
re_mode='pixel',
|
102 |
+
**aug_cfg_dict,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
train_transform = Compose([
|
106 |
+
RandomResizedCrop(
|
107 |
+
image_size,
|
108 |
+
scale=aug_cfg_dict.pop('scale'),
|
109 |
+
interpolation=InterpolationMode.BICUBIC,
|
110 |
+
),
|
111 |
+
_convert_to_rgb,
|
112 |
+
ToTensor(),
|
113 |
+
normalize,
|
114 |
+
])
|
115 |
+
if aug_cfg_dict:
|
116 |
+
warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).')
|
117 |
+
return train_transform
|
118 |
+
else:
|
119 |
+
if resize_longest_max:
|
120 |
+
transforms = [
|
121 |
+
ResizeMaxSize(image_size, fill=fill_color)
|
122 |
+
]
|
123 |
+
else:
|
124 |
+
transforms = [
|
125 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
126 |
+
CenterCrop(image_size),
|
127 |
+
]
|
128 |
+
transforms.extend([
|
129 |
+
_convert_to_rgb,
|
130 |
+
ToTensor(),
|
131 |
+
normalize,
|
132 |
+
])
|
133 |
+
return Compose(transforms)
|
README.md
CHANGED
@@ -1,117 +1,80 @@
|
|
1 |
|
2 |
-
# CLIP
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
- **model:**
|
79 |
-
- input_layer:
|
80 |
-
- image_size: [640, 640, 3] # 표준 입력 이미지 크기
|
81 |
-
- backbone:
|
82 |
-
- name: CLIP (ViT-B-32) # CLIP 모델의 비전 트랜스포머를 백본으로 사용
|
83 |
-
- filters: [32, 64, 128, 256, 512] # 비전 트랜스포머의 각 레이어 필터 크기
|
84 |
-
- neck:
|
85 |
-
- name: Anomaly Detection Module # 결함 탐지를 위한 추가 모듈
|
86 |
-
- method: Contrastive Learning # CLIP 모델의 특징을 사용한 대조 학습 기법
|
87 |
-
- head:
|
88 |
-
- name: Anomaly Detection Head # 결함 탐지를 위한 최종 출력 레이어
|
89 |
-
- outputs:
|
90 |
-
- anomaly_score: 1 # 이상 탐지 점수 (비정상/정상 구분)
|
91 |
-
- class_probabilities: N # 각 클래스에 대한 확률 (결함 여부)
|
92 |
-
|
93 |
-
# Optimizer and Loss Function
|
94 |
-
- **training:**
|
95 |
-
- optimizer:
|
96 |
-
- name: AdamW # AdamW 옵티마이저 (가중치 감쇠 포함)
|
97 |
-
- lr: 0.0001 # 학습률
|
98 |
-
- loss:
|
99 |
-
- classification_loss: 1.0 # 분류 손실 (교차 엔트로피)
|
100 |
-
- anomaly_loss: 1.0 # 결함 탐지 손실 (이상 탐지 모델에 대한 손실)
|
101 |
-
- contrastive_loss: 1.0 # 대조 학습 손실 (유사도 기반 손실)
|
102 |
-
|
103 |
-
# Metrics
|
104 |
-
- **metrics:**
|
105 |
-
- Precision # 정밀도 (Precision)
|
106 |
-
- Recall # 재현율 (Recall)
|
107 |
-
- mAP # 평균 정밀도 (Mean Average Precision)
|
108 |
-
- F1-Score # F1-점수 (균형 잡힌 평가 지표)
|
109 |
-
|
110 |
-
# Training Parameters
|
111 |
-
**하이퍼파라미터 설정**
|
112 |
-
- Learning Rate: 0.001.
|
113 |
-
- Batch Size: 8.
|
114 |
-
- Epochs: 200.
|
115 |
|
116 |
# Pre-trained CLIP model
|
117 |
| Model | Download |
|
@@ -121,75 +84,132 @@
|
|
121 |
| ViT-L/14 | [download](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) |
|
122 |
| ViT-L/14@336px | [download](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt) |
|
123 |
|
124 |
-
#
|
125 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
-
|
131 |
-
|
|
|
|
|
|
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
</div>
|
142 |
-
<div style="text-align: center; margin-right: 20px;">
|
143 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/6n0DnnQjXD8Ql-p3Owxan.png" height="80%" width="100%" style="margin-right:5px;">
|
144 |
-
<p>3차 학습 성능</p>
|
145 |
-
</div>
|
146 |
-
</div>
|
147 |
|
148 |
-
- **학습 결과표**:
|
149 |
-

|
150 |
-
|
151 |
-
- **테스트 결과**:
|
152 |
-
<div style="display: flex; justify-content: space-between;">
|
153 |
-
<div style="text-align: center; margin-right: 20px;">
|
154 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/A91V0GdrcUcX01cC-biG9.png" height="600" width="1000" style="margin-right:5px;">
|
155 |
-
<p>Anomaly Product</p>
|
156 |
-
</div>
|
157 |
-
<div style="text-align: center; margin-right: 20px;">
|
158 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/65e7d0935ea025ead9623dde/PxleIhphzViTGCubVhWn7.png" height="600" width="1000" style="margin-right:5px;">
|
159 |
-
<p>Normal Product</p>
|
160 |
-
</div>
|
161 |
-
</div>
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
|
|
|
|
165 |
|
166 |
-
이 모델을 실행하려면 Python과 함께 다음 라이브러리가 필요합니다:
|
167 |
|
168 |
-
- **ftfy==6.2.0**: 텍스트 정규화 및 인코딩 문제를 해결하는 라이브러리.
|
169 |
-
- **matplotlib==3.9.0**: 데이터 시각화 및 그래프 생성을 위한 라이브러리.
|
170 |
-
- **numpy==1.24.3**: 수치 연산을 위한 핵심 라이브러리.
|
171 |
-
- **opencv_python==4.9.0.80**: 이미지 및 비디오 처리용 라이브러리.
|
172 |
-
- **pandas==2.2.2**: 데이터 분석 및 조작을 위한 라이브러리.
|
173 |
-
- **Pillow==10.3.0**: 이미지 파일 처리 및 변환을 위한 라이브러리.
|
174 |
-
- **PyQt5==5.15.10**: GUI 애플리케이션 개발을 위한 프레임워크.
|
175 |
-
- **PyQt5_sip==12.13.0**: PyQt5와 Python 간의 인터페이스를 제공하는 라이브러리.
|
176 |
-
- **regex==2024.5.15**: 정규 표현식 처리를 위한 라이브러리.
|
177 |
-
- **scikit_learn==1.2.2**: 기계 학습 및 데이터 분석을 위한 라이브러리.
|
178 |
-
- **scipy==1.9.1**: 과학 및 기술 계산을 위한 라이브러리.
|
179 |
-
- **setuptools==59.5.0**: Python 패키지 배포 및 설치를 위한 라이브러리.
|
180 |
-
- **scikit-image**: 이미지 처리 및 분석을 위한 라이브러리.
|
181 |
-
- **tabulate==0.9.0**: 표 형태로 데이터를 출력하는 라이브러리.
|
182 |
-
- **thop==0.1.1.post2209072238**: PyTorch 모델의 FLOP 수를 계산하는 도구.
|
183 |
-
- **timm==0.6.13**: 다양한 최신 이미지 분류 모델을 제공하는 라이브러리.
|
184 |
-
- **torch==2.0.0**: PyTorch 딥러닝 프레임워크.
|
185 |
-
- **torchvision==0.15.1**: 컴퓨터 비전 작업을 위한 PyTorch 확장 라이브러리.
|
186 |
-
- **tqdm==4.65.0**: 진행 상황을 시각적으로 표시하는 라이브러리.
|
187 |
-
- **pyautogui**: GUI 자동화를 위한 라이브러리.
|
188 |
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
|
191 |
|
192 |
-
### 모델 실행 단계:
|
193 |
|
194 |
### ✅ Prompt generating
|
195 |
```ruby
|
@@ -249,18 +269,8 @@ parser.add_argument("--dpam", type=int, default=20, help="dpam size")
|
|
249 |
→ If you want to focus only on the final layers (where the model usually learns complex features), you can choose fewer DPAM layers.
|
250 |
```
|
251 |
|
252 |
-
### ✅ Test process
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
→ Used as the starting point for training the CLIP model
|
259 |
-
→ Pre-trained model helps speed up and improve training by leveraging previously learned features
|
260 |
-
```
|
261 |
-
2. Fine-tuned models (./checkpoint/):
|
262 |
-
```ruby
|
263 |
-
→ "epoch_N.pth" files in this folder store the model's states during the fine-tuning process.
|
264 |
-
→ Each ".pth" file represents a version of the model fine-tuned from the pre-trained model
|
265 |
-
→ These checkpoints can be used to resume fine-tuning, evaluate the model at different stages, or select the best-performing version
|
266 |
-
```
|
|
|
1 |
|
2 |
+
# CLIP based ANOMALY DETECTION
|
3 |
+
|
4 |
+
<div align="center">
|
5 |
+
|
6 |
+
[]()
|
7 |
+
[](https://github.com/kylelobo/The-Documentation-Compendium/issues)
|
8 |
+
[](https://github.com/kylelobo/The-Documentation-Compendium/pulls)
|
9 |
+
[](/LICENSE)
|
10 |
+
|
11 |
+
</div>
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |
+
<p align="center"> Anomaly detection (AD) requires detection models trained using auxiliary data to detect anomalies without any training sample in a target dataset. AnomalyCLIP is to learn object-agnostic text prompts that capture generic normality and abnormality in an image regardless of its foreground objects. This allows our model to focus on the abnormal image regions rather than the object semantics, enabling generalized normality and abnormality recognition on diverse types of objects. All experiments are conducted in PyTorch-2.0.0 with a single NVIDIA RTX 4090 24GB.
|
16 |
+
<br>
|
17 |
+
</p>
|
18 |
+
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
# 📝 Table of Contents
|
23 |
+
|
24 |
+
- [Update](#update)
|
25 |
+
- [Install & Dependence](#install--dependence)
|
26 |
+
- [Dataset Preparation](#dataset-preparation)
|
27 |
+
- [Pre-trained CLIP model](#pre-trained-clip-model)
|
28 |
+
- [Usage](#usage)
|
29 |
+
- [Code Details](#code-details)
|
30 |
+
- [References](#references)
|
31 |
+
|
32 |
+
# Update
|
33 |
+
- 08.08.2024: Code has been released !!!
|
34 |
+
|
35 |
+
|
36 |
+
# Install & Dependence
|
37 |
+
|
38 |
+
### ⭕ Tested Platform
|
39 |
+
- Software Information
|
40 |
+
```
|
41 |
+
OS: Windows 11 64 bit
|
42 |
+
Python: 3.9.18 (anaconda)
|
43 |
+
PyTorch: 2.0.0
|
44 |
+
Cuda Toolkit: 11.8
|
45 |
+
CudDNN: 9.3.0.75 for cuda11
|
46 |
+
```
|
47 |
+

|
48 |
+
|
49 |
+
- Hardware
|
50 |
+
```
|
51 |
+
CPU: Intel(R) Core(TM) i7-14700K 3.40 GHz
|
52 |
+
RAM: 64GB
|
53 |
+
GPU: Nvidia RTX4090 (24GB)
|
54 |
+
```
|
55 |
+
|
56 |
+
|
57 |
+
- Install Python libraries
|
58 |
+
```
|
59 |
+
pip install -r requirements.txt
|
60 |
+
```
|
61 |
+
|
62 |
+
# Dataset Preparation
|
63 |
+
|
64 |
+
Download the dataset below:
|
65 |
+
|
66 |
+
* Industrial Domain:
|
67 |
+
|
68 |
+
| Dataset | Download |
|
69 |
+
| --- | --- |
|
70 |
+
| MVTec | [download](https://www.mvtec.com/company/research/datasets/mvtec-ad) |
|
71 |
+
| VisA | [download](https://github.com/amazon-science/spot-diff) |
|
72 |
+
| MPDD | [download](https://github.com/stepanje/MPDD) |
|
73 |
+
| BTAD | [download](http://avires.dimi.uniud.it/papers/btad/btad.zip) |
|
74 |
+
| SDD | [download](https://www.vicos.si/resources/kolektorsdd/) |
|
75 |
+
| DAGM | [download](https://www.kaggle.com/datasets/mhskjelvareid/dagm-2007-competition-dataset-optical-inspection) |
|
76 |
+
| DTD-Synthetic | [download](https://drive.google.com/drive/folders/10OyPzvI3H6llCZBxKxFlKWt1Pw1tkMK1) |
|
77 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
# Pre-trained CLIP model
|
80 |
| Model | Download |
|
|
|
84 |
| ViT-L/14 | [download](https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt) |
|
85 |
| ViT-L/14@336px | [download](https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt) |
|
86 |
|
87 |
+
# Usage
|
88 |
+
- for train (Fine-tuning)
|
89 |
+
```ruby
|
90 |
+
python train.py
|
91 |
+
```
|
92 |
+
- for test with dataset (many data)
|
93 |
+
```ruby
|
94 |
+
python test.py
|
95 |
+
```
|
96 |
+
- for simple test (개별 이미지 테스트)
|
97 |
+
```ruby
|
98 |
+
python Simple_test_code.py
|
99 |
+
```
|
100 |
+
- for UI app test (simple app developed)
|
101 |
+
```ruby
|
102 |
+
python monitor_check.py
|
103 |
+
```
|
104 |
+
- for real-time detection test (webcam and video tracking)
|
105 |
+
```ruby
|
106 |
+
python real_time_CLIP.py
|
107 |
+
```
|
108 |
+
|
109 |
+
# Code Details
|
110 |
+
|
111 |
+
### ✅Dataset configuration
|
112 |
+
|
113 |
+
- Dataset configuration as example below
|
114 |
+
```
|
115 |
+
├── data/
|
116 |
+
│ ├── COMP_1/
|
117 |
+
│ │ ├── product_1/
|
118 |
+
│ │ │ ├──grouth_truth
|
119 |
+
│ │ │ │ ├──anomaly_1
|
120 |
+
│ │ │ │ ├──anomaly_2
|
121 |
+
│ │ │ │
|
122 |
+
│ │ │ ├──test/
|
123 |
+
│ │ │ │ ├──good
|
124 |
+
│ │ │ │ ├──anomaly_1
|
125 |
+
│ │ │ │ ├──anomaly_2
|
126 |
+
│ │ │ │
|
127 |
+
│ │ │ ├──train/
|
128 |
+
│ │ │ │ ├──good
|
129 |
+
│ │ │ │ ├──anomaly_1
|
130 |
+
│ │ │ │ ├──anomaly_2
|
131 |
+
│ │ │ │
|
132 |
+
│ │ ├── product_2/
|
133 |
+
│ │ │ │
|
134 |
+
│ │
|
135 |
+
│ ├── COMP_2/
|
136 |
+
│ │
|
137 |
+
```
|
138 |
|
139 |
+
- Generate JSON file storing all the above information of dataset ( -> meta_train.json, meta_test.json)
|
140 |
+
```ruby
|
141 |
+
cd dataset_config
|
142 |
+
python dataset_get_json.py
|
143 |
+
```
|
144 |
|
145 |
+
- Making all grouth_truth (only anomaly mask) by hand
|
146 |
+
```ruby
|
147 |
+
cd dataset_config
|
148 |
+
python image_ground_truth.py
|
149 |
+
```
|
150 |
|
151 |
+
- Dataset configuration for train and test
|
152 |
+
```ruby
|
153 |
+
cd training_libs
|
154 |
+
python dataset.py
|
155 |
+
```
|
156 |
|
157 |
+
→ _ _init_ _ 메서드는 데이터셋의 루트 디렉토리, 변환 함수, 데이터셋 이름, 모드를 입력으로 받음
|
158 |
+
→ 메타 정보를 담은 JSON 파일 (meta_train.json)을 읽어와 클래스 이름 목록과 모든 데이터 항목을 리스트에 저장
|
159 |
+
→ generate_class_info 함수를 호출하여 클래스 정보를 생성하고 클래스 이름을 클래스 ID에 매핑
|
160 |
+
→ _ _len_ _ 메서드는 데이터셋의 샘플 수를 반환
|
161 |
+
→ _ _getitem_ _ 메서드는 주어진 인덱스의 샘플 데이터를 반환
|
162 |
+
→ 이미지 경로를 통해 이미지를 읽고, 이상 여부에 따라 마스크 이미지를 생성
|
163 |
+
→ 필요시 이미지와 마스크에 변환 함수를 적용
|
164 |
+
→ 이미지, 마스크, 클래스 이름, 이상 여부, 이미지 경로, 클래스 ID를 포함한 딕셔너리를 반환
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
+
### ✅ Image pre-processing (transformation) for train and test
|
168 |
+
```ruby
|
169 |
+
training_lib/utils.py
|
170 |
+
```
|
171 |
+
```ruby
|
172 |
+
AnomalyCLIP_lib/transform.py
|
173 |
+
```
|
174 |
+
⭐ **Data Processing Techniques**
|
175 |
+
1. Normalization
|
176 |
+
→ Standardize image pixel values using mean and standard deviation
|
177 |
+
→ Utilized via *'Normalize'* from *'torchvision.transforms'*
|
178 |
+
|
179 |
+
2. Normalization
|
180 |
+
→ Resize the image to a maximum dimension while maintaining aspect ratio, with padding
|
181 |
+
→ Custom *'ResizeMaxSize'* class
|
182 |
+
|
183 |
+
3. RandomResizedCrop
|
184 |
+
→ Randomly crop and resize images during training to create variability
|
185 |
+
→ Implemented via *'RandomResizedCrop'* from *'torchvision.transforms'*
|
186 |
+
|
187 |
+
4. Resize
|
188 |
+
→ Resize images to a fixed size for model input
|
189 |
+
→ Done using *'Resize'* with BICUBIC interpolation
|
190 |
+
|
191 |
+
5. Center Crop
|
192 |
+
→ Crop the central region of the image to the desired size
|
193 |
+
→ Applied using *'CenterCrop'*
|
194 |
+
|
195 |
+
6. ToTensor
|
196 |
+
→ Convert images to PyTorch tensors
|
197 |
+
→ Done with *'ToTensor'*
|
198 |
|
199 |
+
7. Augmentation (Optional)
|
200 |
+
→ Apply various random transformations for data augmentation, configurable via *'AugmentationCfg' *
|
201 |
+
→ Uses *'timm'* library if specified
|
202 |
|
|
|
203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
+
⭐ **Libraries Used**
|
206 |
+
1. *'torch'*: Core deep learning library for tensor operations and model building
|
207 |
+
2. *'torchvision'*: Provides image processing utilities like Resize, CenterCrop, Normalize, etc
|
208 |
+
3. *'timm'*: Optional, for advanced augmentation and transformations
|
209 |
+
4. *'AnomalyCLIP_lib'*: Custom library for dataset-specific constants and transformations
|
210 |
|
211 |
|
212 |
|
|
|
213 |
|
214 |
### ✅ Prompt generating
|
215 |
```ruby
|
|
|
269 |
→ If you want to focus only on the final layers (where the model usually learns complex features), you can choose fewer DPAM layers.
|
270 |
```
|
271 |
|
|
|
272 |
|
273 |
+
|
274 |
+
# References
|
275 |
+
- AnomalyCLIP: Object-agnostic Prompt Learning for Zero-shot Anomaly Detection [[github](https://github.com/zqhang/AnomalyCLIP.git)]
|
276 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_config/dataset_get_json.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
class DATASolver(object):
|
6 |
+
|
7 |
+
CLSNAMES = [
|
8 |
+
'shinpyung',
|
9 |
+
# 'gear'# Change to 'welding_test' for testing
|
10 |
+
]
|
11 |
+
|
12 |
+
def __init__(self, root='data/4inlab'):
|
13 |
+
self.root = root
|
14 |
+
self.meta_path = f'{root}/meta_train.json' # Change to meta_test.json for testing
|
15 |
+
|
16 |
+
def run(self):
|
17 |
+
info = dict(train={}, test={})
|
18 |
+
anomaly_samples = 0
|
19 |
+
normal_samples = 0
|
20 |
+
for cls_name in self.CLSNAMES:
|
21 |
+
cls_dir = f'{self.root}/{cls_name}'
|
22 |
+
for phase in ['train', 'test']:
|
23 |
+
cls_info = []
|
24 |
+
species = os.listdir(f'{cls_dir}/{phase}')
|
25 |
+
for specie in species:
|
26 |
+
is_abnormal = True if specie not in ['good'] else False
|
27 |
+
img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
|
28 |
+
mask_names = os.listdir(f'{cls_dir}/ground_truth/{specie}') if is_abnormal else None
|
29 |
+
img_names.sort()
|
30 |
+
mask_names.sort() if mask_names is not None else None
|
31 |
+
for idx, img_name in enumerate(img_names):
|
32 |
+
info_img = dict(
|
33 |
+
img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
|
34 |
+
mask_path=f'{cls_name}/ground_truth/{specie}/{mask_names[idx]}' if is_abnormal else '',
|
35 |
+
cls_name=cls_name,
|
36 |
+
specie_name=specie,
|
37 |
+
anomaly=1 if is_abnormal else 0,
|
38 |
+
)
|
39 |
+
cls_info.append(info_img)
|
40 |
+
if phase == 'test':
|
41 |
+
if is_abnormal:
|
42 |
+
anomaly_samples = anomaly_samples + 1
|
43 |
+
else:
|
44 |
+
normal_samples = normal_samples + 1
|
45 |
+
info[phase][cls_name] = cls_info
|
46 |
+
with open(self.meta_path, 'w') as f:
|
47 |
+
f.write(json.dumps(info, indent=4) + "\n")
|
48 |
+
print('normal_samples', normal_samples, 'anomaly_samples', anomaly_samples)
|
49 |
+
if __name__ == '__main__':
|
50 |
+
runner = DATASolver(root='data/4inlab')
|
51 |
+
runner.run()
|
dataset_config/image_ground_truth.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
|
6 |
+
# Initialize global variables
|
7 |
+
points = []
|
8 |
+
drawing = False
|
9 |
+
|
10 |
+
# Function to clear the drawn points
|
11 |
+
def clear_points():
|
12 |
+
global points
|
13 |
+
points = []
|
14 |
+
|
15 |
+
# Mouse callback function
|
16 |
+
def draw_polygon(event, x, y, flags, param):
|
17 |
+
global points, drawing, img
|
18 |
+
|
19 |
+
if event == cv2.EVENT_LBUTTONDOWN:
|
20 |
+
drawing = True
|
21 |
+
points.append((x, y))
|
22 |
+
|
23 |
+
elif event == cv2.EVENT_MOUSEMOVE:
|
24 |
+
if drawing:
|
25 |
+
img_copy = img.copy()
|
26 |
+
for i in range(1, len(points)):
|
27 |
+
cv2.line(img_copy, points[i - 1], points[i], (255, 0, 0), 2)
|
28 |
+
if len(points) > 0:
|
29 |
+
cv2.circle(img_copy, points[-1], 3, (0, 0, 255), -1, lineType=cv2.LINE_AA) # Hiển thị điểm chọn của chuột
|
30 |
+
cv2.imshow('image', img_copy)
|
31 |
+
|
32 |
+
elif event == cv2.EVENT_LBUTTONUP:
|
33 |
+
drawing = False
|
34 |
+
points.append((x, y))
|
35 |
+
pts = np.array(points, np.int32)
|
36 |
+
pts = pts.reshape((-1, 1, 2))
|
37 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
38 |
+
cv2.fillPoly(mask, [pts], 255)
|
39 |
+
cv2.imwrite(mask_path, mask)
|
40 |
+
cv2.imshow('image', img)
|
41 |
+
|
42 |
+
# Function to process images in a folder
|
43 |
+
def process_images_in_folder(folder_path):
|
44 |
+
global img, mask_path
|
45 |
+
|
46 |
+
for img_name in os.listdir(folder_path):
|
47 |
+
if img_name.endswith('.jpg'):
|
48 |
+
img_path = os.path.join(folder_path, img_name)
|
49 |
+
mask_path = os.path.join(folder_path, f'{os.path.splitext(img_name)[0]}_mask.jpg')
|
50 |
+
img = cv2.imread(img_path)
|
51 |
+
|
52 |
+
# Create a window and bind the mouse callback function
|
53 |
+
cv2.namedWindow('image')
|
54 |
+
cv2.setMouseCallback('image', draw_polygon)
|
55 |
+
|
56 |
+
while True:
|
57 |
+
cv2.imshow('image', img)
|
58 |
+
k = cv2.waitKey(1) & 0xFF
|
59 |
+
if k == 27: # Press 'ESC' to exit
|
60 |
+
break
|
61 |
+
|
62 |
+
clear_points()
|
63 |
+
cv2.destroyAllWindows()
|
64 |
+
|
65 |
+
# Define folders to process
|
66 |
+
folder_path=r'C:\Users\20240805\Documents\GitHub\AD-CLIP\data\4inlab\shinpyung\train\anomaly'
|
67 |
+
process_images_in_folder(folder_path)
|
68 |
+
# %%
|
dataset_config/image_resize.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
|
4 |
+
def resize_image(input_image_path, output_image_path, size=(518, 518)):
|
5 |
+
with Image.open(input_image_path) as image:
|
6 |
+
resized_image = image.resize(size)
|
7 |
+
resized_image.save(output_image_path)
|
8 |
+
|
9 |
+
# Example usage:
|
10 |
+
input_image = r'\4inlab\shinpyung\train\anomaly'
|
11 |
+
output_image = r'\4inlab\shinpyung\train\anomaly\resize'
|
12 |
+
resize_image(input_image, output_image)
|
13 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ftfy==6.2.0
|
2 |
+
matplotlib==3.9.0
|
3 |
+
numpy==1.24.3
|
4 |
+
opencv_python==4.9.0.80
|
5 |
+
pandas==2.2.2
|
6 |
+
Pillow==10.3.0
|
7 |
+
PyQt5==5.15.10
|
8 |
+
PyQt5_sip==12.13.0
|
9 |
+
regex==2024.5.15
|
10 |
+
scikit_learn==1.2.2
|
11 |
+
scipy==1.9.1
|
12 |
+
setuptools==59.5.0
|
13 |
+
scikit-image
|
14 |
+
tabulate==0.9.0
|
15 |
+
thop==0.1.1.post2209072238
|
16 |
+
timm==0.6.13
|
17 |
+
torch==2.0.0
|
18 |
+
torchvision==0.15.1
|
19 |
+
tqdm==4.65.0
|
20 |
+
pyautogui
|
test.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import AnomalyCLIP_lib
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
|
7 |
+
from training_libs.loss import FocalLoss, BinaryDiceLoss
|
8 |
+
from training_libs.utils import normalize
|
9 |
+
from training_libs.dataset import Dataset_test
|
10 |
+
from training_libs.logger import get_logger
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import numpy as np
|
16 |
+
from tabulate import tabulate
|
17 |
+
from training_libs.utils import get_transform
|
18 |
+
|
19 |
+
def setup_seed(seed):
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
random.seed(seed)
|
24 |
+
torch.backends.cudnn.deterministic = True
|
25 |
+
torch.backends.cudnn.benchmark = False
|
26 |
+
|
27 |
+
from training_libs.visualization import visualizer
|
28 |
+
|
29 |
+
from training_libs.metrics import image_level_metrics, pixel_level_metrics
|
30 |
+
from tqdm import tqdm
|
31 |
+
from scipy.ndimage import gaussian_filter
|
32 |
+
|
33 |
+
|
34 |
+
def test(args):
|
35 |
+
img_size = args.image_size
|
36 |
+
features_list = args.features_list
|
37 |
+
dataset_dir = args.data_path
|
38 |
+
save_path = args.save_path
|
39 |
+
dataset_name = args.dataset
|
40 |
+
|
41 |
+
logger = get_logger(args.save_path)
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
# device = "gpu"
|
44 |
+
|
45 |
+
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
|
46 |
+
model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
|
47 |
+
model.eval()
|
48 |
+
# torch.save(model.state_dict(),"pre-trained models/clip")
|
49 |
+
|
50 |
+
preprocess, target_transform = get_transform(args)
|
51 |
+
test_data = Dataset_test(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
|
52 |
+
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
|
53 |
+
obj_list = test_data.obj_list
|
54 |
+
|
55 |
+
|
56 |
+
results = {}
|
57 |
+
metrics = {}
|
58 |
+
for obj in obj_list:
|
59 |
+
results[obj] = {}
|
60 |
+
results[obj]['gt_sp'] = []
|
61 |
+
results[obj]['pr_sp'] = []
|
62 |
+
results[obj]['imgs_masks'] = []
|
63 |
+
results[obj]['anomaly_maps'] = []
|
64 |
+
metrics[obj] = {}
|
65 |
+
metrics[obj]['pixel-auroc'] = 0
|
66 |
+
metrics[obj]['pixel-aupro'] = 0
|
67 |
+
metrics[obj]['image-auroc'] = 0
|
68 |
+
metrics[obj]['image-ap'] = 0
|
69 |
+
|
70 |
+
prompt_learner = AnomalyCLIP_PromptLearner(model.to(device=device), AnomalyCLIP_parameters)
|
71 |
+
|
72 |
+
|
73 |
+
#Add check-point from trained model with normal images
|
74 |
+
# checkpoint = torch.load("checkpoint/241120_SP_DPAM_13_518/epoch_500.pth",map_location=torch.device('cpu'))
|
75 |
+
# prompt_learner.load_state_dict(checkpoint["prompt_learner"])
|
76 |
+
|
77 |
+
|
78 |
+
#Add check-point from trained model with normal images
|
79 |
+
# checkpoint = torch.load(args.checkpoint_path,map_location=torch.device(device=device))
|
80 |
+
# prompt_learner.load_state_dict(checkpoint["prompt_learner"])
|
81 |
+
|
82 |
+
|
83 |
+
prompt_learner.to(device)
|
84 |
+
model.to(device)
|
85 |
+
model.visual.DAPM_replace(DPAM_layer = 13)
|
86 |
+
|
87 |
+
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
|
88 |
+
print("print(prompts)")
|
89 |
+
print(prompts)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
|
94 |
+
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
|
95 |
+
text_features = text_features/text_features.norm(dim=-1, keepdim=True)
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
model.to(device)
|
100 |
+
for idx, items in enumerate(tqdm(test_dataloader)):
|
101 |
+
image = items['img'].to(device)
|
102 |
+
cls_name = items['cls_name']
|
103 |
+
cls_id = items['cls_id']
|
104 |
+
|
105 |
+
gt_mask_initial = items['img_mask']
|
106 |
+
#convert gt mask to good (0) and anomaly (1)
|
107 |
+
gt_mask = items['img_mask']
|
108 |
+
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
|
109 |
+
|
110 |
+
|
111 |
+
results[cls_name[0]]['imgs_masks'].append(gt_mask) # px
|
112 |
+
results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu())
|
113 |
+
|
114 |
+
with torch.no_grad():
|
115 |
+
image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20)
|
116 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
117 |
+
|
118 |
+
text_probs = image_features @ text_features.permute(0, 2, 1)
|
119 |
+
text_probs = (text_probs/0.07).softmax(-1)
|
120 |
+
text_probs = text_probs[:, 0, 1]
|
121 |
+
anomaly_map_list = []
|
122 |
+
for idx, patch_feature in enumerate(patch_features):
|
123 |
+
if idx >= args.feature_map_layer[0]:
|
124 |
+
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
|
125 |
+
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
|
126 |
+
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size)
|
127 |
+
anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0
|
128 |
+
anomaly_map_list.append(anomaly_map)
|
129 |
+
|
130 |
+
anomaly_map = torch.stack(anomaly_map_list)
|
131 |
+
|
132 |
+
anomaly_map = anomaly_map.sum(dim = 0)
|
133 |
+
results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu())
|
134 |
+
anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 )
|
135 |
+
results[cls_name[0]]['anomaly_maps'].append(anomaly_map)
|
136 |
+
|
137 |
+
#Save the anomaly map images
|
138 |
+
visualizer(items['img_path'], anomaly_map.detach().cpu().numpy(), args.image_size, args.save_path, cls_name)
|
139 |
+
|
140 |
+
print("print(results)")
|
141 |
+
torch.save(results,"results/results_shinpyung_0.pt")
|
142 |
+
# print(results)
|
143 |
+
|
144 |
+
table_ls = []
|
145 |
+
image_auroc_list = []
|
146 |
+
image_ap_list = []
|
147 |
+
pixel_auroc_list = []
|
148 |
+
pixel_aupro_list = []
|
149 |
+
for obj in obj_list:
|
150 |
+
table = []
|
151 |
+
table.append(obj)
|
152 |
+
results[obj]['imgs_masks'] = torch.cat(results[obj]['imgs_masks'])
|
153 |
+
results[obj]['anomaly_maps'] = torch.cat(results[obj]['anomaly_maps']).detach().cpu().numpy()
|
154 |
+
if args.metrics == 'image-level':
|
155 |
+
image_auroc = image_level_metrics(results, obj, "image-auroc")
|
156 |
+
image_ap = image_level_metrics(results, obj, "image-ap")
|
157 |
+
table.append(str(np.round(image_auroc * 100, decimals=1)))
|
158 |
+
table.append(str(np.round(image_ap * 100, decimals=1)))
|
159 |
+
image_auroc_list.append(image_auroc)
|
160 |
+
image_ap_list.append(image_ap)
|
161 |
+
elif args.metrics == 'pixel-level':
|
162 |
+
pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
|
163 |
+
pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
|
164 |
+
table.append(str(np.round(pixel_auroc * 100, decimals=1)))
|
165 |
+
table.append(str(np.round(pixel_aupro * 100, decimals=1)))
|
166 |
+
pixel_auroc_list.append(pixel_auroc)
|
167 |
+
pixel_aupro_list.append(pixel_aupro)
|
168 |
+
elif args.metrics == 'image-pixel-level':
|
169 |
+
image_auroc = image_level_metrics(results, obj, "image-auroc")
|
170 |
+
image_ap = image_level_metrics(results, obj, "image-ap")
|
171 |
+
pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
|
172 |
+
pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
|
173 |
+
table.append(str(np.round(pixel_auroc * 100, decimals=1)))
|
174 |
+
table.append(str(np.round(pixel_aupro * 100, decimals=1)))
|
175 |
+
table.append(str(np.round(image_auroc * 100, decimals=1)))
|
176 |
+
table.append(str(np.round(image_ap * 100, decimals=1)))
|
177 |
+
image_auroc_list.append(image_auroc)
|
178 |
+
image_ap_list.append(image_ap)
|
179 |
+
pixel_auroc_list.append(pixel_auroc)
|
180 |
+
pixel_aupro_list.append(pixel_aupro)
|
181 |
+
table_ls.append(table)
|
182 |
+
|
183 |
+
if args.metrics == 'image-level':
|
184 |
+
# logger
|
185 |
+
table_ls.append(['mean',
|
186 |
+
str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
|
187 |
+
str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
|
188 |
+
results = tabulate(table_ls, headers=['objects', 'image_auroc', 'image_ap'], tablefmt="pipe")
|
189 |
+
elif args.metrics == 'pixel-level':
|
190 |
+
# logger
|
191 |
+
table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
|
192 |
+
str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1))
|
193 |
+
])
|
194 |
+
results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro'], tablefmt="pipe")
|
195 |
+
elif args.metrics == 'image-pixel-level':
|
196 |
+
# logger
|
197 |
+
table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
|
198 |
+
str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1)),
|
199 |
+
str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
|
200 |
+
str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
|
201 |
+
results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro', 'image_auroc', 'image_ap'], tablefmt="pipe")
|
202 |
+
logger.info("\n%s", results)
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == '__main__':
|
206 |
+
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True)
|
207 |
+
# paths
|
208 |
+
parser.add_argument("--data_path", type=str, default="./data/4inlab/", help="path to test dataset")
|
209 |
+
parser.add_argument("--save_path", type=str, default='./results/', help='path to save results')
|
210 |
+
parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to checkpoint')
|
211 |
+
# model
|
212 |
+
parser.add_argument("--dataset", type=str, default='4inlab')
|
213 |
+
parser.add_argument("--image_size", type=int, default=518, help="image size")
|
214 |
+
parser.add_argument("--depth", type=int, default=9, help="image size")
|
215 |
+
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
|
216 |
+
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
|
217 |
+
parser.add_argument("--metrics", type=str, default='image-pixel-level')
|
218 |
+
parser.add_argument("--seed", type=int, default=111, help="random seed")
|
219 |
+
parser.add_argument("--sigma", type=int, default=4, help="zero shot")
|
220 |
+
# Specify layers from which feature maps will be extracted (can pass multiple values)
|
221 |
+
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
|
222 |
+
|
223 |
+
# List of layers whose features will be used
|
224 |
+
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
|
225 |
+
|
226 |
+
|
227 |
+
args = parser.parse_args()
|
228 |
+
print(args)
|
229 |
+
setup_seed(args.seed)
|
230 |
+
test(args)
|
231 |
+
#%%
|
train.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
|
3 |
+
import AnomalyCLIP_lib
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
|
8 |
+
from training_libs.loss import FocalLoss, BinaryDiceLoss
|
9 |
+
from training_libs.utils import normalize
|
10 |
+
from training_libs.dataset import Dataset_train
|
11 |
+
from training_libs.logger import get_logger
|
12 |
+
from tqdm import tqdm
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
+
from training_libs.utils import get_transform
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
|
18 |
+
import warnings
|
19 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
20 |
+
|
21 |
+
|
22 |
+
def setup_seed(seed):
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
np.random.seed(seed)
|
26 |
+
random.seed(seed)
|
27 |
+
torch.backends.cudnn.deterministic = True
|
28 |
+
torch.backends.cudnn.benchmark = False
|
29 |
+
|
30 |
+
class RealTimePlotter: #
|
31 |
+
def __init__(self):
|
32 |
+
self.epochs = []
|
33 |
+
self.loss_list = []
|
34 |
+
self.image_loss_list = []
|
35 |
+
self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(14, 6))
|
36 |
+
plt.ion()
|
37 |
+
self.fig.show()
|
38 |
+
self.fig.canvas.flush_events()
|
39 |
+
|
40 |
+
def update(self, epoch, loss, image_loss):
|
41 |
+
self.epochs.append(epoch)
|
42 |
+
self.loss_list.append(loss)
|
43 |
+
self.image_loss_list.append(image_loss)
|
44 |
+
|
45 |
+
self.ax1.clear()
|
46 |
+
self.ax2.clear()
|
47 |
+
|
48 |
+
self.ax1.plot(self.epochs, self.loss_list, label='Training Loss')
|
49 |
+
self.ax1.set_title('Training Loss')
|
50 |
+
self.ax1.set_xlabel('Epochs')
|
51 |
+
self.ax1.set_ylabel('Loss')
|
52 |
+
self.ax1.legend()
|
53 |
+
|
54 |
+
self.ax2.plot(self.epochs, self.image_loss_list, label='Image Loss')
|
55 |
+
self.ax2.set_title('Image Loss')
|
56 |
+
self.ax2.set_xlabel('Epochs')
|
57 |
+
self.ax2.set_ylabel('Loss')
|
58 |
+
self.ax2.legend()
|
59 |
+
|
60 |
+
self.fig.canvas.flush_events()
|
61 |
+
|
62 |
+
def train(args):
|
63 |
+
|
64 |
+
logger = get_logger(args.save_path)
|
65 |
+
|
66 |
+
preprocess, target_transform = get_transform(args)
|
67 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
68 |
+
# device = "cpu"
|
69 |
+
|
70 |
+
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
|
71 |
+
|
72 |
+
# model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters)
|
73 |
+
model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
train_data = Dataset_train(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
|
77 |
+
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
|
78 |
+
|
79 |
+
##########################################################################################
|
80 |
+
prompt_learner = AnomalyCLIP_PromptLearner(model.to(device), AnomalyCLIP_parameters)
|
81 |
+
prompt_learner.to(device)
|
82 |
+
model.to(device)
|
83 |
+
model.visual.DAPM_replace(DPAM_layer = args.dpam)
|
84 |
+
##########################################################################################
|
85 |
+
optimizer = torch.optim.Adam(list(prompt_learner.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
|
86 |
+
|
87 |
+
# losses
|
88 |
+
loss_focal = FocalLoss()
|
89 |
+
loss_dice = BinaryDiceLoss()
|
90 |
+
|
91 |
+
|
92 |
+
model.eval()
|
93 |
+
prompt_learner.train()
|
94 |
+
# plotter = RealTimePlotter()
|
95 |
+
|
96 |
+
for epoch in tqdm(range(args.epoch)):
|
97 |
+
model.eval()
|
98 |
+
prompt_learner.train()
|
99 |
+
loss_list = []
|
100 |
+
image_loss_list = []
|
101 |
+
|
102 |
+
for items in tqdm(train_dataloader):
|
103 |
+
image = items['img'].to(device)
|
104 |
+
label = items['anomaly']
|
105 |
+
|
106 |
+
gt = items['img_mask'].squeeze().to(device)
|
107 |
+
gt[gt > 0.5] = 1
|
108 |
+
gt[gt <= 0.5] = 0
|
109 |
+
|
110 |
+
with torch.no_grad():
|
111 |
+
# Apply DPAM to the layer from 6 to 24
|
112 |
+
# DPAM_layer represents the number of layer refined by DPAM from top to bottom
|
113 |
+
# DPAM_layer = 1, no DPAM is used
|
114 |
+
# DPAM_layer = 20 as default
|
115 |
+
image_features, patch_features = model.encode_image(image, args.features_list, DPAM_layer = args.dpam)
|
116 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
117 |
+
|
118 |
+
####################################
|
119 |
+
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
|
120 |
+
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
|
121 |
+
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
|
122 |
+
text_features = text_features/text_features.norm(dim=-1, keepdim=True)
|
123 |
+
# Apply DPAM surgery
|
124 |
+
text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1)
|
125 |
+
text_probs = text_probs[:, 0, ...]/0.07
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
image_loss = F.cross_entropy(text_probs.squeeze(), label.long().cuda()) #Process with GPU
|
130 |
+
#image_loss = F.cross_entropy(text_probs.squeeze(), label.long()) #Without GPU and using CPU
|
131 |
+
image_loss_list.append(image_loss.item())
|
132 |
+
######################################################################
|
133 |
+
similarity_map_list = []
|
134 |
+
# similarity_map_list.append(similarity_map)
|
135 |
+
for idx, patch_feature in enumerate(patch_features):
|
136 |
+
if idx >= args.feature_map_layer[0]:
|
137 |
+
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
|
138 |
+
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
|
139 |
+
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size).permute(0, 3, 1, 2)
|
140 |
+
similarity_map_list.append(similarity_map)
|
141 |
+
|
142 |
+
loss = 0
|
143 |
+
for i in range(len(similarity_map_list)):
|
144 |
+
loss += loss_focal(similarity_map_list[i], gt)
|
145 |
+
loss += loss_dice(similarity_map_list[i][:, 1, :, :], gt)
|
146 |
+
loss += loss_dice(similarity_map_list[i][:, 0, :, :], 1-gt)
|
147 |
+
|
148 |
+
optimizer.zero_grad()
|
149 |
+
(loss+image_loss).backward()
|
150 |
+
optimizer.step()
|
151 |
+
loss_list.append(loss.item())
|
152 |
+
# logs
|
153 |
+
if (epoch + 1) % args.print_freq == 0:
|
154 |
+
avg_loss = np.mean(loss_list)
|
155 |
+
avg_image_loss = np.mean(image_loss_list)
|
156 |
+
logger.info('epoch [{}/{}], loss:{:.4f}, image_loss:{:.4f}'.format(epoch + 1, args.epoch, avg_loss, avg_image_loss))
|
157 |
+
# plotter.update(epoch + 1, avg_loss, avg_image_loss) #Realtime training performance monitoring
|
158 |
+
|
159 |
+
# save model
|
160 |
+
if (epoch + 1) % args.save_freq == 0:
|
161 |
+
ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth')
|
162 |
+
torch.save({"prompt_learner": prompt_learner.state_dict(),"epoch":epoch+1}, ckp_path)
|
163 |
+
|
164 |
+
if __name__ == '__main__':
|
165 |
+
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) # Initialize the argument parser
|
166 |
+
|
167 |
+
# Define the path to the training dataset and model checkpoint saving
|
168 |
+
parser.add_argument("--train_data_path", type=str, default="./data/4inlab", help="train dataset path")
|
169 |
+
parser.add_argument("--save_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to save results')
|
170 |
+
|
171 |
+
# Specify the name of the training dataset
|
172 |
+
parser.add_argument("--dataset", type=str, default='4inlab', help="train dataset name")
|
173 |
+
|
174 |
+
# Set the depth parameter (Note: "image size" in help may be misleading)
|
175 |
+
parser.add_argument("--depth", type=int, default=9, help="image size")
|
176 |
+
|
177 |
+
# Set the prompt length and learnable text embedding length for "zero-shot" learning
|
178 |
+
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
|
179 |
+
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
|
180 |
+
|
181 |
+
# Specify layers from which feature maps will be extracted (can pass multiple values)
|
182 |
+
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
|
183 |
+
|
184 |
+
# List of layers whose features will be used
|
185 |
+
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
|
186 |
+
|
187 |
+
# Setting parameters for training
|
188 |
+
parser.add_argument("--epoch", type=int, default=400, help="epochs")
|
189 |
+
parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
|
190 |
+
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
|
191 |
+
|
192 |
+
# Size/depth parameter for the DPAM (Deep Prompt Attention Mechanism)
|
193 |
+
parser.add_argument("--dpam", type=int, default=13, help="dpam size")
|
194 |
+
|
195 |
+
# Define the size of input images used for training
|
196 |
+
parser.add_argument("--image_size", type=int, default=518, help="image size")
|
197 |
+
|
198 |
+
# Frequency (in epochs) of logging training information and saving
|
199 |
+
parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
|
200 |
+
parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
|
201 |
+
parser.add_argument("--seed", type=int, default=111, help="random seed")
|
202 |
+
|
203 |
+
args = parser.parse_args() # Parse the command-line arguments and store them in the 'args' object
|
204 |
+
setup_seed(args.seed) # Set the random seed for reproducibility using the provided seed value
|
205 |
+
train(args) # Call the training function with the parsed arguments
|
206 |
+
|
207 |
+
|
training_libs/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (3.54 kB). View file
|
|
training_libs/__pycache__/logger.cpython-39.pyc
ADDED
Binary file (890 Bytes). View file
|
|
training_libs/__pycache__/loss.cpython-39.pyc
ADDED
Binary file (4.19 kB). View file
|
|
training_libs/__pycache__/metrics.cpython-39.pyc
ADDED
Binary file (1.98 kB). View file
|
|
training_libs/__pycache__/prompt_ensemble.cpython-39.pyc
ADDED
Binary file (7.13 kB). View file
|
|
training_libs/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.06 kB). View file
|
|
training_libs/__pycache__/visualization.cpython-39.pyc
ADDED
Binary file (1.17 kB). View file
|
|
training_libs/dataset.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
|
9 |
+
def generate_class_info(dataset_name, mode='train'):
|
10 |
+
class_name_map_class_id = {}
|
11 |
+
if dataset_name == 'mvtec':
|
12 |
+
# obj_list = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
|
13 |
+
# 'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood']
|
14 |
+
obj_list = ['bottle']
|
15 |
+
elif dataset_name == '4inlab':
|
16 |
+
if mode=='train':
|
17 |
+
obj_list = ['shinpyung'] # With training
|
18 |
+
elif mode=='test':
|
19 |
+
obj_list = ['shinpyung'] # With testing
|
20 |
+
elif dataset_name == 'task1':
|
21 |
+
if mode=='train':
|
22 |
+
obj_list = ['cup']
|
23 |
+
elif dataset_name == 'task2':
|
24 |
+
if mode=='train':
|
25 |
+
obj_list = ['fire']
|
26 |
+
elif dataset_name == 'smoke_cloud':
|
27 |
+
if mode=='train':
|
28 |
+
obj_list = ['fire']
|
29 |
+
|
30 |
+
for k, index in zip(obj_list, range(len(obj_list))):
|
31 |
+
class_name_map_class_id[k] = index
|
32 |
+
|
33 |
+
return obj_list, class_name_map_class_id
|
34 |
+
|
35 |
+
class Dataset_test(data.Dataset):
|
36 |
+
def __init__(self, root, transform, target_transform, dataset_name, mode="test"):
|
37 |
+
self.root = root
|
38 |
+
self.transform = transform
|
39 |
+
self.target_transform = target_transform
|
40 |
+
self.data_all = []
|
41 |
+
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
|
42 |
+
name = self.root.split('/')[-1]
|
43 |
+
meta_info = meta_info[mode]
|
44 |
+
|
45 |
+
self.cls_names = list(meta_info.keys())
|
46 |
+
for cls_name in self.cls_names:
|
47 |
+
self.data_all.extend(meta_info[cls_name])
|
48 |
+
self.length = len(self.data_all)
|
49 |
+
|
50 |
+
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='test')
|
51 |
+
def __len__(self):
|
52 |
+
return self.length
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
data = self.data_all[index]
|
56 |
+
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
|
57 |
+
data['specie_name'], data['anomaly']
|
58 |
+
img = Image.open(os.path.join(self.root, img_path))
|
59 |
+
if anomaly == 0:
|
60 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
61 |
+
else:
|
62 |
+
if os.path.isdir(os.path.join(self.root, mask_path)):
|
63 |
+
# just for classification not report error
|
64 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
65 |
+
else:
|
66 |
+
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
|
67 |
+
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
|
68 |
+
# transforms
|
69 |
+
img = self.transform(img) if self.transform is not None else img
|
70 |
+
img_mask = self.target_transform(
|
71 |
+
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
|
72 |
+
img_mask = [] if img_mask is None else img_mask
|
73 |
+
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
|
74 |
+
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
|
75 |
+
|
76 |
+
|
77 |
+
class Dataset_train(data.Dataset):
|
78 |
+
def __init__(self, root, transform, target_transform, dataset_name, mode="train"):
|
79 |
+
self.root = root
|
80 |
+
self.transform = transform
|
81 |
+
self.target_transform = target_transform
|
82 |
+
self.data_all = []
|
83 |
+
meta_info = json.load(open(f'{self.root}/meta_train.json', 'r'))
|
84 |
+
name = self.root.split('/')[-1]
|
85 |
+
meta_info = meta_info[mode]
|
86 |
+
|
87 |
+
self.cls_names = list(meta_info.keys())
|
88 |
+
for cls_name in self.cls_names:
|
89 |
+
self.data_all.extend(meta_info[cls_name])
|
90 |
+
self.length = len(self.data_all)
|
91 |
+
|
92 |
+
self.obj_list, self.class_name_map_class_id = generate_class_info(dataset_name,mode='train')
|
93 |
+
def __len__(self):
|
94 |
+
return self.length
|
95 |
+
|
96 |
+
def __getitem__(self, index):
|
97 |
+
data = self.data_all[index]
|
98 |
+
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
|
99 |
+
data['specie_name'], data['anomaly']
|
100 |
+
img = Image.open(os.path.join(self.root, img_path))
|
101 |
+
if anomaly == 0:
|
102 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
103 |
+
else:
|
104 |
+
if os.path.isdir(os.path.join(self.root, mask_path)):
|
105 |
+
# just for classification not report error
|
106 |
+
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
|
107 |
+
else:
|
108 |
+
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
|
109 |
+
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
|
110 |
+
# transforms
|
111 |
+
img = self.transform(img) if self.transform is not None else img
|
112 |
+
img_mask = self.target_transform(
|
113 |
+
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
|
114 |
+
img_mask = [] if img_mask is None else img_mask
|
115 |
+
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
|
116 |
+
'img_path': os.path.join(self.root, img_path), "cls_id": self.class_name_map_class_id[cls_name]}
|
training_libs/logger.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
|
5 |
+
def get_logger(save_path):
|
6 |
+
if not os.path.exists(save_path):
|
7 |
+
os.makedirs(save_path)
|
8 |
+
|
9 |
+
txt_path = os.path.join(save_path, 'log.txt')
|
10 |
+
# logger
|
11 |
+
root_logger = logging.getLogger()
|
12 |
+
for handler in root_logger.handlers[:]:
|
13 |
+
root_logger.removeHandler(handler)
|
14 |
+
root_logger.setLevel(logging.WARNING)
|
15 |
+
logger = logging.getLogger('test')
|
16 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
|
17 |
+
datefmt='%y-%m-%d %H:%M:%S')
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
file_handler = logging.FileHandler(txt_path, mode='a')
|
20 |
+
file_handler.setFormatter(formatter)
|
21 |
+
logger.addHandler(file_handler)
|
22 |
+
console_handler = logging.StreamHandler()
|
23 |
+
console_handler.setFormatter(formatter)
|
24 |
+
logger.addHandler(console_handler)
|
25 |
+
return logger
|
training_libs/loss.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from math import exp
|
6 |
+
|
7 |
+
class FocalLoss(nn.Module):
|
8 |
+
"""
|
9 |
+
copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
|
10 |
+
This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
|
11 |
+
'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
|
12 |
+
Focal_Loss= -1*alpha*(1-pt)*log(pt)
|
13 |
+
:param alpha: (tensor) 3D or 4D the scalar factor for this criterion
|
14 |
+
:param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
|
15 |
+
focus on hard misclassified example
|
16 |
+
:param smooth: (float,double) smooth value when cross entropy
|
17 |
+
:param balance_index: (int) balance class index, should be specific when alpha is float
|
18 |
+
:param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
|
22 |
+
super(FocalLoss, self).__init__()
|
23 |
+
self.apply_nonlin = apply_nonlin
|
24 |
+
self.alpha = alpha
|
25 |
+
self.gamma = gamma
|
26 |
+
self.balance_index = balance_index
|
27 |
+
self.smooth = smooth
|
28 |
+
self.size_average = size_average
|
29 |
+
|
30 |
+
if self.smooth is not None:
|
31 |
+
if self.smooth < 0 or self.smooth > 1.0:
|
32 |
+
raise ValueError('smooth value should be in [0,1]')
|
33 |
+
|
34 |
+
def forward(self, logit, target):
|
35 |
+
if self.apply_nonlin is not None:
|
36 |
+
logit = self.apply_nonlin(logit)
|
37 |
+
num_class = logit.shape[1]
|
38 |
+
|
39 |
+
if logit.dim() > 2:
|
40 |
+
# N,C,d1,d2 -> N,C,m (m=d1*d2*...)
|
41 |
+
logit = logit.view(logit.size(0), logit.size(1), -1)
|
42 |
+
logit = logit.permute(0, 2, 1).contiguous()
|
43 |
+
logit = logit.view(-1, logit.size(-1))
|
44 |
+
target = torch.squeeze(target, 1)
|
45 |
+
target = target.view(-1, 1)
|
46 |
+
alpha = self.alpha
|
47 |
+
|
48 |
+
if alpha is None:
|
49 |
+
alpha = torch.ones(num_class, 1)
|
50 |
+
elif isinstance(alpha, (list, np.ndarray)):
|
51 |
+
assert len(alpha) == num_class
|
52 |
+
alpha = torch.FloatTensor(alpha).view(num_class, 1)
|
53 |
+
alpha = alpha / alpha.sum()
|
54 |
+
elif isinstance(alpha, float):
|
55 |
+
alpha = torch.ones(num_class, 1)
|
56 |
+
alpha = alpha * (1 - self.alpha)
|
57 |
+
alpha[self.balance_index] = self.alpha
|
58 |
+
|
59 |
+
else:
|
60 |
+
raise TypeError('Not support alpha type')
|
61 |
+
|
62 |
+
if alpha.device != logit.device:
|
63 |
+
alpha = alpha.to(logit.device)
|
64 |
+
|
65 |
+
idx = target.cpu().long()
|
66 |
+
|
67 |
+
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
|
68 |
+
one_hot_key = one_hot_key.scatter_(1, idx, 1)
|
69 |
+
if one_hot_key.device != logit.device:
|
70 |
+
one_hot_key = one_hot_key.to(logit.device)
|
71 |
+
|
72 |
+
if self.smooth:
|
73 |
+
one_hot_key = torch.clamp(
|
74 |
+
one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
|
75 |
+
pt = (one_hot_key * logit).sum(1) + self.smooth
|
76 |
+
logpt = pt.log()
|
77 |
+
|
78 |
+
gamma = self.gamma
|
79 |
+
|
80 |
+
alpha = alpha[idx]
|
81 |
+
alpha = torch.squeeze(alpha)
|
82 |
+
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
|
83 |
+
|
84 |
+
if self.size_average:
|
85 |
+
loss = loss.mean()
|
86 |
+
return loss
|
87 |
+
|
88 |
+
|
89 |
+
class BinaryDiceLoss(nn.Module):
|
90 |
+
def __init__(self):
|
91 |
+
super(BinaryDiceLoss, self).__init__()
|
92 |
+
|
93 |
+
def forward(self, input, targets):
|
94 |
+
# Get the size N of each batch
|
95 |
+
N = targets.size()[0]
|
96 |
+
# Smooth variable
|
97 |
+
smooth = 1
|
98 |
+
# Reshape the width and height to the same dimension
|
99 |
+
input_flat = input.view(N, -1)
|
100 |
+
targets_flat = targets.view(N, -1)
|
101 |
+
|
102 |
+
intersection = input_flat * targets_flat
|
103 |
+
N_dice_eff = (2 * intersection.sum(1) + smooth) / (input_flat.sum(1) + targets_flat.sum(1) + smooth)
|
104 |
+
# Calculate the loss average for each image in a batch
|
105 |
+
loss = 1 - N_dice_eff.sum() / N
|
106 |
+
return loss
|
107 |
+
|
108 |
+
def smooth(arr, lamda1):
|
109 |
+
new_array = arr
|
110 |
+
arr2 = torch.zeros_like(arr)
|
111 |
+
arr2[:, :-1, :] = arr[:, 1:, :]
|
112 |
+
arr2[:, -1, :] = arr[:, -1, :]
|
113 |
+
|
114 |
+
new_array2 = torch.zeros_like(new_array)
|
115 |
+
new_array2[:, :, :-1] = new_array[:, :, 1:]
|
116 |
+
new_array2[:, :, -1] = new_array[:, :, -1]
|
117 |
+
loss = (torch.sum((arr2 - arr) ** 2) + torch.sum((new_array2 - new_array) ** 2)) / 2
|
118 |
+
return lamda1 * loss
|
119 |
+
|
120 |
+
def sparsity(arr, target, lamda2):
|
121 |
+
if target == 0:
|
122 |
+
loss = torch.mean(torch.norm(arr, dim=0))
|
123 |
+
else:
|
124 |
+
loss = torch.mean(torch.norm(1-arr, dim=0))
|
125 |
+
return lamda2 * loss
|
training_libs/metrics.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import auc, roc_auc_score, average_precision_score, f1_score, precision_recall_curve, pairwise
|
2 |
+
import numpy as np
|
3 |
+
from skimage import measure
|
4 |
+
|
5 |
+
def cal_pro_score(masks, amaps, max_step=200, expect_fpr=0.3):
|
6 |
+
# ref: https://github.com/gudovskiy/cflow-ad/blob/master/train.py
|
7 |
+
binary_amaps = np.zeros_like(amaps, dtype=bool)
|
8 |
+
min_th, max_th = amaps.min(), amaps.max()
|
9 |
+
delta = (max_th - min_th) / max_step
|
10 |
+
pros, fprs, ths = [], [], []
|
11 |
+
for th in np.arange(min_th, max_th, delta):
|
12 |
+
binary_amaps[amaps <= th], binary_amaps[amaps > th] = 0, 1
|
13 |
+
pro = []
|
14 |
+
for binary_amap, mask in zip(binary_amaps, masks):
|
15 |
+
for region in measure.regionprops(measure.label(mask)):
|
16 |
+
tp_pixels = binary_amap[region.coords[:, 0], region.coords[:, 1]].sum()
|
17 |
+
pro.append(tp_pixels / region.area)
|
18 |
+
inverse_masks = 1 - masks
|
19 |
+
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
|
20 |
+
fpr = fp_pixels / inverse_masks.sum()
|
21 |
+
pros.append(np.array(pro).mean())
|
22 |
+
fprs.append(fpr)
|
23 |
+
ths.append(th)
|
24 |
+
pros, fprs, ths = np.array(pros), np.array(fprs), np.array(ths)
|
25 |
+
idxes = fprs < expect_fpr
|
26 |
+
fprs = fprs[idxes]
|
27 |
+
fprs = (fprs - fprs.min()) / (fprs.max() - fprs.min())
|
28 |
+
pro_auc = auc(fprs, pros[idxes])
|
29 |
+
return pro_auc
|
30 |
+
|
31 |
+
|
32 |
+
def image_level_metrics(results, obj, metric):
|
33 |
+
gt = results[obj]['gt_sp']
|
34 |
+
pr = results[obj]['pr_sp']
|
35 |
+
gt = np.array(gt)
|
36 |
+
pr = np.array(pr)
|
37 |
+
if metric == 'image-auroc':
|
38 |
+
performance = roc_auc_score(gt, pr)
|
39 |
+
elif metric == 'image-ap':
|
40 |
+
performance = average_precision_score(gt, pr)
|
41 |
+
|
42 |
+
return performance
|
43 |
+
# table.append(str(np.round(performance * 100, decimals=1)))
|
44 |
+
|
45 |
+
|
46 |
+
def pixel_level_metrics(results, obj, metric):
|
47 |
+
gt = results[obj]['imgs_masks']
|
48 |
+
pr = results[obj]['anomaly_maps']
|
49 |
+
gt = np.array(gt)
|
50 |
+
pr = np.array(pr)
|
51 |
+
if metric == 'pixel-auroc':
|
52 |
+
performance = roc_auc_score(gt.ravel(), pr.ravel())
|
53 |
+
elif metric == 'pixel-aupro':
|
54 |
+
if len(gt.shape) == 4:
|
55 |
+
gt = gt.squeeze(1)
|
56 |
+
if len(pr.shape) == 4:
|
57 |
+
pr = pr.squeeze(1)
|
58 |
+
performance = cal_pro_score(gt, pr)
|
59 |
+
return performance
|
60 |
+
|
training_libs/prompt_ensemble.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Union, List
|
3 |
+
from pkg_resources import packaging
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from AnomalyCLIP_lib.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
7 |
+
# from open_clip import tokenizer
|
8 |
+
# simple_tokenizer = tokenizer.SimpleTokenizer()
|
9 |
+
from copy import deepcopy
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
_tokenizer = _Tokenizer()
|
13 |
+
|
14 |
+
|
15 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
16 |
+
"""
|
17 |
+
Returns the tokenized representation of given input string(s)
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
texts : Union[str, List[str]]
|
22 |
+
An input string or a list of input strings to tokenize
|
23 |
+
|
24 |
+
context_length : int
|
25 |
+
The context length to use; all CLIP models use 77 as the context length
|
26 |
+
|
27 |
+
truncate: bool
|
28 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
29 |
+
|
30 |
+
Returns
|
31 |
+
-------
|
32 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
33 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
34 |
+
"""
|
35 |
+
if isinstance(texts, str):
|
36 |
+
texts = [texts]
|
37 |
+
|
38 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
39 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
40 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
41 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
42 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
43 |
+
else:
|
44 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
45 |
+
|
46 |
+
for i, tokens in enumerate(all_tokens):
|
47 |
+
if len(tokens) > context_length:
|
48 |
+
if truncate:
|
49 |
+
tokens = tokens[:context_length]
|
50 |
+
tokens[-1] = eot_token
|
51 |
+
else:
|
52 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
53 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
54 |
+
|
55 |
+
return result
|
56 |
+
|
57 |
+
# def encode_text_with_prompt_ensemble(model, texts, device):
|
58 |
+
# prompt_normal = ['{}', 'flawless {}', 'perfect {}', 'unblemished {}', '{} without flaw', '{} without defect', '{} without damage']
|
59 |
+
# prompt_abnormal = ['damaged {}', 'broken {}', '{} with flaw', '{} with defect', '{} with damage']
|
60 |
+
# prompt_state = [prompt_normal, prompt_abnormal]
|
61 |
+
# prompt_templates = ['a bad photo of a {}.', 'a low resolution photo of the {}.', 'a bad photo of the {}.', 'a cropped photo of the {}.', 'a bright photo of a {}.', 'a dark photo of the {}.', 'a photo of my {}.', 'a photo of the cool {}.', 'a close-up photo of a {}.', 'a black and white photo of the {}.', 'a bright photo of the {}.', 'a cropped photo of a {}.', 'a jpeg corrupted photo of a {}.', 'a blurry photo of the {}.', 'a photo of the {}.', 'a good photo of the {}.', 'a photo of one {}.', 'a close-up photo of the {}.', 'a photo of a {}.', 'a low resolution photo of a {}.', 'a photo of a large {}.', 'a blurry photo of a {}.', 'a jpeg corrupted photo of the {}.', 'a good photo of a {}.', 'a photo of the small {}.', 'a photo of the large {}.', 'a black and white photo of a {}.', 'a dark photo of a {}.', 'a photo of a cool {}.', 'a photo of a small {}.', 'there is a {} in the scene.', 'there is the {} in the scene.', 'this is a {} in the scene.', 'this is the {} in the scene.', 'this is one {} in the scene.']
|
62 |
+
|
63 |
+
# text_features = []
|
64 |
+
# for i in range(len(prompt_state)):
|
65 |
+
# prompted_state = [state.format(texts[0]) for state in prompt_state[i]]
|
66 |
+
# prompted_sentence = []
|
67 |
+
# for s in prompted_state:
|
68 |
+
# for template in prompt_templates:
|
69 |
+
# prompted_sentence.append(template.format(s))
|
70 |
+
# prompted_sentence = tokenize(prompted_sentence)
|
71 |
+
# class_embeddings = model.encode_text(prompted_sentence.to(device))
|
72 |
+
# class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
73 |
+
# class_embedding = class_embeddings.mean(dim=0)
|
74 |
+
# class_embedding /= class_embedding.norm()
|
75 |
+
# text_features.append(class_embedding)
|
76 |
+
|
77 |
+
# text_features = torch.stack(text_features, dim=1).to(device).t()
|
78 |
+
|
79 |
+
# return text_features
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def _get_clones(module, N):
|
84 |
+
return nn.ModuleList([deepcopy(module) for i in range(N)])
|
85 |
+
class AnomalyCLIP_PromptLearner(nn.Module):
|
86 |
+
def __init__(self, clip_model, design_details):
|
87 |
+
super().__init__()
|
88 |
+
classnames = ["object"]
|
89 |
+
self.n_cls = len(classnames)
|
90 |
+
self.n_ctx = design_details["Prompt_length"]
|
91 |
+
n_ctx_pos = self.n_ctx
|
92 |
+
n_ctx_neg = self.n_ctx
|
93 |
+
self.text_encoder_n_ctx = design_details["learnabel_text_embedding_length"]
|
94 |
+
ctx_init_pos = ""
|
95 |
+
ctx_init_neg = ""
|
96 |
+
dtype = clip_model.transformer.get_cast_dtype()
|
97 |
+
device = clip_model.token_embedding.weight.device
|
98 |
+
|
99 |
+
ctx_dim = clip_model.ln_final.weight.shape[0]
|
100 |
+
|
101 |
+
|
102 |
+
self.classnames = classnames
|
103 |
+
|
104 |
+
self.state_normal_list = [
|
105 |
+
"{}",
|
106 |
+
]
|
107 |
+
|
108 |
+
self.state_anomaly_list = [
|
109 |
+
"damaged {}",
|
110 |
+
]
|
111 |
+
|
112 |
+
normal_num = len(self.state_normal_list)
|
113 |
+
anormaly_num = len(self.state_anomaly_list)
|
114 |
+
self.normal_num = normal_num
|
115 |
+
self.anormaly_num = anormaly_num
|
116 |
+
|
117 |
+
if ctx_init_pos and ctx_init_neg:
|
118 |
+
# use given words to initialize context vectors
|
119 |
+
ctx_init_pos = ctx_init_pos.replace("_", " ")
|
120 |
+
ctx_init_neg = ctx_init_neg.replace("_", " ")
|
121 |
+
n_ctx_pos = len(ctx_init_pos.split(" "))
|
122 |
+
n_ctx_neg = len(ctx_init_neg.split(" "))
|
123 |
+
# Initialize text into bpd encoding
|
124 |
+
prompt_pos = tokenize(ctx_init_pos)
|
125 |
+
prompt_neg = tokenize(ctx_init_neg)
|
126 |
+
with torch.no_grad():
|
127 |
+
# Generate corresponding text embedding
|
128 |
+
embedding_pos = clip_model.token_embedding(prompt_pos).type(dtype)
|
129 |
+
embedding_neg = clip_model.token_embedding(prompt_neg).type(dtype)
|
130 |
+
# Remove EOS and # CLS, EOS, and get the learnable textual prompt
|
131 |
+
ctx_vectors_pos = embedding_pos[0, 1: 1 + n_ctx_pos, :]
|
132 |
+
ctx_vectors_neg = embedding_neg[0, 1: 1 + n_ctx_neg, :]
|
133 |
+
prompt_prefix_pos = ctx_init_pos
|
134 |
+
prompt_prefix_neg = ctx_init_neg
|
135 |
+
if True:
|
136 |
+
ctx_vectors_pos_ = []
|
137 |
+
ctx_vectors_neg_ = []
|
138 |
+
for _ in range(self.n_cls):
|
139 |
+
ctx_vectors_pos_.append(deepcopy(ctx_vectors_pos))
|
140 |
+
ctx_vectors_neg_.append(deepcopy(ctx_vectors_neg))
|
141 |
+
ctx_vectors_pos = torch.stack(ctx_vectors_pos_, dim=0)
|
142 |
+
ctx_vectors_neg = torch.stack(ctx_vectors_neg_, dim=0)
|
143 |
+
|
144 |
+
else:
|
145 |
+
# Random Initialization
|
146 |
+
if True:
|
147 |
+
print("Initializing class-specific contexts")
|
148 |
+
# Here cls is the number of classes, n_ctx_pos represents the length of learnable tokens, ctx_dim indicates the dimension of the prompt
|
149 |
+
ctx_vectors_pos = torch.empty(self.n_cls, self.normal_num, n_ctx_pos, ctx_dim, dtype=dtype)
|
150 |
+
ctx_vectors_neg = torch.empty(self.n_cls, self.anormaly_num, n_ctx_neg, ctx_dim, dtype=dtype)
|
151 |
+
else:
|
152 |
+
print("Initializing a generic context")
|
153 |
+
ctx_vectors_pos = torch.empty(n_ctx_pos, ctx_dim, dtype=dtype)
|
154 |
+
ctx_vectors_neg = torch.empty(n_ctx_neg, ctx_dim, dtype=dtype)
|
155 |
+
nn.init.normal_(ctx_vectors_pos, std=0.02)
|
156 |
+
nn.init.normal_(ctx_vectors_neg, std=0.02)
|
157 |
+
prompt_prefix_pos = " ".join(["X"] * n_ctx_pos)
|
158 |
+
prompt_prefix_neg = " ".join(["X"] * n_ctx_neg)
|
159 |
+
self.compound_prompts_depth = design_details["learnabel_text_embedding_depth"]
|
160 |
+
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(self.text_encoder_n_ctx, ctx_dim))
|
161 |
+
for _ in range(self.compound_prompts_depth - 1)])
|
162 |
+
for single_para in self.compound_prompts_text:
|
163 |
+
print("single_para", single_para.shape)
|
164 |
+
nn.init.normal_(single_para, std=0.02)
|
165 |
+
|
166 |
+
single_layer = nn.Linear(ctx_dim, 896)
|
167 |
+
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
|
168 |
+
|
169 |
+
|
170 |
+
self.ctx_pos = nn.Parameter(ctx_vectors_pos) # to be optimized
|
171 |
+
self.ctx_neg = nn.Parameter(ctx_vectors_neg) # to be optimized
|
172 |
+
|
173 |
+
classnames = [name.replace("_", " ") for name in classnames]
|
174 |
+
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
|
175 |
+
|
176 |
+
|
177 |
+
prompts_pos = [prompt_prefix_pos + " " + template.format(name)+ "." for template in self.state_normal_list for name in classnames]
|
178 |
+
prompts_neg = [prompt_prefix_neg + " " + template.format(name)+ "." for template in self.state_anomaly_list for name in classnames]
|
179 |
+
|
180 |
+
# print("Normal Prompt:",prompts_pos )
|
181 |
+
# print("Anomaly Prompt:",prompts_neg )
|
182 |
+
|
183 |
+
tokenized_prompts_pos = []
|
184 |
+
tokenized_prompts_neg = []
|
185 |
+
|
186 |
+
for p_pos in prompts_pos:
|
187 |
+
tokenized_prompts_pos.append(tokenize(p_pos))
|
188 |
+
for p_neg in prompts_neg:
|
189 |
+
tokenized_prompts_neg.append(tokenize(p_neg))
|
190 |
+
|
191 |
+
|
192 |
+
tokenized_prompts_pos = [tokenize(p_pos).to(device) for p_pos in prompts_pos] # Move tokenized_prompts_pos to the same device
|
193 |
+
tokenized_prompts_neg = [tokenize(p_neg).to(device) for p_neg in prompts_neg] # Move tokenized_prompts_neg to the same device
|
194 |
+
|
195 |
+
tokenized_prompts_pos = torch.cat(tokenized_prompts_pos)
|
196 |
+
tokenized_prompts_neg = torch.cat(tokenized_prompts_neg)
|
197 |
+
# Generate corresponding text embedding
|
198 |
+
with torch.no_grad():
|
199 |
+
embedding_pos = clip_model.token_embedding(tokenized_prompts_pos).type(dtype)
|
200 |
+
embedding_neg = clip_model.token_embedding(tokenized_prompts_neg).type(dtype)
|
201 |
+
n, l, d = embedding_pos.shape
|
202 |
+
print("embedding_pos", embedding_pos.shape)
|
203 |
+
embedding_pos = embedding_pos.reshape(normal_num, self.n_cls, l, d).permute(1, 0, 2, 3)
|
204 |
+
embedding_neg = embedding_neg.reshape(anormaly_num, self.n_cls, l, d).permute(1, 0, 2, 3)
|
205 |
+
|
206 |
+
|
207 |
+
self.register_buffer("token_prefix_pos", embedding_pos[:, :, :1, :] )
|
208 |
+
self.register_buffer("token_suffix_pos", embedding_pos[:, :,1 + n_ctx_pos:, :])
|
209 |
+
self.register_buffer("token_prefix_neg", embedding_neg[:,:, :1, :])
|
210 |
+
self.register_buffer("token_suffix_neg", embedding_neg[:, :, 1 + n_ctx_neg:, :])
|
211 |
+
|
212 |
+
n, d = tokenized_prompts_pos.shape
|
213 |
+
tokenized_prompts_pos = tokenized_prompts_pos.reshape(normal_num, self.n_cls, d).permute(1, 0, 2)
|
214 |
+
|
215 |
+
n, d = tokenized_prompts_neg.shape
|
216 |
+
tokenized_prompts_neg = tokenized_prompts_neg.reshape(anormaly_num, self.n_cls, d).permute(1, 0, 2)
|
217 |
+
|
218 |
+
self.n_ctx_pos = n_ctx_pos
|
219 |
+
self.n_ctx_neg = n_ctx_neg
|
220 |
+
# tokenized_prompts = torch.cat([tokenized_prompts_pos, tokenized_prompts_neg], dim=0) # torch.Tensor
|
221 |
+
self.register_buffer("tokenized_prompts_pos", tokenized_prompts_pos)
|
222 |
+
self.register_buffer("tokenized_prompts_neg", tokenized_prompts_neg)
|
223 |
+
print("tokenized_prompts shape", self.tokenized_prompts_pos.shape, self.tokenized_prompts_neg.shape)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def forward(self, cls_id =None):
|
228 |
+
|
229 |
+
ctx_pos = self.ctx_pos
|
230 |
+
ctx_neg = self.ctx_neg
|
231 |
+
ctx_pos = self.ctx_pos
|
232 |
+
ctx_neg = self.ctx_neg
|
233 |
+
# print("shape", self.ctx_pos[0:1].shape, ctx_pos.shape)
|
234 |
+
prefix_pos = self.token_prefix_pos
|
235 |
+
prefix_neg = self.token_prefix_neg
|
236 |
+
suffix_pos = self.token_suffix_pos
|
237 |
+
suffix_neg = self.token_suffix_neg
|
238 |
+
|
239 |
+
# print(prefix_pos.shape, prefix_neg.shape)
|
240 |
+
|
241 |
+
prompts_pos = torch.cat(
|
242 |
+
[
|
243 |
+
# N(the number of template), 1, dim
|
244 |
+
prefix_pos, # (n_cls, 1, dim)
|
245 |
+
ctx_pos, # (n_cls, n_ctx, dim)
|
246 |
+
suffix_pos, # (n_cls, *, dim)
|
247 |
+
],
|
248 |
+
dim=2,
|
249 |
+
)
|
250 |
+
|
251 |
+
prompts_neg = torch.cat(
|
252 |
+
[
|
253 |
+
prefix_neg, # (n_cls, 1, dim)
|
254 |
+
ctx_neg, # (n_cls, n_ctx, dim)
|
255 |
+
suffix_neg, # (n_cls, *, dim)
|
256 |
+
],
|
257 |
+
dim=2,
|
258 |
+
)
|
259 |
+
_, _, l, d = prompts_pos.shape
|
260 |
+
prompts_pos = prompts_pos.reshape(-1, l, d)
|
261 |
+
_, _, l, d = prompts_neg.shape
|
262 |
+
prompts_neg = prompts_neg.reshape(-1, l, d)
|
263 |
+
prompts = torch.cat([prompts_pos, prompts_neg], dim=0)
|
264 |
+
|
265 |
+
|
266 |
+
_, l, d = self.tokenized_prompts_pos.shape
|
267 |
+
tokenized_prompts_pos = self.tokenized_prompts_pos.reshape(-1, d)
|
268 |
+
_, l, d = self.tokenized_prompts_neg.shape
|
269 |
+
tokenized_prompts_neg = self.tokenized_prompts_neg.reshape(-1, d)
|
270 |
+
tokenized_prompts = torch.cat((tokenized_prompts_pos, tokenized_prompts_neg), dim = 0)
|
271 |
+
|
272 |
+
|
273 |
+
return prompts, tokenized_prompts, self.compound_prompts_text
|
training_libs/utils.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchvision.transforms as transforms
|
2 |
+
# from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
|
3 |
+
from AnomalyCLIP_lib.transform import image_transform
|
4 |
+
from AnomalyCLIP_lib.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def normalize(pred, max_value=None, min_value=None):
|
9 |
+
if max_value is None or min_value is None:
|
10 |
+
return (pred - pred.min()) / (pred.max() - pred.min())
|
11 |
+
else:
|
12 |
+
return (pred - min_value) / (max_value - min_value)
|
13 |
+
|
14 |
+
def get_transform(args):
|
15 |
+
preprocess = image_transform(args.image_size, is_train=False, mean = OPENAI_DATASET_MEAN, std = OPENAI_DATASET_STD)
|
16 |
+
target_transform = transforms.Compose([
|
17 |
+
transforms.Resize((args.image_size, args.image_size)),
|
18 |
+
transforms.CenterCrop(args.image_size),
|
19 |
+
transforms.ToTensor()
|
20 |
+
])
|
21 |
+
preprocess.transforms[0] = transforms.Resize(size=(args.image_size, args.image_size), interpolation=transforms.InterpolationMode.BICUBIC,
|
22 |
+
max_size=None, antialias=None)
|
23 |
+
preprocess.transforms[1] = transforms.CenterCrop(size=(args.image_size, args.image_size))
|
24 |
+
return preprocess, target_transform
|
training_libs/visualization.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
# from utils import normalize
|
4 |
+
from sklearn.preprocessing import normalize
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def visualizer(pathes, anomaly_map, img_size, save_path, cls_name):
|
8 |
+
for idx, path in enumerate(pathes):
|
9 |
+
cls = path.split('/')[-2]
|
10 |
+
filename = path.split('/')[-1]
|
11 |
+
vis = cv2.cvtColor(cv2.resize(cv2.imread(path), (img_size, img_size)), cv2.COLOR_BGR2RGB) # RGB
|
12 |
+
mask = normalize(anomaly_map[idx])
|
13 |
+
vis = apply_ad_scoremap(vis, mask)
|
14 |
+
vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR) # BGR
|
15 |
+
save_vis = os.path.join(save_path, 'imgs', cls_name[idx], cls)
|
16 |
+
if not os.path.exists(save_vis):
|
17 |
+
os.makedirs(save_vis)
|
18 |
+
cv2.imwrite(os.path.join(save_vis, filename), vis)
|
19 |
+
|
20 |
+
def apply_ad_scoremap(image, scoremap, alpha=0.5):
|
21 |
+
np_image = np.asarray(image, dtype=float)
|
22 |
+
scoremap = (scoremap * 255).astype(np.uint8)
|
23 |
+
scoremap = cv2.applyColorMap(scoremap, cv2.COLORMAP_JET)
|
24 |
+
scoremap = cv2.cvtColor(scoremap, cv2.COLOR_BGR2RGB)
|
25 |
+
return (alpha * np_image + (1 - alpha) * scoremap).astype(np.uint8)
|