Spaces:
Sleeping
Sleeping
import gradio as gr | |
from load_model import extract_sel_mean_std_bias_assignemnt | |
from pathlib import Path | |
from architectures.model_mapping import get_model | |
from configs.dataset_params import dataset_constants | |
import torch | |
import torchvision.transforms as transforms | |
import pandas as pd | |
import cv2 | |
import numpy as np | |
def overlapping_features_on_input(model,output, feature_maps, input, target): | |
W=model.linear.layer.weight | |
output=output.detach().cpu().numpy() | |
feature_maps=feature_maps.detach().cpu().numpy().squeeze() | |
if target !=None: | |
label=target | |
else: | |
label=np.argmax(output)+1 | |
Interpretable_Selection= W[label,:] | |
print("W",Interpretable_Selection) | |
input_np=np.array(input) | |
h,w= input.shape[:2] | |
print("h,w:",h,w) | |
Interpretable_Features=[] | |
Feature_image_list=[] | |
for S in range(len(Interpretable_Selection)): | |
if Interpretable_Selection[S] > 0: | |
Interpretable_Features.append(feature_maps[S]) | |
Feature_image=cv2.resize(feature_maps[S],(w,h)) | |
Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255 | |
Feature_image=Feature_image.astype(np.uint8) | |
Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) | |
Feature_image=0.3*Feature_image+0.7*input_np | |
Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8) | |
Feature_image_list.append(Feature_image) | |
#path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg" | |
#cv2.imwrite(path_to_featureimage,Feature_image) | |
print("len of Features:",len(Interpretable_Features)) | |
return Feature_image_list | |
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): | |
n_classes = dataset_constants[dataset]["num_classes"] | |
model = get_model(arch, n_classes, reduced_strides) | |
tr=transforms.ToTensor() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
if folder is None: | |
folder = Path(f"tmp/{arch}/{dataset}/{seed}/") | |
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") | |
selection= torch.load(folder / f"SlDD_Selection_50.pt") | |
state_dict['linear.selection']=selection | |
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) | |
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) | |
model.load_state_dict(state_dict) | |
input = tr(input) | |
input= input.unsqueeze(0) | |
input= input.to(device) | |
model = model.to(device) | |
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) | |
print("final features:",final_features) | |
output=output.detach().cpu().numpy() | |
output= np.argmax(output)+1 | |
print("outputclass:",output) | |
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") | |
labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target']) | |
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) | |
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) | |
options_output=labels[labels['target']==output] | |
options_output=options_output.sample(1) | |
others=labels[labels['target']!=output] | |
options_others=others.sample(3) | |
options = pd.concat([options_others, options_output], ignore_index=True) | |
shuffled_options = options.sample(frac=1).reset_index(drop=True) | |
print("shuffled:",shuffled_options) | |
op=[] | |
for i in shuffled_options['img_id']: | |
print(i) | |
filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] | |
targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0] | |
print("targets",targets) | |
print("name",filenames) | |
classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] | |
print(data_dir/f"images/{filenames}") | |
op_img=cv2.imread(data_dir/f"images/{filenames}") | |
op_images=tr(op_img) | |
op_images=op_images.unsqueeze(0) | |
op_images=op_images.to(device) | |
OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) | |
print("OP:",OP, | |
"feature_maps_op:",feature_maps_op.shape) | |
opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) | |
op+=opt | |
return op | |
def post_next_image(op): | |
if len(op)<=1: | |
return [],None, "all done, thank you!" | |
else: | |
op=op[1:len(op)] | |
return op,op[0], "Is this feature also in your input?" | |
def get_features_on_interface(input): | |
op=genreate_intepriable_output(input,dataset="CUB2011", | |
arch="resnet50",seed=123456, | |
model_type="qsenn", n_features = 50,n_per_class=5, | |
img_size=448, reduced_strides=False, folder = None) | |
return op, op[0],"Is this feature also in your input?",gr.update(interactive=False) | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align: center;'>Interiable Bird Classification</h1>") | |
image_input=gr.Image() | |
image_output=gr.Image() | |
text_output=gr.Markdown() | |
but_generate=gr.Button("Get some interpriable Features") | |
but_feedback_y=gr.Button("Yes") | |
but_feedback_n=gr.Button("No") | |
image_list = gr.State([]) | |
but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate]) | |
but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) | |
but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) | |
demo.launch() | |