Spaces:
Running
Running
import gradio as gr | |
from load_model import extract_sel_mean_std_bias_assignemnt | |
from pathlib import Path | |
from architectures.model_mapping import get_model | |
from configs.dataset_params import dataset_constants | |
import torch | |
import torchvision.transforms as transforms | |
import pandas as pd | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from get_data import get_augmentation | |
from configs.dataset_params import normalize_params | |
import random | |
from evaluation.diversity import MultiKCrossChannelMaxPooledSum | |
def overlapping_features_on_input(model,output, feature_maps, input, target): | |
W=model.linear.layer.weight | |
feature_maps=feature_maps.detach().cpu().numpy().squeeze() | |
print("feature_maps",feature_maps.shape) | |
if target !=None: | |
label=target-1 | |
else: | |
output=output.detach().cpu().numpy() | |
label=np.argmax(output) | |
Interpretable_Selection= W[label,:] | |
print("W",Interpretable_Selection) | |
input_np=np.array(input) | |
h,w= input.shape[:2] | |
print("h,w:",h,w) | |
Interpretable_Features=[] | |
input_np=cv2.resize(input_np,(448,448)) | |
Feature_image_list=[input_np] | |
color_id=0 #set each feature to singel color | |
COLOR=['R','G','B','Y','P','C'] | |
for S in range(len(Interpretable_Selection)): | |
if Interpretable_Selection[S] > 0: | |
Interpretable_Features.append(feature_maps[S]) | |
Feature_image=cv2.resize(feature_maps[S],(448,448)) | |
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) | |
Feature_image=Feature_image.astype(np.uint8) | |
#set each feature to singel color | |
if color_id>len(COLOR)-1: | |
color_id=color_id%len(COLOR) | |
color=COLOR[color_id] | |
if color == 'R': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,0]=Feature_image | |
Feature_image=Feature_image_color | |
if color == 'G': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,1]=Feature_image | |
Feature_image=Feature_image_color | |
if color == 'B': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,2]=Feature_image | |
Feature_image=Feature_image_color | |
if color == 'Y': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,0]=Feature_image | |
Feature_image_color[:,:,1]=Feature_image | |
Feature_image=Feature_image_color | |
if color == 'P': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,0]=Feature_image | |
Feature_image_color[:,:,2]=Feature_image | |
Feature_image=Feature_image_color | |
if color == 'C': | |
Feature_image_color=np.zeros_like(input_np) | |
Feature_image_color[:,:,1]=Feature_image | |
Feature_image_color[:,:,2]=Feature_image | |
Feature_image=Feature_image_color | |
color_id+=1 | |
# use Gamma correction | |
Feature_image=np.power(Feature_image,1.3) | |
# use Gamma correction | |
#set each feature to singel color | |
# Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) | |
input_np=cv2.cvtColor(input_np, cv2.COLOR_BGR2GRAY) | |
input_np=cv2.cvtColor(input_np,cv2.COLOR_GRAY2BGR) | |
Feature_image=0.2*Feature_image+0.8*input_np | |
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) | |
Feature_image=Feature_image.astype(np.uint8) | |
# path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg" | |
# cv2.imwrite(path_to_featureimage,Feature_image) | |
Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR) | |
Feature_image_list.append(Feature_image) | |
print("len of Features:",len(Interpretable_Features)) | |
return Feature_image_list | |
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True): | |
n_classes = dataset_constants[dataset]["num_classes"] | |
# image_re=np.array(input) | |
input=Image.fromarray(input) | |
print("input shape",input.size) | |
model = get_model(arch, n_classes, reduced_strides) | |
tr=transform_input_img(input,img_size) | |
# tr=transforms.Compose([ | |
# transforms.Resize(500), | |
# transforms.CenterCrop(img_size), | |
# transforms.ToTensor(), | |
# ]) | |
#TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
device = torch.device("cpu") | |
if folder is None: | |
folder = Path(f"tmp/{arch}/{dataset}/{seed}/") | |
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth",map_location=torch.device('cpu'))) | |
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth", map_location=torch.device('cpu')) | |
selection= torch.load(folder / f"SlDD_Selection_50.pt", map_location=torch.device('cpu')) | |
state_dict['linear.selection']=selection | |
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) | |
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) | |
model.load_state_dict(state_dict) | |
input = tr(input) | |
# path_to_input="/home/qixuan/tmp/FeatureImage/croped.jpg" | |
# path_to_input_re="/home/qixuan/tmp/FeatureImage/re.jpg" | |
# path_to_input_concat="/home/qixuan/tmp/FeatureImage/concate.jpg" | |
# image_re=cv2.cvtColor(image_re, cv2.COLOR_RGB2BGR) | |
# image_re=cv2.resize(image_re,(448,448)) | |
# image_np = (input * 255).clamp(0, 255).byte() | |
# image_np = image_np.permute(1, 2, 0).numpy() | |
# image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
# print("????",input.shape) | |
# concat=np.vstack((image_re, image_np)) | |
# cv2.imwrite(path_to_input,image_np) | |
# cv2.imwrite(path_to_input_re,image_re) | |
# cv2.imwrite(path_to_input_concat,concat) | |
input= input.unsqueeze(0) | |
input= input.to(device) | |
model = model.to(device) | |
model.eval() | |
with torch.no_grad(): | |
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) | |
print("featuresmap size:",feature_maps.size()) | |
output_np=output.detach().cpu().numpy() | |
output_np= np.argmax(output_np)+1 | |
if with_featuremaps: | |
return output_np,model,feature_maps | |
else: | |
return output_np, model | |
def get_options_from_trainingset(output, model, TR, device,with_other_class): | |
print("outputclass:",output) | |
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") | |
labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target']) | |
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) | |
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) | |
options_output=labels[labels['target']==output] | |
print(options_output) | |
print(labels) | |
options=options_output.sample(4) | |
#mode 2 | |
if with_other_class: | |
other_targets=random.sample([i for i in range(1,200)if i != output],3) | |
all_targets=[output]+other_targets | |
for tg in other_targets: | |
others=labels[labels['target']==tg] | |
options_others=others.sample(4) | |
options = pd.concat([options, options_others], ignore_index=True) | |
else: | |
all_targets=[output] | |
#shuffled_options = options.sample(frac=1).reset_index(drop=True) | |
print("shuffled:",options) | |
op=[] | |
# resample_img_id_list=[]#resample filter | |
W=model.linear.layer.weight# intergrate negative features | |
model.eval() | |
with torch.no_grad(): | |
for t in all_targets: | |
options_class=options[options['target']==t] | |
op_class=[] | |
# intergrate negative features | |
W_class=W[t-1,:] | |
features_id=[ f for f in W_class if f !=0 ] | |
features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0] | |
# intergrate negative features | |
for i in options_class['img_id']: | |
print(i) | |
filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] | |
targets=options.loc[options['img_id']==i,'target'].values[0] | |
print("targets",targets) | |
print("name",filenames) | |
classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] | |
print(data_dir/f"images/{filenames}") | |
op_img=cv2.imread(data_dir/f"images/{filenames}") | |
op_img=cv2.cvtColor(op_img, cv2.COLOR_BGR2RGB) | |
op_imag=Image.fromarray(op_img) | |
op_images=TR(op_imag) | |
op_images=op_images.unsqueeze(0) | |
op_images=op_images.to(device) | |
OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) | |
#ues diversity filter | |
# weight=model.linear.layer.weight | |
# div=filter_with_diversity(feature_maps_op,OP,weight) | |
# DIV=0.8 | |
# while div<=DIV: | |
# options_class_set_for_resample=labels[(labels['target']==t ) | |
# & (~labels['img_id'].isin(options['img_id'])) | |
# &(~labels['img_id'].isin(resample_img_id_list))] | |
# if len(options_class_set_for_resample)<=0: | |
# resample_img_id_list=[] | |
# DIV-=0.1 | |
# continue | |
# resample=options_class_set_for_resample.sample(1) | |
# # print("resample:",resample) | |
# img_id_re=resample.iloc[0]['img_id'] | |
# resample_img_id_list.append(img_id_re) | |
# filenames_re=namelist.loc[namelist['img_id']==img_id_re,'file_name'].values[0] | |
# op_img_re=cv2.imread(data_dir/f"images/{filenames_re}") | |
# op_img_re=cv2.cvtColor(op_img_re, cv2.COLOR_BGR2RGB) | |
# op_imag_re=Image.fromarray(op_img_re) | |
# op_images_re=TR(op_imag_re) | |
# op_images_re=op_images_re.unsqueeze(0) | |
# op_images_re=op_images_re.to(device) | |
# OP_re, feature_maps_op_re=model(op_images_re,with_feature_maps=True,with_final_features=False) | |
# div=filter_with_diversity(feature_maps_op_re,OP_re,weight) | |
#ues diversity filter | |
print("OP:",OP, | |
"feature_maps_op:",feature_maps_op.shape) | |
opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) | |
image_arrays = [np.array(img) for img in opt] | |
concatenated_image = np.concatenate(image_arrays, axis=0) | |
op_class.append(concatenated_image) | |
op_class_arrays=[np.array(img)for img in op_class] | |
concatenate_class=np.concatenate(op_class_arrays, axis=1) | |
op.append((concatenate_class,features_id_neg))# intergrate negative features | |
return op | |
def transform_input_img(input,img_size): | |
h,w=input.size | |
rate=h/w | |
if h >= w: | |
w_new=img_size | |
h_new=int(w_new*rate) | |
else: | |
h_new=img_size | |
w_new=int(h_new/rate) | |
return transforms.Compose([ | |
transforms.Resize((w_new,h_new)), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
]) | |
def post_next_image(OPT: str,key:str): | |
if OPT==key: | |
return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) | |
else: | |
return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) | |
def get_features_on_interface(input): | |
img_size=448 | |
output,model=genreate_intepriable_output(input,dataset="CUB2011", | |
arch="resnet50",seed=123456, | |
model_type="qsenn", n_features = 50,n_per_class=5, | |
img_size=448, reduced_strides=False, folder = None,with_featuremaps=False) | |
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
device = torch.device("cpu") | |
op= get_options_from_trainingset(output, model, TR, device,with_other_class=True) | |
key=op[0][0]# intergrate negative features | |
random.shuffle(op) | |
option=[(op[0][0],"A"), | |
(op[1][0],"B"), | |
(op[2][0],"C"), | |
(op[3][0],"D")] | |
for value,char in option: | |
if np.array_equal(value,key): | |
key_op=char | |
print("key",key_op) | |
# if op[0][1]!=[]: | |
# option[0][1]=f"A,features{', '.join(map(str, op[0][1]))} are negative." | |
# if op[1][1]!=[]: | |
# option[1][1]=f"B,features{', '.join(map(str, op[0][1]))} are negative." | |
# if op[2][1]!=[]: | |
# option[2][1]=f"C,features{', '.join(map(str, op[0][1]))} are negative." | |
# if op[3][1]!=[]: | |
# option[3][1]=f"D,features{', '.join(map(str, op[0][1]))} are negative." | |
return option, key_op," These are some class explanations from our model for different classes, which of these classes has our model predicted? ",gr.update(interactive=False) | |
def direct_inference(input): | |
img_size=448 | |
output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011", | |
arch="resnet50",seed=123456, | |
model_type="qsenn", n_features = 50,n_per_class=5, | |
img_size=448, reduced_strides=False, folder = None,with_featuremaps=True) | |
# image_list=overlapping_features_on_input(model,output,feature_maps,input,target=None) | |
# image_arrays = [np.array(img) for img in image_list] | |
# concatenated_image = np.concatenate(image_arrays, axis=0) | |
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) | |
device = torch.device("cpu") | |
concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False) | |
#original | |
Input=Image.fromarray(input) | |
tr=transform_input_img(Input,img_size) | |
Input=tr(Input) | |
image_np = (Input * 255).clamp(0, 255).byte() | |
image_np = image_np.permute(1, 2, 0).numpy() | |
# image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) | |
ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output)#input image_np | |
ORI_arrays = [np.array(img) for img in ORI] | |
concatenated_ORI = np.concatenate(ORI_arrays, axis=0) | |
print(concatenated_ORI.shape,concatenated_image[0][0].shape) | |
concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1) | |
print(concatenated_image_final_array.shape) | |
#original | |
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") | |
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) | |
output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0] | |
if concatenated_image[0][1]!=[]: | |
output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative." | |
else: | |
output_name_and_features=f"{output_name}, all features are positive." | |
return concatenated_image_final_array, output_name_and_features | |
def filter_with_diversity(featuremaps,output,weight): | |
localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None) | |
localizer(output.to("cpu"),featuremaps.to("cpu")) | |
locality, exlusive_locality = localizer.get_result() | |
diversity = locality[4] | |
diversity=diversity.item() | |
return diversity | |