File size: 9,680 Bytes
71d05bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import AnomalyCLIP_lib
import torch
import argparse
import torch.nn.functional as F
from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
from training_libs.loss import FocalLoss, BinaryDiceLoss
from training_libs.utils import normalize
from training_libs.dataset import Dataset_train
from training_libs.logger import get_logger
from tqdm import tqdm
import numpy as np
import random
from training_libs.utils import get_transform
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class RealTimePlotter: #
def __init__(self):
self.epochs = []
self.loss_list = []
self.image_loss_list = []
self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(14, 6))
plt.ion()
self.fig.show()
self.fig.canvas.flush_events()
def update(self, epoch, loss, image_loss):
self.epochs.append(epoch)
self.loss_list.append(loss)
self.image_loss_list.append(image_loss)
self.ax1.clear()
self.ax2.clear()
self.ax1.plot(self.epochs, self.loss_list, label='Training Loss')
self.ax1.set_title('Training Loss')
self.ax1.set_xlabel('Epochs')
self.ax1.set_ylabel('Loss')
self.ax1.legend()
self.ax2.plot(self.epochs, self.image_loss_list, label='Image Loss')
self.ax2.set_title('Image Loss')
self.ax2.set_xlabel('Epochs')
self.ax2.set_ylabel('Loss')
self.ax2.legend()
self.fig.canvas.flush_events()
def train(args):
logger = get_logger(args.save_path)
preprocess, target_transform = get_transform(args)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
# model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters)
model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
model.eval()
train_data = Dataset_train(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
##########################################################################################
prompt_learner = AnomalyCLIP_PromptLearner(model.to(device), AnomalyCLIP_parameters)
prompt_learner.to(device)
model.to(device)
model.visual.DAPM_replace(DPAM_layer = args.dpam)
##########################################################################################
optimizer = torch.optim.Adam(list(prompt_learner.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
model.eval()
prompt_learner.train()
# plotter = RealTimePlotter()
for epoch in tqdm(range(args.epoch)):
model.eval()
prompt_learner.train()
loss_list = []
image_loss_list = []
for items in tqdm(train_dataloader):
image = items['img'].to(device)
label = items['anomaly']
gt = items['img_mask'].squeeze().to(device)
gt[gt > 0.5] = 1
gt[gt <= 0.5] = 0
with torch.no_grad():
# Apply DPAM to the layer from 6 to 24
# DPAM_layer represents the number of layer refined by DPAM from top to bottom
# DPAM_layer = 1, no DPAM is used
# DPAM_layer = 20 as default
image_features, patch_features = model.encode_image(image, args.features_list, DPAM_layer = args.dpam)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
####################################
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
text_features = text_features/text_features.norm(dim=-1, keepdim=True)
# Apply DPAM surgery
text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1)
text_probs = text_probs[:, 0, ...]/0.07
image_loss = F.cross_entropy(text_probs.squeeze(), label.long().cuda()) #Process with GPU
#image_loss = F.cross_entropy(text_probs.squeeze(), label.long()) #Without GPU and using CPU
image_loss_list.append(image_loss.item())
######################################################################
similarity_map_list = []
# similarity_map_list.append(similarity_map)
for idx, patch_feature in enumerate(patch_features):
if idx >= args.feature_map_layer[0]:
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size).permute(0, 3, 1, 2)
similarity_map_list.append(similarity_map)
loss = 0
for i in range(len(similarity_map_list)):
loss += loss_focal(similarity_map_list[i], gt)
loss += loss_dice(similarity_map_list[i][:, 1, :, :], gt)
loss += loss_dice(similarity_map_list[i][:, 0, :, :], 1-gt)
optimizer.zero_grad()
(loss+image_loss).backward()
optimizer.step()
loss_list.append(loss.item())
# logs
if (epoch + 1) % args.print_freq == 0:
avg_loss = np.mean(loss_list)
avg_image_loss = np.mean(image_loss_list)
logger.info('epoch [{}/{}], loss:{:.4f}, image_loss:{:.4f}'.format(epoch + 1, args.epoch, avg_loss, avg_image_loss))
# plotter.update(epoch + 1, avg_loss, avg_image_loss) #Realtime training performance monitoring
# save model
if (epoch + 1) % args.save_freq == 0:
ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth')
torch.save({"prompt_learner": prompt_learner.state_dict(),"epoch":epoch+1}, ckp_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) # Initialize the argument parser
# Define the path to the training dataset and model checkpoint saving
parser.add_argument("--train_data_path", type=str, default="./data/4inlab", help="train dataset path")
parser.add_argument("--save_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to save results')
# Specify the name of the training dataset
parser.add_argument("--dataset", type=str, default='4inlab', help="train dataset name")
# Set the depth parameter (Note: "image size" in help may be misleading)
parser.add_argument("--depth", type=int, default=9, help="image size")
# Set the prompt length and learnable text embedding length for "zero-shot" learning
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
# Specify layers from which feature maps will be extracted (can pass multiple values)
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
# List of layers whose features will be used
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
# Setting parameters for training
parser.add_argument("--epoch", type=int, default=400, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
# Size/depth parameter for the DPAM (Deep Prompt Attention Mechanism)
parser.add_argument("--dpam", type=int, default=13, help="dpam size")
# Define the size of input images used for training
parser.add_argument("--image_size", type=int, default=518, help="image size")
# Frequency (in epochs) of logging training information and saving
parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
parser.add_argument("--seed", type=int, default=111, help="random seed")
args = parser.parse_args() # Parse the command-line arguments and store them in the 'args' object
setup_seed(args.seed) # Set the random seed for reproducibility using the provided seed value
train(args) # Call the training function with the parsed arguments
|