quandao92's picture
Upload 48 files
71d05bb verified
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import AnomalyCLIP_lib
import torch
import argparse
import torch.nn.functional as F
from training_libs.prompt_ensemble import AnomalyCLIP_PromptLearner
from training_libs.loss import FocalLoss, BinaryDiceLoss
from training_libs.utils import normalize
from training_libs.dataset import Dataset_train
from training_libs.logger import get_logger
from tqdm import tqdm
import numpy as np
import random
from training_libs.utils import get_transform
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class RealTimePlotter: #
def __init__(self):
self.epochs = []
self.loss_list = []
self.image_loss_list = []
self.fig, (self.ax1, self.ax2) = plt.subplots(1, 2, figsize=(14, 6))
plt.ion()
self.fig.show()
self.fig.canvas.flush_events()
def update(self, epoch, loss, image_loss):
self.epochs.append(epoch)
self.loss_list.append(loss)
self.image_loss_list.append(image_loss)
self.ax1.clear()
self.ax2.clear()
self.ax1.plot(self.epochs, self.loss_list, label='Training Loss')
self.ax1.set_title('Training Loss')
self.ax1.set_xlabel('Epochs')
self.ax1.set_ylabel('Loss')
self.ax1.legend()
self.ax2.plot(self.epochs, self.image_loss_list, label='Image Loss')
self.ax2.set_title('Image Loss')
self.ax2.set_xlabel('Epochs')
self.ax2.set_ylabel('Loss')
self.ax2.legend()
self.fig.canvas.flush_events()
def train(args):
logger = get_logger(args.save_path)
preprocess, target_transform = get_transform(args)
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
AnomalyCLIP_parameters = {"Prompt_length": args.n_ctx, "learnabel_text_embedding_depth": args.depth, "learnabel_text_embedding_length": args.t_n_ctx}
# model, _ = AnomalyCLIP_lib.load("ViT-L/14@336px", device=device, design_details = AnomalyCLIP_parameters)
model, _ = AnomalyCLIP_lib.load("pre-trained models/clip/ViT-B-32.pt", device=device, design_details = AnomalyCLIP_parameters)
model.eval()
train_data = Dataset_train(root=args.train_data_path, transform=preprocess, target_transform=target_transform, dataset_name = args.dataset)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True)
##########################################################################################
prompt_learner = AnomalyCLIP_PromptLearner(model.to(device), AnomalyCLIP_parameters)
prompt_learner.to(device)
model.to(device)
model.visual.DAPM_replace(DPAM_layer = args.dpam)
##########################################################################################
optimizer = torch.optim.Adam(list(prompt_learner.parameters()), lr=args.learning_rate, betas=(0.5, 0.999))
# losses
loss_focal = FocalLoss()
loss_dice = BinaryDiceLoss()
model.eval()
prompt_learner.train()
# plotter = RealTimePlotter()
for epoch in tqdm(range(args.epoch)):
model.eval()
prompt_learner.train()
loss_list = []
image_loss_list = []
for items in tqdm(train_dataloader):
image = items['img'].to(device)
label = items['anomaly']
gt = items['img_mask'].squeeze().to(device)
gt[gt > 0.5] = 1
gt[gt <= 0.5] = 0
with torch.no_grad():
# Apply DPAM to the layer from 6 to 24
# DPAM_layer represents the number of layer refined by DPAM from top to bottom
# DPAM_layer = 1, no DPAM is used
# DPAM_layer = 20 as default
image_features, patch_features = model.encode_image(image, args.features_list, DPAM_layer = args.dpam)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
####################################
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()
text_features = torch.stack(torch.chunk(text_features, dim = 0, chunks = 2), dim = 1)
text_features = text_features/text_features.norm(dim=-1, keepdim=True)
# Apply DPAM surgery
text_probs = image_features.unsqueeze(1) @ text_features.permute(0, 2, 1)
text_probs = text_probs[:, 0, ...]/0.07
image_loss = F.cross_entropy(text_probs.squeeze(), label.long().cuda()) #Process with GPU
#image_loss = F.cross_entropy(text_probs.squeeze(), label.long()) #Without GPU and using CPU
image_loss_list.append(image_loss.item())
######################################################################
similarity_map_list = []
# similarity_map_list.append(similarity_map)
for idx, patch_feature in enumerate(patch_features):
if idx >= args.feature_map_layer[0]:
patch_feature = patch_feature/ patch_feature.norm(dim = -1, keepdim = True)
similarity, _ = AnomalyCLIP_lib.compute_similarity(patch_feature, text_features[0])
similarity_map = AnomalyCLIP_lib.get_similarity_map(similarity[:, 1:, :], args.image_size).permute(0, 3, 1, 2)
similarity_map_list.append(similarity_map)
loss = 0
for i in range(len(similarity_map_list)):
loss += loss_focal(similarity_map_list[i], gt)
loss += loss_dice(similarity_map_list[i][:, 1, :, :], gt)
loss += loss_dice(similarity_map_list[i][:, 0, :, :], 1-gt)
optimizer.zero_grad()
(loss+image_loss).backward()
optimizer.step()
loss_list.append(loss.item())
# logs
if (epoch + 1) % args.print_freq == 0:
avg_loss = np.mean(loss_list)
avg_image_loss = np.mean(image_loss_list)
logger.info('epoch [{}/{}], loss:{:.4f}, image_loss:{:.4f}'.format(epoch + 1, args.epoch, avg_loss, avg_image_loss))
# plotter.update(epoch + 1, avg_loss, avg_image_loss) #Realtime training performance monitoring
# save model
if (epoch + 1) % args.save_freq == 0:
ckp_path = os.path.join(args.save_path, 'epoch_' + str(epoch + 1) + '.pth')
torch.save({"prompt_learner": prompt_learner.state_dict(),"epoch":epoch+1}, ckp_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser("AnomalyCLIP", add_help=True) # Initialize the argument parser
# Define the path to the training dataset and model checkpoint saving
parser.add_argument("--train_data_path", type=str, default="./data/4inlab", help="train dataset path")
parser.add_argument("--save_path", type=str, default='./checkpoint/241122_SP_DPAM_13_518', help='path to save results')
# Specify the name of the training dataset
parser.add_argument("--dataset", type=str, default='4inlab', help="train dataset name")
# Set the depth parameter (Note: "image size" in help may be misleading)
parser.add_argument("--depth", type=int, default=9, help="image size")
# Set the prompt length and learnable text embedding length for "zero-shot" learning
parser.add_argument("--n_ctx", type=int, default=12, help="zero shot")
parser.add_argument("--t_n_ctx", type=int, default=4, help="zero shot")
# Specify layers from which feature maps will be extracted (can pass multiple values)
parser.add_argument("--feature_map_layer", type=int, nargs="+", default=[0, 1, 2, 3], help="zero shot")
# List of layers whose features will be used
parser.add_argument("--features_list", type=int, nargs="+", default=[6, 12, 18, 24], help="features used")
# Setting parameters for training
parser.add_argument("--epoch", type=int, default=400, help="epochs")
parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
# Size/depth parameter for the DPAM (Deep Prompt Attention Mechanism)
parser.add_argument("--dpam", type=int, default=13, help="dpam size")
# Define the size of input images used for training
parser.add_argument("--image_size", type=int, default=518, help="image size")
# Frequency (in epochs) of logging training information and saving
parser.add_argument("--print_freq", type=int, default=1, help="print frequency")
parser.add_argument("--save_freq", type=int, default=1, help="save frequency")
parser.add_argument("--seed", type=int, default=111, help="random seed")
args = parser.parse_args() # Parse the command-line arguments and store them in the 'args' object
setup_seed(args.seed) # Set the random seed for reproducibility using the provided seed value
train(args) # Call the training function with the parsed arguments