RuoyuChen commited on
Commit
f53377b
·
1 Parent(s): 9e49331

add model debugging

Browse files
Files changed (3) hide show
  1. app.py +71 -13
  2. images/cat_lion.jpeg +0 -0
  3. images/rabbit-duck.jpg +0 -0
app.py CHANGED
@@ -113,6 +113,7 @@ explainer = MultiModalSubModularExplanationEfficientPlus(
113
  lambda2=0.05,
114
  lambda3=20.,
115
  lambda4=5.)
 
116
 
117
  def add_value_decrease(smdl_mask, json_file):
118
  single_mask = np.zeros_like(smdl_mask[0].mean(-1))
@@ -276,13 +277,16 @@ def norm_image(image):
276
  def read_image(file_path):
277
  image = Image.open(file_path)
278
  image = image.convert("RGB")
 
279
  return np.array(image)
280
 
281
  # 使用同一个示例图像 "shark.png"
282
  default_images = {
283
  # "Default Image": read_image("images/shark.png"),
284
- "Example: Tiger Shark": read_image("images/shark.png"),
285
- "Example: Quail": read_image("images/bird.png") # 所有选项都使用相同的图片
 
 
286
  }
287
 
288
  def interpret_image(uploaded_image, slider, text_input):
@@ -296,21 +300,66 @@ def interpret_image(uploaded_image, slider, text_input):
296
  element_sets_V = SubRegionDivision(image, mode="slico", region_size=40)
297
 
298
  explainer.k = len(element_sets_V)
299
- print(len(element_sets_V))
300
 
301
  global submodular_image_set
302
  global saved_json_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  # global im
304
- submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=None)
305
 
306
  # attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
307
  # im, heatmap = gen_cam(image, norm_image(attribution_map))
308
 
309
  image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)
310
 
311
- text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
- return image_curve, highest_confidence, insertion_auc_score, text_output_class
 
 
 
 
 
 
 
 
314
 
315
  def visualization_slider(uploaded_image, slider):
316
  # 使用上传的图像(如果有),否则使用生成的图像
@@ -331,7 +380,7 @@ def update_image(thumbnail_name):
331
 
332
  # 创建 Gradio 界面
333
  with gr.Blocks() as demo:
334
- gr.Markdown("# Semantic Region Attribution via Submodular Subset Selection") # 使用Markdown添加标题
335
 
336
  gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
337
  with gr.Row():
@@ -352,6 +401,9 @@ with gr.Blocks() as demo:
352
  # )
353
  gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
354
 
 
 
 
355
  # 第二排:两个缩略图
356
  with gr.Row():
357
  for key in default_images.keys():
@@ -363,10 +415,6 @@ with gr.Blocks() as demo:
363
  inputs=[],
364
  outputs=image_input
365
  )
366
-
367
- # 文本输入框和滑块
368
- text_input = gr.Textbox(label="Text Input", placeholder="Enter some text here... (optional)")
369
-
370
 
371
  with gr.Column():
372
  # 输出图像和控件
@@ -374,19 +422,29 @@ with gr.Blocks() as demo:
374
 
375
  slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
376
 
 
377
  text_output_class = gr.Textbox(label="Explaining Category")
 
378
  with gr.Row():
379
  # 最高置信度和插入 AUC Score 并排显示
380
  text_output_confidence = gr.Textbox(label="Highest Confidence")
381
  text_output_auc = gr.Textbox(label="Insertion AUC Score")
382
 
383
- interpret_button = gr.Button("Interpreting Model")
 
 
384
 
385
  # 定义解释模型按钮点击事件
386
  interpret_button.click(
387
  fn=interpret_image,
388
  inputs=[image_input, slider, text_input],
389
- outputs=[image_output, text_output_confidence, text_output_auc, text_output_class]
 
 
 
 
 
 
390
  )
391
 
392
  # 实时更新的滑块
 
113
  lambda2=0.05,
114
  lambda3=20.,
115
  lambda4=5.)
116
+ explainer.org_semantic_feature = semantic_feature
117
 
118
  def add_value_decrease(smdl_mask, json_file):
119
  single_mask = np.zeros_like(smdl_mask[0].mean(-1))
 
277
  def read_image(file_path):
278
  image = Image.open(file_path)
279
  image = image.convert("RGB")
280
+ image = image.resize((512,512))
281
  return np.array(image)
282
 
283
  # 使用同一个示例图像 "shark.png"
284
  default_images = {
285
  # "Default Image": read_image("images/shark.png"),
286
+ "Example: tiger shark": read_image("images/shark.png"),
287
+ "Example: quail": read_image("images/bird.png"), # 所有选项都使用相同的图片
288
+ "Example: tabby cat or lion": read_image("images/cat_lion.jpeg"),
289
+ "Example: rabbit or duck": read_image("images/rabbit-duck.jpg"),
290
  }
291
 
292
  def interpret_image(uploaded_image, slider, text_input):
 
300
  element_sets_V = SubRegionDivision(image, mode="slico", region_size=40)
301
 
302
  explainer.k = len(element_sets_V)
 
303
 
304
  global submodular_image_set
305
  global saved_json_file
306
+
307
+ image_input = explainer.preproccessing_function(image).unsqueeze(0)
308
+ predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
309
+
310
+ # input
311
+ if text_input == "":
312
+ target_id = predicted_class
313
+ else:
314
+ if text_input in imagenet_classes:
315
+ target_id = imagenet_classes.index(text_input)
316
+ else:
317
+ target_id = -1
318
+ texts = [text_input]
319
+ texts = clip.tokenize(texts).to(device) #tokenize
320
+
321
+ with torch.no_grad():
322
+ class_embeddings = vis_model.model.encode_text(texts)
323
+ class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
324
+ class_embeddings = class_embeddings.to(device) * 100
325
+
326
+ explainer.semantic_feature = torch.cat((explainer.org_semantic_feature, class_embeddings), dim=0)
327
+
328
  # global im
329
+ submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=target_id)
330
 
331
  # attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
332
  # im, heatmap = gen_cam(image, norm_image(attribution_map))
333
 
334
  image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)
335
 
336
+ if target_id == -1:
337
+ text_output_class = "This method explains that CLIP is interested in describing \"{}\".".format(text_input)
338
+ else:
339
+ text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
340
+
341
+ text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
342
+
343
+ explainer.semantic_feature = explainer.org_semantic_feature
344
+
345
+ return image_curve, highest_confidence, insertion_auc_score, text_output_class, text_output_predict, None
346
+
347
+ def predict_image(uploaded_image):
348
+ # 使用上传的图像(如果有),否则使用生成的图像
349
+ if uploaded_image is not None:
350
+ image = np.array(uploaded_image)
351
+ else:
352
+ return None, 0, 0
353
 
354
+ image = cv2.resize(image, (224, 224))
355
+
356
+ image_input = explainer.preproccessing_function(image).unsqueeze(0)
357
+ predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
358
+
359
+ text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
360
+
361
+ return text_output_predict
362
+
363
 
364
  def visualization_slider(uploaded_image, slider):
365
  # 使用上传的图像(如果有),否则使用生成的图像
 
380
 
381
  # 创建 Gradio 界面
382
  with gr.Blocks() as demo:
383
+ gr.Markdown("# Semantic Region Attribution and Mistake Discovery via Submodular Subset Selection") # 使用Markdown添加标题
384
 
385
  gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
386
  with gr.Row():
 
401
  # )
402
  gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
403
 
404
+ # 文本输入框和滑块
405
+ text_input = gr.Textbox(label="Text Input", placeholder="You can choose what you want to explain. You can enter a word (e.g., 'Rabbit') or a description (e.g., 'A photo of a rabbit'). If the input is empty, the model will explain the predicted category.")
406
+
407
  # 第二排:两个缩略图
408
  with gr.Row():
409
  for key in default_images.keys():
 
415
  inputs=[],
416
  outputs=image_input
417
  )
 
 
 
 
418
 
419
  with gr.Column():
420
  # 输出图像和控件
 
422
 
423
  slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
424
 
425
+ text_output_predict = gr.Textbox(label="Predicted Category")
426
  text_output_class = gr.Textbox(label="Explaining Category")
427
+
428
  with gr.Row():
429
  # 最高置信度和插入 AUC Score 并排显示
430
  text_output_confidence = gr.Textbox(label="Highest Confidence")
431
  text_output_auc = gr.Textbox(label="Insertion AUC Score")
432
 
433
+ with gr.Row():
434
+ predict_button = gr.Button("Model Inference")
435
+ interpret_button = gr.Button("Interpreting Model")
436
 
437
  # 定义解释模型按钮点击事件
438
  interpret_button.click(
439
  fn=interpret_image,
440
  inputs=[image_input, slider, text_input],
441
+ outputs=[image_output, text_output_confidence, text_output_auc, text_output_class, text_output_predict, text_input]
442
+ )
443
+
444
+ predict_button.click(
445
+ fn=predict_image,
446
+ inputs=[image_input],
447
+ outputs=[text_output_predict]
448
  )
449
 
450
  # 实时更新的滑块
images/cat_lion.jpeg ADDED
images/rabbit-duck.jpg ADDED