File size: 18,279 Bytes
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d944c
 
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d944c
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00d944c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c0a35a
 
4dca37a
00d944c
4dca37a
 
 
 
00d944c
4dca37a
 
 
 
00d944c
 
 
 
4dca37a
 
 
 
 
 
 
f53377b
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a13b7
 
 
 
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a13b7
4dca37a
 
 
 
 
 
 
 
 
55c3fad
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c0a35a
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f53377b
4dca37a
 
 
 
 
f53377b
 
 
 
4dca37a
 
 
 
 
 
 
 
 
 
64a13b7
4dca37a
 
 
 
 
f53377b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a13b7
f53377b
4dca37a
64a13b7
 
4dca37a
64a13b7
4dca37a
f53377b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dca37a
f53377b
 
 
 
 
 
 
 
 
4dca37a
 
 
 
 
 
 
 
 
 
64a13b7
4dca37a
 
 
 
 
 
 
 
 
f53377b
8c0a35a
9e49331
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f53377b
 
 
4dca37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64a13b7
4dca37a
f53377b
4dca37a
f53377b
4dca37a
 
 
 
 
f53377b
 
 
4dca37a
 
 
 
 
f53377b
 
 
 
 
 
 
4dca37a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import cv2
import matplotlib

import clip

from utils import *

matplotlib.get_cachedir()
plt.rc('font', family="Times New Roman")
from sklearn import metrics

import torch
from torchvision import transforms

from tqdm import tqdm

from models.submodular_vit_efficient_plus import MultiModalSubModularExplanationEfficientPlus

data_transform = transforms.Compose(
    [
        transforms.Resize(
            (224,224), interpolation=transforms.InterpolationMode.BICUBIC
        ),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)

class CLIPModel_Super(torch.nn.Module):
    def __init__(self, 
                 type="ViT-L/14", 
                 download_root=None,
                 device = "cuda"):
        super().__init__()
        self.device = device
        self.model, _ = clip.load(type, device=self.device, download_root=download_root)
        
        self.model = self.model.type(torch.float32)
        
    def forward(self, vision_inputs):
        """
        Input:
            vision_inputs: torch.size([B,C,W,H])
        Output:
            embeddings: a d-dimensional vector torch.size([B,d])
        """
        vision_inputs = vision_inputs.type(torch.float32)
        
        with torch.no_grad():
            image_features = self.model.encode_image(vision_inputs)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        
        return image_features
    
def transform_vision_data(image):
    """
    Input:
        image: An image read by opencv [w,h,c]
    Output:
        image: After preproccessing, is a tensor [c,w,h]
    """
    image = Image.fromarray(image)
    image = data_transform(image)
    return image

def zeroshot_classifier(model, classnames, templates, device):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).to(device) #tokenize
            
            with torch.no_grad():
                class_embeddings = model.model.encode_text(texts)

            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights).cuda()
    return zeroshot_weights*100

device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# Instantiate model
vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
vis_model.eval()
vis_model.to(device)
print("load clip model")

semantic_path = "./clip_vitb_imagenet_zeroweights.pt"
if os.path.exists(semantic_path):
    semantic_feature = torch.load(semantic_path, map_location="cpu")
    semantic_feature = semantic_feature.to(device)
    semantic_feature = semantic_feature.type(torch.float32)
else:
    semantic_feature = zeroshot_classifier(vis_model, imagenet_classes, imagenet_templates, device)
    torch.save(semantic_feature, semantic_path)


explainer = MultiModalSubModularExplanationEfficientPlus(
        vis_model, semantic_feature, transform_vision_data, device=device, 
        lambda1=0.01, 
        lambda2=0.05, 
        lambda3=20., 
        lambda4=5.)
explainer.org_semantic_feature = semantic_feature

def add_value_decrease(smdl_mask, json_file):
    single_mask = np.zeros_like(smdl_mask[0].mean(-1))
    
    value_list_1 = np.array(json_file["consistency_score"]) + np.array(json_file["collaboration_score"])
    
    value_list_2 = np.array([json_file["baseline_score"]] + json_file["consistency_score"][:-1]) + np.array([1 - json_file["org_score"]] + json_file["collaboration_score"][:-1])
    
    value_list = value_list_1 - value_list_2
    
    values = []
    value = 0
    for smdl_single_mask, smdl_value in zip(smdl_mask, value_list):
        value = value - abs(smdl_value)
        single_mask[smdl_single_mask.sum(-1)>0] = value
        values.append(value)
    
    attribution_map = single_mask - single_mask.min()
    attribution_map /= attribution_map.max()
    
    return attribution_map, np.array(values)

def visualization(image, submodular_image_set, saved_json_file, index=None, compute_params=True):
    
    attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
    vis_image, heatmap = gen_cam(image, norm_image(attribution_map))
    
    insertion_ours_images = []
    # deletion_ours_images = []

    insertion_image = submodular_image_set[0] - submodular_image_set[0]
    insertion_ours_images.append(insertion_image)
    # deletion_ours_images.append(image - insertion_image)
    for smdl_sub_mask in submodular_image_set[:]:
        insertion_image = insertion_image.copy() + smdl_sub_mask
        insertion_ours_images.append(insertion_image)
        # deletion_ours_images.append(image - insertion_image)

    insertion_ours_images_input_results = np.array([1-saved_json_file["collaboration_score"][-1]] + saved_json_file["consistency_score"])

    if index == None:
        ours_best_index = np.argmax(insertion_ours_images_input_results)
    else:
        ours_best_index = index
    x = [(insertion_ours_image.sum(-1)!=0).sum() / (image.shape[0] * image.shape[1]) for insertion_ours_image in insertion_ours_images]
    i = len(x)

    fig, [ax1, ax2, ax3] = plt.subplots(1,3, gridspec_kw = {'width_ratios':[1, 1, 1.5]}, figsize=(30,8))
    ax1.spines["left"].set_visible(False)
    ax1.spines["right"].set_visible(False)
    ax1.spines["top"].set_visible(False)
    ax1.spines["bottom"].set_visible(False)
    ax1.xaxis.set_visible(False)
    ax1.yaxis.set_visible(False)
    ax1.set_title('Attribution Map', fontsize=54)
    ax1.set_facecolor('white')
    ax1.imshow(vis_image[...,::-1].astype(np.uint8))
    
    ax2.spines["left"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.spines["bottom"].set_visible(False)
    ax2.xaxis.set_visible(True)
    ax2.yaxis.set_visible(False)
    ax2.set_title('Searched Region', fontsize=54)
    ax2.set_facecolor('white')
    ax2.set_xlabel("Confidence {:.4f}".format(insertion_ours_images_input_results[ours_best_index]), fontsize=44)
    ax2.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)

    ax3.set_xlim((0, 1))
    ax3.set_ylim((0, 1))
    
    ax3.set_ylabel('Recognition Score', fontsize=44)
    ax3.set_xlabel('Percentage of image revealed', fontsize=44)
    ax3.tick_params(axis='both', which='major', labelsize=36)

    x_ = x[:i]
    ours_y = insertion_ours_images_input_results[:i]
    ax3.plot(x_, ours_y, color='dodgerblue', linewidth=3.5)  # draw curve
    ax3.set_facecolor('white')
    ax3.spines['bottom'].set_color('black')
    ax3.spines['bottom'].set_linewidth(2.0)
    ax3.spines['top'].set_color('none')
    ax3.spines['left'].set_color('black')
    ax3.spines['left'].set_linewidth(2.0)
    ax3.spines['right'].set_color('none')

    # plt.legend(["Ours"], fontsize=40, loc="upper left")
    ax3.scatter(x_[-1], ours_y[-1], color='dodgerblue', s=54)  # Plot latest point
    # 在曲线下方填充淡蓝色
    ax3.fill_between(x_, ours_y, color='dodgerblue', alpha=0.1)

    kernel = np.ones((3, 3), dtype=np.uint8)
    # ax3.plot([x_[ours_best_index], x_[ours_best_index]], [0, 1], color='red', linewidth=3.5)  # 绘制红色曲线
    ax3.axvline(x=x_[int(ours_best_index)], color='red', linewidth=3.5)  # 绘制红色垂直线

    # Ours
    mask = (image - insertion_ours_images[int(ours_best_index)]).mean(-1)
    mask[mask>0] = 1

    if int(ours_best_index) != 0:
        dilate = cv2.dilate(mask, kernel, 3)
        # erosion = cv2.erode(dilate, kernel, iterations=3)
        # dilate = cv2.dilate(erosion, kernel, 2)
        edge = dilate - mask
        # erosion = cv2.erode(dilate, kernel, iterations=1)

    image_debug = image.copy()

    image_debug[mask>0] = image_debug[mask>0] * 0.5
    if int(ours_best_index) != 0:
        image_debug[edge>0] = np.array([255,0,0])
    ax2.imshow(image_debug)
    
    if compute_params:
        auc = metrics.auc(x, insertion_ours_images_input_results)
    
    ax3.set_title('Insertion Curve', fontsize=54)
    
    fig.tight_layout()
    fig.canvas.draw()
    img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    
    plt.close(fig)  # 关闭图形以释放资源
    
    if compute_params:
        return img_curve, insertion_ours_images_input_results.max(), auc, ours_best_index
    else:
        return img_curve
    
def gen_cam(image, mask):
    """
    Generate heatmap
        :param image: [H,W,C]
        :param mask: [H,W],range 0-1
        :return: tuple(cam,heatmap)
    """
    # Read image
    # image = cv2.resize(cv2.imread(image_path), (224,224))
    # mask->heatmap
    heatmap = cv2.applyColorMap(np.uint8(mask), cv2.COLORMAP_COOL)
    heatmap = np.float32(heatmap)

    # merge heatmap to original image
    cam = 0.5*heatmap + 0.5*np.float32(image)
    return cam, (heatmap).astype(np.uint8)

def norm_image(image):
    """
    Normalization image
    :param image: [H,W,C]
    :return:
    """
    image = image.copy()
    image -= np.max(np.min(image), 0)
    image /= np.max(image)
    image *= 255.
    return np.uint8(image)

def read_image(file_path):
    image = Image.open(file_path)
    image = image.convert("RGB")
    image = image.resize((512,512))
    return np.array(image)

# 使用同一个示例图像 "shark.png"
default_images = {
    # "Default Image": read_image("images/shark.png"),
    "Example: tiger shark": read_image("images/shark.png"),
    "Example: quail": read_image("images/bird.png"),  # 所有选项都使用相同的图片
    "Example: tabby cat or lion": read_image("images/cat_lion.jpeg"),
    "Example: rabbit or duck": read_image("images/rabbit-duck.jpg"),
}

def interpret_image(uploaded_image, slider, text_input):
    # 使用上传的图像(如果有),否则使用生成的图像
    if uploaded_image is not None:
        image = np.array(uploaded_image)
    else:
        return None, 0, 0

    image = cv2.resize(image, (224, 224))
    element_sets_V = SubRegionDivision(image, mode="slico", region_size=40) 
    
    explainer.k = len(element_sets_V)
    
    global submodular_image_set
    global saved_json_file
    
    image_input = explainer.preproccessing_function(image).unsqueeze(0)
    predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
    
    # input
    if text_input == "":
        target_id = predicted_class
    else:
        if text_input in imagenet_classes:
            target_id = imagenet_classes.index(text_input)
        else:
            target_id = -1
            texts = [text_input]
            texts = clip.tokenize(texts).to(device) #tokenize
            
            with torch.no_grad():
                class_embeddings = vis_model.model.encode_text(texts)
                class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
                class_embeddings = class_embeddings.to(device) * 100
            
            explainer.semantic_feature = torch.cat((explainer.org_semantic_feature, class_embeddings), dim=0)
    
    # global im
    submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=target_id)
    
    # attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
    # im, heatmap = gen_cam(image, norm_image(attribution_map))
    
    image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)

    if target_id == -1:
        text_output_class = "This method explains that CLIP is interested in describing \"{}\".".format(text_input)
    else:
        text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])

    text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
    
    explainer.semantic_feature = explainer.org_semantic_feature

    return image_curve, highest_confidence, insertion_auc_score, text_output_class, text_output_predict, None

def predict_image(uploaded_image):
    # 使用上传的图像(如果有),否则使用生成的图像
    if uploaded_image is not None:
        image = np.array(uploaded_image)
    else:
        return None, 0, 0

    image = cv2.resize(image, (224, 224))
    
    image_input = explainer.preproccessing_function(image).unsqueeze(0)
    predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
    
    text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
    
    return text_output_predict
    

def visualization_slider(uploaded_image, slider):
    # 使用上传的图像(如果有),否则使用生成的图像
    if uploaded_image is not None:
        image = np.array(uploaded_image)
    else:
        return None, 0, 0

    image = cv2.resize(image, (224, 224))
    
    image_curve = visualization(image, submodular_image_set, saved_json_file, index=slider, compute_params=False)
    
    return image_curve

def update_image(thumbnail_name):
    # 返回对应缩略图的图像数据
    return default_images[thumbnail_name]

# 创建 Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown("# Semantic Region Attribution and Mistake Discovery via Submodular Subset Selection")  # 使用Markdown添加标题
    
    gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
    with gr.Row():
        with gr.Column():
            # 第一排:上传图像输入框和一个缩略图
            with gr.Row():
                # 上传图像输入框
                image_input = gr.Image(label="Upload Image", type="numpy")
                
                # 第一个缩略图和按钮
                with gr.Column():
                    # gr.Image(value=default_images["Default Image"], type="numpy")
                    # button_default = gr.Button(value="Default Image")
                    # button_default.click(
                    #     fn=lambda k="Default Image": update_image(k),
                    #     inputs=[],
                    #     outputs=image_input
                    # )
                    gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
            
            # 文本输入框和滑块
            text_input = gr.Textbox(label="Text Input", placeholder="You can choose what you want to explain. You can enter a word (e.g., 'Rabbit') or a description (e.g., 'A photo of a rabbit'). If the input is empty, the model will explain the predicted category.")
            
            # 第二排:两个缩略图
            with gr.Row():
                for key in default_images.keys():
                    with gr.Column():
                        gr.Image(value=default_images[key], type="numpy")
                        button = gr.Button(value=key)
                        button.click(
                            fn=lambda k=key: update_image(k),
                            inputs=[],
                            outputs=image_input
                        )
            
        with gr.Column():
            # 输出图像和控件
            image_output = gr.Image(label="Output Image")
            
            slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
            
            text_output_predict = gr.Textbox(label="Predicted Category")
            text_output_class = gr.Textbox(label="Explaining Category")
            
            with gr.Row():
                # 最高置信度和插入 AUC Score 并排显示
                text_output_confidence = gr.Textbox(label="Highest Confidence")
                text_output_auc = gr.Textbox(label="Insertion AUC Score")
            
            with gr.Row():
                predict_button = gr.Button("Model Inference")
                interpret_button = gr.Button("Interpreting Model")

    # 定义解释模型按钮点击事件
    interpret_button.click(
        fn=interpret_image,
        inputs=[image_input, slider, text_input],
        outputs=[image_output, text_output_confidence, text_output_auc, text_output_class, text_output_predict, text_input]
    )
    
    predict_button.click(
        fn=predict_image,
        inputs=[image_input],
        outputs=[text_output_predict]
    )
    
    # 实时更新的滑块
    slider.change(
        fn=visualization_slider,
        inputs=[image_input, slider],
        outputs=[image_output]
    )

# 启动 Gradio 应用
demo.launch()