File size: 6,349 Bytes
40e7aed
 
 
 
 
 
 
 
 
 
09b357e
 
 
40e7aed
 
 
 
 
 
 
 
 
 
 
 
09b357e
40e7aed
 
09b357e
40e7aed
 
09b357e
40e7aed
 
09b357e
40e7aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09b357e
 
 
 
 
 
061a0da
40e7aed
345909a
40e7aed
09b357e
 
40e7aed
 
 
 
 
09b357e
40e7aed
 
 
 
 
09b357e
 
 
 
 
 
ea385b8
40e7aed
09b357e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40e7aed
09b357e
 
40e7aed
09b357e
 
 
 
 
 
 
 
 
 
40e7aed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
from load_model import extract_sel_mean_std_bias_assignemnt
from pathlib import Path
from architectures.model_mapping import get_model
from configs.dataset_params import dataset_constants
import torch
import torchvision.transforms as transforms
import pandas as pd
import cv2
import numpy as np
from PIL import Image
from get_data import get_augmentation
from configs.dataset_params import normalize_params

def overlapping_features_on_input(model,output, feature_maps, input, target):
    W=model.linear.layer.weight
    output=output.detach().cpu().numpy()
    feature_maps=feature_maps.detach().cpu().numpy().squeeze()

    if target !=None:
     label=target
    else:
     label=np.argmax(output)+1

    Interpretable_Selection= W[label,:]
    print("W",Interpretable_Selection)
    input_np=np.array(input)
    h,w= input.shape[:2]
    print("h,w:",h,w)
    Interpretable_Features=[]
    Feature_image_list=[]
    
    for S in range(len(Interpretable_Selection)):
        if Interpretable_Selection[S] > 0:

               Interpretable_Features.append(feature_maps[S])
               Feature_image=cv2.resize(feature_maps[S],(w,h))
               Feature_image=((Feature_image-np.min(Feature_image))/(np.max(Feature_image)-np.min(Feature_image)))*255
               Feature_image=Feature_image.astype(np.uint8)
               Feature_image=cv2.applyColorMap(Feature_image,cv2.COLORMAP_JET)
               Feature_image=0.3*Feature_image+0.7*input_np
               Feature_image=np.clip(Feature_image, 0, 255).astype(np.uint8)
               Feature_image_list.append(Feature_image)
               #path_to_featureimage=f"/home/qixuan/tmp/FeatureImage/FI{S}.jpg"
               #cv2.imwrite(path_to_featureimage,Feature_image)
    print("len of Features:",len(Interpretable_Features))

    return Feature_image_list


def genreate_intepriable_output(input,dataset="CUB2011", arch="resnet50",seed=123456, model_type="qsenn", n_features = 50, n_per_class=5, img_size=448, reduced_strides=False, folder = None):
    n_classes = dataset_constants[dataset]["num_classes"]
  
    model = get_model(arch, n_classes, reduced_strides)
    tr=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])
    TR=get_augmentation(0.1, img_size, False, False, True, True, normalize_params["CUB2011"])
    device = torch.device("cpu")
    if folder is None:
        folder = Path(f"tmp/{arch}/{dataset}/{seed}/")
    
    state_dict = torch.load(folder / f"{model_type}_{n_features}_{n_per_class}_FinetunedModel.pth")
    selection= torch.load(folder / f"SlDD_Selection_50.pt")
    state_dict['linear.selection']=selection
    
    feature_sel, sparse_layer, current_mean, current_std, bias_sparse = extract_sel_mean_std_bias_assignemnt(state_dict)
    model.set_model_sldd(feature_sel, sparse_layer, current_mean, current_std, bias_sparse)
    model.load_state_dict(state_dict)
    input=Image.fromarray(input)

    input = tr(input)
    input= input.unsqueeze(0)
    input= input.to(device)
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        output, feature_maps, final_features = model(input, with_feature_maps=True, with_final_features=True)
        print("final features:",final_features)
        output=output.detach().cpu().numpy()
        output= np.argmax(output)+1
        

        print("outputclass:",output)
        data_dir=Path.home()/"tmp/Datasets/CUB200/CUB_200_2011/"
        labels = pd.read_csv(data_dir/"image_class_labels.txt", sep=' ', names=['img_id', 'target'])
        namelist=pd.read_csv(data_dir/"images.txt",sep=' ',names=['img_id','file_name'])
        classlist=pd.read_csv(data_dir/"classes.txt",sep=' ',names=['cl_id','class_name'])
        options_output=labels[labels['target']==output]
        options_output=options_output.sample(1)
        others=labels[labels['target']!=output]
        options_others=others.sample(3)
        options = pd.concat([options_others, options_output], ignore_index=True)
        shuffled_options = options.sample(frac=1).reset_index(drop=True)
        print("shuffled:",shuffled_options)
        op=[]


        for i in shuffled_options['img_id']:

            filenames=namelist.loc[namelist['img_id']==i,'file_name'].values[0]
            targets=shuffled_options.loc[shuffled_options['img_id']==i,'target'].values[0]


            classes=classlist.loc[classlist['cl_id']==targets, 'class_name'].values[0]
            

            op_img=cv2.imread(data_dir/f"images/{filenames}")
            
            op_imag=Image.fromarray(op_img)
            op_images=TR(op_imag)
            op_images=op_images.unsqueeze(0)
            op_images=op_images.to(device)
            OP, feature_maps_op =model(op_images,with_feature_maps=True,with_final_features=False)
      
            opt= overlapping_features_on_input(model,OP, feature_maps_op,op_img,targets)
            op+=opt
    
    return op

def post_next_image(op):
    if len(op)<=1:
        return [],None, "all done, thank you!"
    else:
        op=op[1:len(op)]
        return op,op[0], "Is this feature also in your input?"

def get_features_on_interface(input):
    op=genreate_intepriable_output(input,dataset="CUB2011", 
                                arch="resnet50",seed=123456, 
                                model_type="qsenn", n_features = 50,n_per_class=5,
                                img_size=448, reduced_strides=False, folder = None)
    return op, op[0],"Is this feature also in your input?",gr.update(interactive=False)


with gr.Blocks() as demo:

    gr.Markdown("<h1 style='text-align: center;'>Interiable Bird Classification</h1>")
    image_input=gr.Image()
    image_output=gr.Image()
    text_output=gr.Markdown()
    but_generate=gr.Button("Get some interpriable Features")
    but_feedback_y=gr.Button("Yes")
    but_feedback_n=gr.Button("No")
    image_list = gr.State([])
    but_generate.click(fn=get_features_on_interface, inputs=image_input, outputs=[image_list,image_output,text_output,but_generate])
    but_feedback_y.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])
    but_feedback_n.click(fn=post_next_image, inputs=image_list, outputs=[image_list,image_output,text_output])

demo.launch()