Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from load_model import extract_sel_mean_std_bias_assignemnt | |
| from pathlib import Path | |
| from architectures.model_mapping import get_model | |
| from configs.dataset_params import dataset_constants | |
| import torch | |
| import torchvision.transforms as transforms | |
| import pandas as pd | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from get_data import get_augmentation | |
| from configs.dataset_params import normalize_params | |
| import random | |
| from evaluation.diversity import MultiKCrossChannelMaxPooledSum | |
| def overlapping_features_on_input(model,output, feature_maps, input, target): | |
| W=model.linear.layer.weight | |
| feature_maps=feature_maps.detach().cpu().numpy().squeeze() | |
| print("feature_maps",feature_maps.shape) | |
| if target !=None: | |
| label=target-1 | |
| else: | |
| output=output.detach().cpu().numpy() | |
| label=np.argmax(output) | |
| Interpretable_Selection= W[label,:] | |
| print("W",Interpretable_Selection) | |
| input_np=np.array(input) | |
| h,w= input.shape[:2] | |
| print("h,w:",h,w) | |
| Interpretable_Features=[] | |
| input_np=cv2.resize(input_np,(448,448)) | |
| Feature_image_list=[input_np] | |
| color_id=0 #set each feature to singel color | |
| COLOR=['R','G','B','Y','P','C'] | |
| for S in range(len(Interpretable_Selection)): | |
| if Interpretable_Selection[S] > 0: | |
| Interpretable_Features.append(feature_maps[S]) | |
| Feature_image=cv2.resize(feature_maps[S],(448,448)) | |
| Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) | |
| Feature_image=Feature_image.astype(np.uint8) | |
| #set each feature to singel color | |
| if color_id>len(COLOR)-1: | |
| color_id=color_id%len(COLOR) | |
| color=COLOR[color_id] | |
| if color == 'R': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,0]=Feature_image | |
| Feature_image=Feature_image_color | |
| if color == 'G': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,1]=Feature_image | |
| Feature_image=Feature_image_color | |
| if color == 'B': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,2]=Feature_image | |
| Feature_image=Feature_image_color | |
| if color == 'Y': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,0]=Feature_image | |
| Feature_image_color[:,:,1]=Feature_image | |
| Feature_image=Feature_image_color | |
| if color == 'P': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,0]=Feature_image | |
| Feature_image_color[:,:,2]=Feature_image | |
| Feature_image=Feature_image_color | |
| if color == 'C': | |
| Feature_image_color=np.zeros_like(input_np) | |
| Feature_image_color[:,:,1]=Feature_image | |
| Feature_image_color[:,:,2]=Feature_image | |
| Feature_image=Feature_image_color | |
| color_id+=1 | |
| # use Gamma correction | |
| Feature_image=np.power(Feature_image,1.3) | |
| # use Gamma correction | |
| #set each feature to singel color | |
| # Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) | |
| input_np=cv2.cvtColor(input_np, cv2.COLOR_BGR2GRAY) | |
| input_np=cv2.cvtColor(input_np,cv2.COLOR_GRAY2BGR) | |
| Feature_image=0.2*Feature_image+0.8*input_np | |
| Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) | |
| Feature_image=Feature_image.astype(np.uint8) | |
| # path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg" | |
| # cv2.imwrite(path_to_featureimage,Feature_image) | |
| Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR) | |
| Feature_image_list.append(Feature_image) | |
| print("len of Features:",len(Interpretable_Features)) | |
| return Feature_image_list | |
| def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True): | |
| n_classes = dataset_constants[dataset]["num_classes"] | |
| # image_re=np.array(input) | |
| input=Image.fromarray(input) | |
| print("input shape",input.size) | |
| model = get_model(arch, n_classes, reduced_strides) | |
| tr=transform_input_img(input,img_size) | |
| # tr=transforms.Compose([ | |
| # transforms.Resize(500), | |
| # transforms.CenterCrop(img_size), | |
| # transforms.ToTensor(), | |
| # ]) | |
| #TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
| device = torch.device("cpu") | |
| if folder is None: | |
| folder = Path(f"tmp/{arch}/{dataset}/{seed}/") | |
| model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth",map_location=torch.device('cpu'))) | |
| state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth", map_location=torch.device('cpu')) | |
| selection= torch.load(folder / f"SlDD_Selection_50.pt", map_location=torch.device('cpu')) | |
| state_dict['linear.selection']=selection | |
| feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) | |
| model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) | |
| model.load_state_dict(state_dict) | |
| input = tr(input) | |
| # path_to_input="/home/qixuan/tmp/FeatureImage/croped.jpg" | |
| # path_to_input_re="/home/qixuan/tmp/FeatureImage/re.jpg" | |
| # path_to_input_concat="/home/qixuan/tmp/FeatureImage/concate.jpg" | |
| # image_re=cv2.cvtColor(image_re, cv2.COLOR_RGB2BGR) | |
| # image_re=cv2.resize(image_re,(448,448)) | |
| # image_np = (input * 255).clamp(0, 255).byte() | |
| # image_np = image_np.permute(1, 2, 0).numpy() | |
| # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| # print("????",input.shape) | |
| # concat=np.vstack((image_re, image_np)) | |
| # cv2.imwrite(path_to_input,image_np) | |
| # cv2.imwrite(path_to_input_re,image_re) | |
| # cv2.imwrite(path_to_input_concat,concat) | |
| input= input.unsqueeze(0) | |
| input= input.to(device) | |
| model = model.to(device) | |
| model.eval() | |
| with torch.no_grad(): | |
| output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) | |
| print("featuresmap size:",feature_maps.size()) | |
| output_np=output.detach().cpu().numpy() | |
| output_np= np.argmax(output_np)+1 | |
| if with_featuremaps: | |
| return output_np,model,feature_maps | |
| else: | |
| return output_np, model | |
| def get_options_from_trainingset(output, model, TR, device,with_other_class): | |
| print("outputclass:",output) | |
| data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") | |
| labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target']) | |
| namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) | |
| classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) | |
| options_output=labels[labels['target']==output] | |
| print(options_output) | |
| print(labels) | |
| options=options_output.sample(4) | |
| #mode 2 | |
| if with_other_class: | |
| other_targets=random.sample([i for i in range(1,200)if i != output],3) | |
| all_targets=[output]+other_targets | |
| for tg in other_targets: | |
| others=labels[labels['target']==tg] | |
| options_others=others.sample(4) | |
| options = pd.concat([options, options_others], ignore_index=True) | |
| else: | |
| all_targets=[output] | |
| #shuffled_options = options.sample(frac=1).reset_index(drop=True) | |
| print("shuffled:",options) | |
| op=[] | |
| # resample_img_id_list=[]#resample filter | |
| W=model.linear.layer.weight# intergrate negative features | |
| model.eval() | |
| with torch.no_grad(): | |
| for t in all_targets: | |
| options_class=options[options['target']==t] | |
| op_class=[] | |
| # intergrate negative features | |
| W_class=W[t-1,:] | |
| features_id=[ f for f in W_class if f !=0 ] | |
| features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0] | |
| # intergrate negative features | |
| for i in options_class['img_id']: | |
| print(i) | |
| filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] | |
| targets=options.loc[options['img_id']==i,'target'].values[0] | |
| print("targets",targets) | |
| print("name",filenames) | |
| classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] | |
| print(data_dir/f"images/{filenames}") | |
| op_img=cv2.imread(data_dir/f"images/{filenames}") | |
| op_img=cv2.cvtColor(op_img, cv2.COLOR_BGR2RGB) | |
| op_imag=Image.fromarray(op_img) | |
| op_images=TR(op_imag) | |
| op_images=op_images.unsqueeze(0) | |
| op_images=op_images.to(device) | |
| OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) | |
| #ues diversity filter | |
| # weight=model.linear.layer.weight | |
| # div=filter_with_diversity(feature_maps_op,OP,weight) | |
| # DIV=0.8 | |
| # while div<=DIV: | |
| # options_class_set_for_resample=labels[(labels['target']==t ) | |
| # & (~labels['img_id'].isin(options['img_id'])) | |
| # &(~labels['img_id'].isin(resample_img_id_list))] | |
| # if len(options_class_set_for_resample)<=0: | |
| # resample_img_id_list=[] | |
| # DIV-=0.1 | |
| # continue | |
| # resample=options_class_set_for_resample.sample(1) | |
| # # print("resample:",resample) | |
| # img_id_re=resample.iloc[0]['img_id'] | |
| # resample_img_id_list.append(img_id_re) | |
| # filenames_re=namelist.loc[namelist['img_id']==img_id_re,'file_name'].values[0] | |
| # op_img_re=cv2.imread(data_dir/f"images/{filenames_re}") | |
| # op_img_re=cv2.cvtColor(op_img_re, cv2.COLOR_BGR2RGB) | |
| # op_imag_re=Image.fromarray(op_img_re) | |
| # op_images_re=TR(op_imag_re) | |
| # op_images_re=op_images_re.unsqueeze(0) | |
| # op_images_re=op_images_re.to(device) | |
| # OP_re, feature_maps_op_re=model(op_images_re,with_feature_maps=True,with_final_features=False) | |
| # div=filter_with_diversity(feature_maps_op_re,OP_re,weight) | |
| #ues diversity filter | |
| print("OP:",OP, | |
| "feature_maps_op:",feature_maps_op.shape) | |
| opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) | |
| image_arrays = [np.array(img) for img in opt] | |
| concatenated_image = np.concatenate(image_arrays, axis=0) | |
| op_class.append(concatenated_image) | |
| op_class_arrays=[np.array(img)for img in op_class] | |
| concatenate_class=np.concatenate(op_class_arrays, axis=1) | |
| op.append((concatenate_class,features_id_neg))# intergrate negative features | |
| return op | |
| def transform_input_img(input,img_size): | |
| h,w=input.size | |
| rate=h/w | |
| if h >= w: | |
| w_new=img_size | |
| h_new=int(w_new*rate) | |
| else: | |
| h_new=img_size | |
| w_new=int(h_new/rate) | |
| return transforms.Compose([ | |
| transforms.Resize((w_new,h_new)), | |
| transforms.CenterCrop(img_size), | |
| transforms.ToTensor(), | |
| ]) | |
| def post_next_image(OPT: str,key:str): | |
| if OPT==key: | |
| return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) | |
| else: | |
| return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) | |
| def get_features_on_interface(input): | |
| img_size=448 | |
| output,model=genreate_intepriable_output(input,dataset="CUB2011", | |
| arch="resnet50",seed=123456, | |
| model_type="qsenn", n_features = 50,n_per_class=5, | |
| img_size=448, reduced_strides=False, folder = None,with_featuremaps=False) | |
| TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
| device = torch.device("cpu") | |
| op= get_options_from_trainingset(output, model, TR, device,with_other_class=True) | |
| key=op[0][0]# intergrate negative features | |
| random.shuffle(op) | |
| option=[(op[0][0],"A"), | |
| (op[1][0],"B"), | |
| (op[2][0],"C"), | |
| (op[3][0],"D")] | |
| for value,char in option: | |
| if np.array_equal(value,key): | |
| key_op=char | |
| print("key",key_op) | |
| # if op[0][1]!=[]: | |
| # option[0][1]=f"A,features{', '.join(map(str, op[0][1]))} are negative." | |
| # if op[1][1]!=[]: | |
| # option[1][1]=f"B,features{', '.join(map(str, op[0][1]))} are negative." | |
| # if op[2][1]!=[]: | |
| # option[2][1]=f"C,features{', '.join(map(str, op[0][1]))} are negative." | |
| # if op[3][1]!=[]: | |
| # option[3][1]=f"D,features{', '.join(map(str, op[0][1]))} are negative." | |
| return option, key_op," These are some class explanations from our model for different classes, which of these classes has our model predicted? ",gr.update(interactive=False) | |
| def direct_inference(input): | |
| img_size=448 | |
| output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011", | |
| arch="resnet50",seed=123456, | |
| model_type="qsenn", n_features = 50,n_per_class=5, | |
| img_size=448, reduced_strides=False, folder = None,with_featuremaps=True) | |
| # image_list=overlapping_features_on_input(model,output,feature_maps,input,target=None) | |
| # image_arrays = [np.array(img) for img in image_list] | |
| # concatenated_image = np.concatenate(image_arrays, axis=0) | |
| TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
| device = torch.device("cpu") | |
| concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False) | |
| #original | |
| Input=Image.fromarray(input) | |
| tr=transform_input_img(Input,img_size) | |
| Input=tr(Input) | |
| image_np = (Input * 255).clamp(0, 255).byte() | |
| image_np = image_np.permute(1, 2, 0).numpy() | |
| # image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
| ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output)#input image_np | |
| ORI_arrays = [np.array(img) for img in ORI] | |
| concatenated_ORI = np.concatenate(ORI_arrays, axis=0) | |
| print(concatenated_ORI.shape,concatenated_image[0][0].shape) | |
| concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1) | |
| print(concatenated_image_final_array.shape) | |
| #original | |
| data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") | |
| classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) | |
| output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0] | |
| if concatenated_image[0][1]!=[]: | |
| output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative." | |
| else: | |
| output_name_and_features=f"{output_name}, all features are positive." | |
| return concatenated_image_final_array, output_name_and_features | |
| def filter_with_diversity(featuremaps,output,weight): | |
| localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None) | |
| localizer(output.to("cpu"),featuremaps.to("cpu")) | |
| locality, exlusive_locality = localizer.get_result() | |
| diversity = locality[4] | |
| diversity=diversity.item() | |
| return diversity | |