File size: 6,349 Bytes
40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e 061a0da 40e7aed 345909a 40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e ea385b8 40e7aed 09b357e 40e7aed 09b357e 40e7aed 09b357e 40e7aed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import gradio as gr
from load_model import extract_sel_mean_std_bias_assignemnt
from pathlib import Path
from architectures.model_mapping import get_model
from configs.dataset_params import dataset_constants
import torch
import torchvision.transforms as transforms
import pandas as pd
import cv2
import numpy as np
from PIL import Image
from get_data import get_augmentation
from configs.dataset_params import normalize_params
def overlapping_features_on_input(model,output, feature_maps, input, target):
W=model.linear.layer.weight
output=output.detach().cpu().numpy()
feature_maps=feature_maps.detach().cpu().numpy().squeeze()
if target !=None:
label=target
else:
label=np.argmax(output)+1
Interpretable_Selection= W[label,:]
print("W",Interpretable_Selection)
input_np=np.array(input)
h,w= input.shape[:2]
print("h,w:",h,w)
Interpretable_Features=[]
Feature_image_list=[]
for S in range(len(Interpretable_Selection)):
if Interpretable_Selection[S] > 0:
Interpretable_Features.append(feature_maps[S])
Feature_image=cv2.resize(feature_maps[S],(w,h))
Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255
Feature_image=Feature_image.astype(np.uint8)
Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)
Feature_image=0.3*Feature_image+0.7*input_np
Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8)
Feature_image_list.append(Feature_image)
#path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
#cv2.imwrite(path_to_featureimage,Feature_image)
print("len of Features:",len(Interpretable_Features))
return Feature_image_list
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
n_classes = dataset_constants[dataset]["num_classes"]
model = get_model(arch, n_classes, reduced_strides)
tr=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
device = torch.device("cpu")
if folder is None:
folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
selection= torch.load(folder / f"SlDD_Selection_50.pt")
state_dict['linear.selection']=selection
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
model.load_state_dict(state_dict)
input=Image.fromarray(input)
input = tr(input)
input= input.unsqueeze(0)
input= input.to(device)
model = model.to(device)
model.eval()
with torch.no_grad():
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
print("final features:",final_features)
output=output.detach().cpu().numpy()
output= np.argmax(output)+1
print("outputclass:",output)
data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/"
labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target'])
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
options_output=labels[labels['target']==output]
options_output=options_output.sample(1)
others=labels[labels['target']!=output]
options_others=others.sample(3)
options = pd.concat([options_others, options_output], ignore_index=True)
shuffled_options = options.sample(frac=1).reset_index(drop=True)
print("shuffled:",shuffled_options)
op=[]
for i in shuffled_options['img_id']:
filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0]
targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0]
classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0]
op_img=cv2.imread(data_dir/f"images/{filenames}")
op_imag=Image.fromarray(op_img)
op_images=TR(op_imag)
op_images=op_images.unsqueeze(0)
op_images=op_images.to(device)
OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False)
opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets)
op+=opt
return op
def post_next_image(op):
if len(op)<=1:
return [],None, "all done, thank you!"
else:
op=op[1:len(op)]
return op,op[0], "Is this feature also in your input?"
def get_features_on_interface(input):
op=genreate_intepriable_output(input,dataset="CUB2011",
arch="resnet50",seed=123456,
model_type="qsenn", n_features = 50,n_per_class=5,
img_size=448, reduced_strides=False, folder = None)
return op, op[0],"Is this feature also in your input?",gr.update(interactive=False)
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Interiable Bird Classification</h1>")
image_input=gr.Image()
image_output=gr.Image()
text_output=gr.Markdown()
but_generate=gr.Button("Get some interpriable Features")
but_feedback_y=gr.Button("Yes")
but_feedback_n=gr.Button("No")
image_list = gr.State([])
but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate])
but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
demo.launch()
|