|
import gradio as gr |
|
from load_model import extract_sel_mean_std_bias_assignemnt |
|
from pathlib import Path |
|
from architectures.model_mapping import get_model |
|
from configs.dataset_params import dataset_constants |
|
import torch |
|
import torchvision.transforms as transforms |
|
import pandas as pd |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
from get_data import get_augmentation |
|
from configs.dataset_params import normalize_params |
|
|
|
def overlapping_features_on_input(model,output, feature_maps, input, target): |
|
W=model.linear.layer.weight |
|
output=output.detach().cpu().numpy() |
|
feature_maps=feature_maps.detach().cpu().numpy().squeeze() |
|
|
|
if target !=None: |
|
label=target |
|
else: |
|
label=np.argmax(output)+1 |
|
|
|
Interpretable_Selection= W[label,:] |
|
print("W",Interpretable_Selection) |
|
input_np=np.array(input) |
|
h,w= input.shape[:2] |
|
print("h,w:",h,w) |
|
Interpretable_Features=[] |
|
Feature_image_list=[] |
|
|
|
for S in range(len(Interpretable_Selection)): |
|
if Interpretable_Selection[S] > 0: |
|
|
|
Interpretable_Features.append(feature_maps[S]) |
|
Feature_image=cv2.resize(feature_maps[S],(w,h)) |
|
Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255 |
|
Feature_image=Feature_image.astype(np.uint8) |
|
Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET) |
|
Feature_image=0.3*Feature_image+0.7*input_np |
|
Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8) |
|
Feature_image_list.append(Feature_image) |
|
|
|
|
|
print("len of Features:",len(Interpretable_Features)) |
|
|
|
return Feature_image_list |
|
|
|
|
|
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None): |
|
n_classes = dataset_constants[dataset]["num_classes"] |
|
|
|
model = get_model(arch, n_classes, reduced_strides) |
|
tr=transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
]) |
|
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) |
|
device = torch.device("cpu") |
|
if folder is None: |
|
folder = Path(f"tmp/{arch}/{dataset}/{seed}/") |
|
|
|
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") |
|
selection= torch.load(folder / f"SlDD_Selection_50.pt") |
|
state_dict['linear.selection']=selection |
|
|
|
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) |
|
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) |
|
model.load_state_dict(state_dict) |
|
input=Image.fromarray(input) |
|
|
|
input = tr(input) |
|
input= input.unsqueeze(0) |
|
input= input.to(device) |
|
model = model.to(device) |
|
model.eval() |
|
with torch.no_grad(): |
|
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) |
|
print("final features:",final_features) |
|
output=output.detach().cpu().numpy() |
|
output= np.argmax(output)+1 |
|
|
|
|
|
print("outputclass:",output) |
|
data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/" |
|
labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target']) |
|
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) |
|
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) |
|
options_output=labels[labels['target']==output] |
|
options_output=options_output.sample(1) |
|
others=labels[labels['target']!=output] |
|
options_others=others.sample(3) |
|
options = pd.concat([options_others, options_output], ignore_index=True) |
|
shuffled_options = options.sample(frac=1).reset_index(drop=True) |
|
print("shuffled:",shuffled_options) |
|
op=[] |
|
|
|
|
|
for i in shuffled_options['img_id']: |
|
|
|
filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] |
|
targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0] |
|
|
|
|
|
classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] |
|
|
|
|
|
op_img=cv2.imread(data_dir/f"images/{filenames}") |
|
|
|
op_imag=Image.fromarray(op_img) |
|
op_images=TR(op_imag) |
|
op_images=op_images.unsqueeze(0) |
|
op_images=op_images.to(device) |
|
OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) |
|
|
|
opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) |
|
op+=opt |
|
|
|
return op |
|
|
|
def post_next_image(op): |
|
if len(op)<=1: |
|
return [],None, "all done, thank you!" |
|
else: |
|
op=op[1:len(op)] |
|
return op,op[0], "Is this feature also in your input?" |
|
|
|
def get_features_on_interface(input): |
|
op=genreate_intepriable_output(input,dataset="CUB2011", |
|
arch="resnet50",seed=123456, |
|
model_type="qsenn", n_features = 50,n_per_class=5, |
|
img_size=448, reduced_strides=False, folder = None) |
|
return op, op[0],"Is this feature also in your input?",gr.update(interactive=False) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("<h1 style='text-align: center;'>Interiable Bird Classification</h1>") |
|
image_input=gr.Image() |
|
image_output=gr.Image() |
|
text_output=gr.Markdown() |
|
but_generate=gr.Button("Get some interpriable Features") |
|
but_feedback_y=gr.Button("Yes") |
|
but_feedback_n=gr.Button("No") |
|
image_list = gr.State([]) |
|
but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate]) |
|
but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) |
|
but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output]) |
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|