File size: 10,730 Bytes
71d05bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
#%%
import AnomalyCLIP_lib
import torch
import argparse
import torch.nn.functional as F
from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
from training_libs.loss import FocalLoss, BinaryDiceLoss
from training_libs.utils import normalize
from training_libs.dataset import Dataset_test
from training_libs.logger import get_logger
from tqdm import tqdm
import os
import random
import numpy as np
from tabulate import tabulate
from training_libs.utils import get_transform
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from training_libs.visualization import visualizer
from training_libs.metrics import image_level_metrics, pixel_level_metrics
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
def test(args):
img_size = args.image_size
features_list = args.features_list
dataset_dir = args.data_path
save_path = args.save_path
dataset_name = args.dataset
logger = get_logger(args.save_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "gpu"
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
model.eval()
# torch.save(model.state_dict(),"pre-trained models/clip")
preprocess, target_transform = get_transform(args)
test_data = Dataset_test(root=args.data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)
obj_list = test_data.obj_list
results = {}
metrics = {}
for obj in obj_list:
results[obj] = {}
results[obj]['gt_sp'] = []
results[obj]['pr_sp'] = []
results[obj]['imgs_masks'] = []
results[obj]['anomaly_maps'] = []
metrics[obj] = {}
metrics[obj]['pixel-auroc'] = 0
metrics[obj]['pixel-aupro'] = 0
metrics[obj]['image-auroc'] = 0
metrics[obj]['image-ap'] = 0
prompt_learner = AnomalyCLIP_PromptLearner(model.to(device=device), AnomalyCLIP_parameters)
#Add check-point from trained model with normal images
# checkpoint = torch.load("checkpoint/241120_SP_DPAM_13_518/epoch_500.pth",map_location=torch.device('cpu'))
# prompt_learner.load_state_dict(checkpoint["prompt_learner"])
#Add check-point from trained model with normal images
# checkpoint = torch.load(args.checkpoint_path,map_location=torch.device(device=device))
# prompt_learner.load_state_dict(checkpoint["prompt_learner"])
prompt_learner.to(device)
model.to(device)
model.visual.DAPM_replace(DPAM_layer = 13)
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
print("print(prompts)")
print(prompts)
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
text_features = text_features/text_features.norm(dim=-1, keepdim=True)
model.to(device)
for idx, items in enumerate(tqdm(test_dataloader)):
image = items['img'].to(device)
cls_name = items['cls_name']
cls_id = items['cls_id']
gt_mask_initial = items['img_mask']
#convert gt mask to good (0) and anomaly (1)
gt_mask = items['img_mask']
gt_mask[gt_mask > 0.5], gt_mask[gt_mask <= 0.5] = 1, 0
results[cls_name[0]]['imgs_masks'].append(gt_mask) # px
results[cls_name[0]]['gt_sp'].extend(items['anomaly'].detach().cpu())
with torch.no_grad():
image_features, patch_features = model.encode_image(image, features_list, DPAM_layer = 20)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_probs = image_features @ text_features.permute(0, 2, 1)
text_probs = (text_probs/0.07).softmax(-1)
text_probs = text_probs[:, 0, 1]
anomaly_map_list = []
for idx, patch_feature in enumerate(patch_features):
if idx >= args.feature_map_layer[0]:
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size)
anomaly_map = (similarity_map[...,1] + 1 - similarity_map[...,0])/2.0
anomaly_map_list.append(anomaly_map)
anomaly_map = torch.stack(anomaly_map_list)
anomaly_map = anomaly_map.sum(dim = 0)
results[cls_name[0]]['pr_sp'].extend(text_probs.detach().cpu())
anomaly_map = torch.stack([torch.from_numpy(gaussian_filter(i, sigma = args.sigma)) for i in anomaly_map.detach().cpu()], dim = 0 )
results[cls_name[0]]['anomaly_maps'].append(anomaly_map)
#Save the anomaly map images
visualizer(items['img_path'], anomaly_map.detach().cpu().numpy(), args.image_size, args.save_path, cls_name)
print("print(results)")
torch.save(results,"results/results_shinpyung_0.pt")
# print(results)
table_ls = []
image_auroc_list = []
image_ap_list = []
pixel_auroc_list = []
pixel_aupro_list = []
for obj in obj_list:
table = []
table.append(obj)
results[obj]['imgs_masks'] = torch.cat(results[obj]['imgs_masks'])
results[obj]['anomaly_maps'] = torch.cat(results[obj]['anomaly_maps']).detach().cpu().numpy()
if args.metrics == 'image-level':
image_auroc = image_level_metrics(results, obj, "image-auroc")
image_ap = image_level_metrics(results, obj, "image-ap")
table.append(str(np.round(image_auroc * 100, decimals=1)))
table.append(str(np.round(image_ap * 100, decimals=1)))
image_auroc_list.append(image_auroc)
image_ap_list.append(image_ap)
elif args.metrics == 'pixel-level':
pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
table.append(str(np.round(pixel_auroc * 100, decimals=1)))
table.append(str(np.round(pixel_aupro * 100, decimals=1)))
pixel_auroc_list.append(pixel_auroc)
pixel_aupro_list.append(pixel_aupro)
elif args.metrics == 'image-pixel-level':
image_auroc = image_level_metrics(results, obj, "image-auroc")
image_ap = image_level_metrics(results, obj, "image-ap")
pixel_auroc = pixel_level_metrics(results, obj, "pixel-auroc")
pixel_aupro = pixel_level_metrics(results, obj, "pixel-aupro")
table.append(str(np.round(pixel_auroc * 100, decimals=1)))
table.append(str(np.round(pixel_aupro * 100, decimals=1)))
table.append(str(np.round(image_auroc * 100, decimals=1)))
table.append(str(np.round(image_ap * 100, decimals=1)))
image_auroc_list.append(image_auroc)
image_ap_list.append(image_ap)
pixel_auroc_list.append(pixel_auroc)
pixel_aupro_list.append(pixel_aupro)
table_ls.append(table)
if args.metrics == 'image-level':
# logger
table_ls.append(['mean',
str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
results = tabulate(table_ls, headers=['objects', 'image_auroc', 'image_ap'], tablefmt="pipe")
elif args.metrics == 'pixel-level':
# logger
table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1))
])
results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro'], tablefmt="pipe")
elif args.metrics == 'image-pixel-level':
# logger
table_ls.append(['mean', str(np.round(np.mean(pixel_auroc_list) * 100, decimals=1)),
str(np.round(np.mean(pixel_aupro_list) * 100, decimals=1)),
str(np.round(np.mean(image_auroc_list) * 100, decimals=1)),
str(np.round(np.mean(image_ap_list) * 100, decimals=1))])
results = tabulate(table_ls, headers=['objects', 'pixel_auroc', 'pixel_aupro', 'image_auroc', 'image_ap'], tablefmt="pipe")
logger.info("\n%s", results)
if __name__ == '__main__':
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True)
# paths
parser.add_argument("--data_path", type=str, default="./data/4inlab/", help="path to test dataset")
parser.add_argument("--save_path", type=str, default='./results/', help='path to save results')
parser.add_argument("--checkpoint_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to checkpoint')
# model
parser.add_argument("--dataset", type=str, default='4inlab')
parser.add_argument("--image_size", type=int, default=518, help="image size")
parser.add_argument("--depth", type=int, default=9, help="image size")
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
parser.add_argument("--metrics", type=str, default='image-pixel-level')
parser.add_argument("--seed", type=int, default=111, help="random seed")
parser.add_argument("--sigma", type=int, default=4, help="zero shot")
# Specify layers from which feature maps will be extracted (can pass multiple values)
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
# List of layers whose features will be used
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
args = parser.parse_args()
print(args)
setup_seed(args.seed)
test(args)
#%%
|