|
import gradio as gr |
|
from load_model import extract_sel_mean_std_bias_assignemnt |
|
from pathlib import Path |
|
from architectures.model_mapping import get_model |
|
from configs.dataset_params import dataset_constants |
|
import torch |
|
import torchvision.transforms as transforms |
|
import pandas as pd |
|
import cv2 |
|
import numpy as np |
|
from PIL import Image |
|
from get_data import get_augmentation |
|
from configs.dataset_params import normalize_params |
|
import random |
|
from evaluation.diversity import MultiKCrossChannelMaxPooledSum |
|
|
|
def overlapping_features_on_input(model,output, feature_maps, input, target): |
|
W=model.linear.layer.weight |
|
feature_maps=feature_maps.detach().cpu().numpy().squeeze() |
|
print("feature_maps",feature_maps.shape) |
|
|
|
if target !=None: |
|
label=target-1 |
|
else: |
|
output=output.detach().cpu().numpy() |
|
label=np.argmax(output) |
|
|
|
Interpretable_Selection= W[label,:] |
|
print("W",Interpretable_Selection) |
|
input_np=np.array(input) |
|
h,w= input.shape[:2] |
|
print("h,w:",h,w) |
|
Interpretable_Features=[] |
|
|
|
input_np=cv2.resize(input_np,(448,448)) |
|
Feature_image_list=[input_np] |
|
|
|
color_id=0 |
|
COLOR=['R','G','B','Y','P','C'] |
|
|
|
|
|
for S in range(len(Interpretable_Selection)): |
|
if Interpretable_Selection[S] > 0: |
|
Interpretable_Features.append(feature_maps[S]) |
|
Feature_image=cv2.resize(feature_maps[S],(448,448)) |
|
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) |
|
Feature_image=Feature_image.astype(np.uint8) |
|
|
|
|
|
|
|
|
|
if color_id>len(COLOR)-1: |
|
color_id=color_id%len(COLOR) |
|
|
|
color=COLOR[color_id] |
|
if color == 'R': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,0]=Feature_image |
|
Feature_image=Feature_image_color |
|
if color == 'G': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,1]=Feature_image |
|
Feature_image=Feature_image_color |
|
if color == 'B': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,2]=Feature_image |
|
Feature_image=Feature_image_color |
|
if color == 'Y': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,0]=Feature_image |
|
Feature_image_color[:,:,1]=Feature_image |
|
Feature_image=Feature_image_color |
|
if color == 'P': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,0]=Feature_image |
|
Feature_image_color[:,:,2]=Feature_image |
|
Feature_image=Feature_image_color |
|
if color == 'C': |
|
Feature_image_color=np.zeros_like(input_np) |
|
Feature_image_color[:,:,1]=Feature_image |
|
Feature_image_color[:,:,2]=Feature_image |
|
Feature_image=Feature_image_color |
|
|
|
color_id+=1 |
|
|
|
|
|
|
|
|
|
Feature_image=np.power(Feature_image,1.3) |
|
|
|
|
|
|
|
|
|
|
|
input_np=cv2.cvtColor(input_np, cv2.COLOR_BGR2GRAY) |
|
input_np=cv2.cvtColor(input_np,cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
Feature_image=0.2*Feature_image+0.8*input_np |
|
|
|
|
|
|
|
|
|
Feature_image=np.uint((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)) * 255) |
|
Feature_image=Feature_image.astype(np.uint8) |
|
|
|
|
|
Feature_image = cv2.cvtColor(Feature_image, cv2.COLOR_RGB2BGR) |
|
Feature_image_list.append(Feature_image) |
|
|
|
print("len of Features:",len(Interpretable_Features)) |
|
|
|
return Feature_image_list |
|
|
|
|
|
def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None, with_featuremaps=True): |
|
n_classes = dataset_constants[dataset]["num_classes"] |
|
|
|
|
|
|
|
input=Image.fromarray(input) |
|
print("input shape",input.size) |
|
|
|
model = get_model(arch, n_classes, reduced_strides) |
|
tr=transform_input_img(input,img_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cpu") |
|
if folder is None: |
|
folder = Path(f"tmp/{arch}/{dataset}/{seed}/") |
|
model.load_state_dict(torch.load(folder / "Trained_DenseModel.pth")) |
|
state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth") |
|
selection= torch.load(folder / f"SlDD_Selection_50.pt") |
|
state_dict['linear.selection']=selection |
|
|
|
feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict) |
|
model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse) |
|
model.load_state_dict(state_dict) |
|
|
|
input = tr(input) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input= input.unsqueeze(0) |
|
input= input.to(device) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
with torch.no_grad(): |
|
output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True) |
|
print("featuresmap size:",feature_maps.size()) |
|
output_np=output.detach().cpu().numpy() |
|
output_np= np.argmax(output_np)+1 |
|
|
|
if with_featuremaps: |
|
return output_np,model,feature_maps |
|
else: |
|
return output_np, model |
|
|
|
def get_options_from_trainingset(output, model, TR, device,with_other_class): |
|
print("outputclass:",output) |
|
data_dir=Path("tmp/Datasets/CUB200/CUB_200_2011/") |
|
labels = pd.read_csv("image_class_labels.txt", sep=' ', names=['img_id', 'target']) |
|
namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name']) |
|
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) |
|
options_output=labels[labels['target']==output] |
|
print(options_output) |
|
print(labels) |
|
options=options_output.sample(4) |
|
|
|
|
|
if with_other_class: |
|
other_targets=random.sample([i for i in range(1,200)if i != output],3) |
|
all_targets=[output]+other_targets |
|
for tg in other_targets: |
|
others=labels[labels['target']==tg] |
|
options_others=others.sample(4) |
|
options = pd.concat([options, options_others], ignore_index=True) |
|
else: |
|
all_targets=[output] |
|
|
|
|
|
|
|
print("shuffled:",options) |
|
op=[] |
|
|
|
W=model.linear.layer.weight |
|
model.eval() |
|
with torch.no_grad(): |
|
for t in all_targets: |
|
options_class=options[options['target']==t] |
|
op_class=[] |
|
|
|
|
|
W_class=W[t-1,:] |
|
features_id=[ f for f in W_class if f !=0 ] |
|
features_id_neg= [i+1 for i, x in enumerate(features_id) if x < 0] |
|
|
|
|
|
|
|
for i in options_class['img_id']: |
|
print(i) |
|
filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0] |
|
targets=options.loc[options['img_id']==i,'target'].values[0] |
|
print("targets",targets) |
|
print("name",filenames) |
|
|
|
classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0] |
|
print(data_dir/f"images/{filenames}") |
|
|
|
op_img=cv2.imread(data_dir/f"images/{filenames}") |
|
op_img=cv2.cvtColor(op_img, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
op_imag=Image.fromarray(op_img) |
|
op_images=TR(op_imag) |
|
op_images=op_images.unsqueeze(0) |
|
op_images=op_images.to(device) |
|
OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("OP:",OP, |
|
"feature_maps_op:",feature_maps_op.shape) |
|
opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets) |
|
image_arrays = [np.array(img) for img in opt] |
|
concatenated_image = np.concatenate(image_arrays, axis=0) |
|
op_class.append(concatenated_image) |
|
|
|
op_class_arrays=[np.array(img)for img in op_class] |
|
concatenate_class=np.concatenate(op_class_arrays, axis=1) |
|
op.append((concatenate_class,features_id_neg)) |
|
return op |
|
|
|
def transform_input_img(input,img_size): |
|
h,w=input.size |
|
rate=h/w |
|
if h >= w: |
|
w_new=img_size |
|
h_new=int(w_new*rate) |
|
|
|
else: |
|
h_new=img_size |
|
w_new=int(h_new/rate) |
|
|
|
return transforms.Compose([ |
|
transforms.Resize((w_new,h_new)), |
|
transforms.CenterCrop(img_size), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
|
|
|
|
def post_next_image(OPT: str,key:str): |
|
if OPT==key: |
|
return ("Congradulations! you can simulate the prediction of Model this time",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) |
|
else: |
|
return (f"sorry, what the model predicted is {key}",gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False),gr.update(interactive=False)) |
|
|
|
|
|
|
|
def get_features_on_interface(input): |
|
img_size=448 |
|
output,model=genreate_intepriable_output(input,dataset="CUB2011", |
|
arch="resnet50",seed=123456, |
|
model_type="qsenn", n_features = 50,n_per_class=5, |
|
img_size=448, reduced_strides=False, folder = None,with_featuremaps=False) |
|
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) |
|
device = torch.device("cpu") |
|
op= get_options_from_trainingset(output, model, TR, device,with_other_class=True) |
|
key=op[0][0] |
|
random.shuffle(op) |
|
option=[(op[0][0],"A"), |
|
(op[1][0],"B"), |
|
(op[2][0],"C"), |
|
(op[3][0],"D")] |
|
for value,char in option: |
|
if np.array_equal(value,key): |
|
key_op=char |
|
print("key",key_op) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return option, key_op," These are some class explanations from our model for different classes,which of these classes has our model predicted? ",gr.update(interactive=False) |
|
|
|
def direct_inference(input): |
|
img_size=448 |
|
output, model,feature_maps=genreate_intepriable_output(input,dataset="CUB2011", |
|
arch="resnet50",seed=123456, |
|
model_type="qsenn", n_features = 50,n_per_class=5, |
|
img_size=448, reduced_strides=False, folder = None,with_featuremaps=True) |
|
|
|
|
|
|
|
TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"]) |
|
device = torch.device("cpu") |
|
concatenated_image=get_options_from_trainingset(output, model, TR, device, with_other_class=False) |
|
|
|
|
|
|
|
Input=Image.fromarray(input) |
|
tr=transform_input_img(Input,img_size) |
|
Input=tr(Input) |
|
image_np = (Input * 255).clamp(0, 255).byte() |
|
image_np = image_np.permute(1, 2, 0).numpy() |
|
|
|
|
|
|
|
ORI= overlapping_features_on_input(model,output, feature_maps, image_np,output) |
|
ORI_arrays = [np.array(img) for img in ORI] |
|
concatenated_ORI = np.concatenate(ORI_arrays, axis=0) |
|
|
|
print(concatenated_ORI.shape,concatenated_image[0][0].shape) |
|
concatenated_image_final_array=np.concatenate((concatenated_ORI,concatenated_image[0][0]),axis=1) |
|
print(concatenated_image_final_array.shape) |
|
|
|
|
|
|
|
data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/" |
|
classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name']) |
|
output_name=classlist.loc[classlist['cl_id']==output,'class_name'].values[0] |
|
if concatenated_image[0][1]!=[]: |
|
output_name_and_features=f"{output_name}, features{', '.join(map(str, concatenated_image[0][1]))} are negative." |
|
else: |
|
output_name_and_features=f"{output_name}, all features are positive." |
|
|
|
|
|
return concatenated_image_final_array, output_name_and_features |
|
|
|
def filter_with_diversity(featuremaps,output,weight): |
|
localizer = MultiKCrossChannelMaxPooledSum(range(1, 6), weight, None) |
|
localizer(output.to("cpu"),featuremaps.to("cpu")) |
|
|
|
locality, exlusive_locality = localizer.get_result() |
|
diversity = locality[4] |
|
diversity=diversity.item() |
|
return diversity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|