import gradio as gr from load_model import extract_sel_mean_std_bias_assignemnt from pathlib import Path from architectures.model_mapping import get_model from configs.dataset_params import dataset_constants import torch import torchvision.transforms as transforms import pandas as pd import cv2 import numpy as np from PIL import Image from get_data import get_augmentation from configs.dataset_params import normalize_params import random from evaluation.diversity import MultiKCrossChannelMaxPooledSum def overlapping_features_on_input(model,output, feature_maps, input, target): W=model.linear.layer.weight feature_maps=feature_maps.detach().cpu().numpy().squeeze() print("feature_maps",feature_maps.shape) if target !=None: label=target-1 else: output=output.detach().cpu().numpy() label=np.argmax(output) Interpretable_Selection= W[label,:] print("W",Interpretable_Selection) input_np=np.array(input) h,w= input.shape[:2] print("h,w:",h,w) Interpretable_Features=[] input_np=cv2.resize(input_np,(448,448)) Feature_image_list=[input_np] color_id=0 #set each feature to singel color COLOR=['R','G','B','Y','P','C'] for S in range(len(Interpretable_Selection)): if Interpretable_Selection[S] > 0: Interpretable_Features.append(feature_maps[S]) Feature_image=cv2.resize(feature_maps[S],(448,448)) Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) Feature_image=Feature_image.astype(np.uint8) #set each feature to singel color if color_id>len(COLOR)-1: color_id=color_id%len(COLOR) color=COLOR[color_id] if color == 'R': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,0]=Feature_image Feature_image=Feature_image_color if color == 'G': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,1]=Feature_image Feature_image=Feature_image_color if color == 'B': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,2]=Feature_image Feature_image=Feature_image_color if color == 'Y': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,0]=Feature_image Feature_image_color[:,:,1]=Feature_image Feature_image=Feature_image_color if color == 'P': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,0]=Feature_image Feature_image_color[:,:,2]=Feature_image Feature_image=Feature_image_color if color == 'C': Feature_image_color=np.zeros_like(input_np) Feature_image_color[:,:,1]=Feature_image Feature_image_color[:,:,2]=Feature_image Feature_image=Feature_image_color color_id+=1 # use Gamma correction Feature_image=np.power(Feature_image,1.3) # use Gamma correction #set each feature to singel color # Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) input_np=cv2.cvtColor(input_np, cv2.COLOR_BGR2GRAY) input_np=cv2.cvtColor(input_np,cv2.COLOR_GRAY2BGR) Feature_image=0.2*Feature_image+0.8*input_np Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) Feature_image=Feature_image.astype(np.uint8) # path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg" # cv2.imwrite(path_to_featureimage,Feature_image) Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR) Feature_image_list.append(Feature_image) print("len of Features:",len(Interpretable_Features)) return Feature_image_list def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True): n_classes = dataset_constants[dataset]["num_classes"] # image_re=np.array(input) input=Image.fromarray(input) print("input shape",input.size) model = get_model(arch, n_classes, reduced_strides) tr=transform_input_img(input,img_size) # tr=transforms.Compose([ # transforms.Resize(500), # transforms.CenterCrop(img_size), # transforms.ToTensor(), # ]) #TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) device = torch.device("cpu") if folder is None: folder = Path(f"tmp/{arch}/{dataset}/{seed}/") model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth",map_location=torch.device('cpu'))) state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth", map_location=torch.device('cpu')) selection= torch.load(folder / f"SlDD_Selection_50.pt", map_location=torch.device('cpu')) state_dict['linear.selection']=selection feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) model.load_state_dict(state_dict) input = tr(input) # path_to_input="/home/qixuan/tmp/FeatureImage/croped.jpg" # path_to_input_re="/home/qixuan/tmp/FeatureImage/re.jpg" # path_to_input_concat="/home/qixuan/tmp/FeatureImage/concate.jpg" # image_re=cv2.cvtColor(image_re, cv2.COLOR_RGB2BGR) # image_re=cv2.resize(image_re,(448,448)) # image_np = (input * 255).clamp(0, 255).byte() # image_np = image_np.permute(1, 2, 0).numpy() # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # print("????",input.shape) # concat=np.vstack((image_re, image_np)) # cv2.imwrite(path_to_input,image_np) # cv2.imwrite(path_to_input_re,image_re) # cv2.imwrite(path_to_input_concat,concat) input= input.unsqueeze(0) input= input.to(device) model = model.to(device) model.eval() with torch.no_grad(): output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) print("featuresmap size:",feature_maps.size()) output_np=output.detach().cpu().numpy() output_np= np.argmax(output_np)+1 if with_featuremaps: return output_np,model,feature_maps else: return output_np, model def get_options_from_trainingset(output, model, TR, device,with_other_class): print("outputclass:",output) data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target']) namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) options_output=labels[labels['target']==output] print(options_output) print(labels) options=options_output.sample(4) #mode 2 if with_other_class: other_targets=random.sample([i for i in range(1,200)if i != output],3) all_targets=[output]+other_targets for tg in other_targets: others=labels[labels['target']==tg] options_others=others.sample(4) options = pd.concat([options, options_others], ignore_index=True) else: all_targets=[output] #shuffled_options = options.sample(frac=1).reset_index(drop=True) print("shuffled:",options) op=[] # resample_img_id_list=[]#resample filter W=model.linear.layer.weight# intergrate negative features model.eval() with torch.no_grad(): for t in all_targets: options_class=options[options['target']==t] op_class=[] # intergrate negative features W_class=W[t-1,:] features_id=[ f for f in W_class if f !=0 ] features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0] # intergrate negative features for i in options_class['img_id']: print(i) filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] targets=options.loc[options['img_id']==i,'target'].values[0] print("targets",targets) print("name",filenames) classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] print(data_dir/f"images/{filenames}") op_img=cv2.imread(data_dir/f"images/{filenames}") op_img=cv2.cvtColor(op_img, cv2.COLOR_BGR2RGB) op_imag=Image.fromarray(op_img) op_images=TR(op_imag) op_images=op_images.unsqueeze(0) op_images=op_images.to(device) OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) #ues diversity filter # weight=model.linear.layer.weight # div=filter_with_diversity(feature_maps_op,OP,weight) # DIV=0.8 # while div<=DIV: # options_class_set_for_resample=labels[(labels['target']==t ) # & (~labels['img_id'].isin(options['img_id'])) # &(~labels['img_id'].isin(resample_img_id_list))] # if len(options_class_set_for_resample)<=0: # resample_img_id_list=[] # DIV-=0.1 # continue # resample=options_class_set_for_resample.sample(1) # # print("resample:",resample) # img_id_re=resample.iloc[0]['img_id'] # resample_img_id_list.append(img_id_re) # filenames_re=namelist.loc[namelist['img_id']==img_id_re,'file_name'].values[0] # op_img_re=cv2.imread(data_dir/f"images/{filenames_re}") # op_img_re=cv2.cvtColor(op_img_re, cv2.COLOR_BGR2RGB) # op_imag_re=Image.fromarray(op_img_re) # op_images_re=TR(op_imag_re) # op_images_re=op_images_re.unsqueeze(0) # op_images_re=op_images_re.to(device) # OP_re, feature_maps_op_re=model(op_images_re,with_feature_maps=True,with_final_features=False) # div=filter_with_diversity(feature_maps_op_re,OP_re,weight) #ues diversity filter print("OP:",OP, "feature_maps_op:",feature_maps_op.shape) opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) image_arrays = [np.array(img) for img in opt] concatenated_image = np.concatenate(image_arrays, axis=0) op_class.append(concatenated_image) op_class_arrays=[np.array(img)for img in op_class] concatenate_class=np.concatenate(op_class_arrays, axis=1) op.append((concatenate_class,features_id_neg))# intergrate negative features return op def transform_input_img(input,img_size): h,w=input.size rate=h/w if h >= w: w_new=img_size h_new=int(w_new*rate) else: h_new=img_size w_new=int(h_new/rate) return transforms.Compose([ transforms.Resize((w_new,h_new)), transforms.CenterCrop(img_size), transforms.ToTensor(), ]) def post_next_image(OPT: str,key:str): if OPT==key: return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) else: return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) def get_features_on_interface(input): img_size=448 output,model=genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50,n_per_class=5, img_size=448, reduced_strides=False, folder = None,with_featuremaps=False) TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) device = torch.device("cpu") op= get_options_from_trainingset(output, model, TR, device,with_other_class=True) key=op[0][0]# intergrate negative features random.shuffle(op) option=[(op[0][0],"A"), (op[1][0],"B"), (op[2][0],"C"), (op[3][0],"D")] for value,char in option: if np.array_equal(value,key): key_op=char print("key",key_op) # if op[0][1]!=[]: # option[0][1]=f"A,features{', '.join(map(str, op[0][1]))} are negative." # if op[1][1]!=[]: # option[1][1]=f"B,features{', '.join(map(str, op[0][1]))} are negative." # if op[2][1]!=[]: # option[2][1]=f"C,features{', '.join(map(str, op[0][1]))} are negative." # if op[3][1]!=[]: # option[3][1]=f"D,features{', '.join(map(str, op[0][1]))} are negative." return option, key_op," These are some class explanations from our model for different classes, which of these classes has our model predicted? ",gr.update(interactive=False) def direct_inference(input): img_size=448 output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50,n_per_class=5, img_size=448, reduced_strides=False, folder = None,with_featuremaps=True) # image_list=overlapping_features_on_input(model,output,feature_maps,input,target=None) # image_arrays = [np.array(img) for img in image_list] # concatenated_image = np.concatenate(image_arrays, axis=0) TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) device = torch.device("cpu") concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False) #original Input=Image.fromarray(input) tr=transform_input_img(Input,img_size) Input=tr(Input) image_np = (Input * 255).clamp(0, 255).byte() image_np = image_np.permute(1, 2, 0).numpy() # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output)#input image_np ORI_arrays = [np.array(img) for img in ORI] concatenated_ORI = np.concatenate(ORI_arrays, axis=0) print(concatenated_ORI.shape,concatenated_image[0][0].shape) concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1) print(concatenated_image_final_array.shape) #original data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0] if concatenated_image[0][1]!=[]: output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative." else: output_name_and_features=f"{output_name}, all features are positive." return concatenated_image_final_array, output_name_and_features def filter_with_diversity(featuremaps,output,weight): localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None) localizer(output.to("cpu"),featuremaps.to("cpu")) locality, exlusive_locality = localizer.get_result() diversity = locality[4] diversity=diversity.item() return diversity