Spaces:
Sleeping
Sleeping
add model debugging
Browse files- app.py +71 -13
- images/cat_lion.jpeg +0 -0
- images/rabbit-duck.jpg +0 -0
app.py
CHANGED
@@ -113,6 +113,7 @@ explainer = MultiModalSubModularExplanationEfficientPlus(
|
|
113 |
lambda2=0.05,
|
114 |
lambda3=20.,
|
115 |
lambda4=5.)
|
|
|
116 |
|
117 |
def add_value_decrease(smdl_mask, json_file):
|
118 |
single_mask = np.zeros_like(smdl_mask[0].mean(-1))
|
@@ -276,13 +277,16 @@ def norm_image(image):
|
|
276 |
def read_image(file_path):
|
277 |
image = Image.open(file_path)
|
278 |
image = image.convert("RGB")
|
|
|
279 |
return np.array(image)
|
280 |
|
281 |
# 使用同一个示例图像 "shark.png"
|
282 |
default_images = {
|
283 |
# "Default Image": read_image("images/shark.png"),
|
284 |
-
"Example:
|
285 |
-
"Example:
|
|
|
|
|
286 |
}
|
287 |
|
288 |
def interpret_image(uploaded_image, slider, text_input):
|
@@ -296,21 +300,66 @@ def interpret_image(uploaded_image, slider, text_input):
|
|
296 |
element_sets_V = SubRegionDivision(image, mode="slico", region_size=40)
|
297 |
|
298 |
explainer.k = len(element_sets_V)
|
299 |
-
print(len(element_sets_V))
|
300 |
|
301 |
global submodular_image_set
|
302 |
global saved_json_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
303 |
# global im
|
304 |
-
submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=
|
305 |
|
306 |
# attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
|
307 |
# im, heatmap = gen_cam(image, norm_image(attribution_map))
|
308 |
|
309 |
image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)
|
310 |
|
311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
def visualization_slider(uploaded_image, slider):
|
316 |
# 使用上传的图像(如果有),否则使用生成的图像
|
@@ -331,7 +380,7 @@ def update_image(thumbnail_name):
|
|
331 |
|
332 |
# 创建 Gradio 界面
|
333 |
with gr.Blocks() as demo:
|
334 |
-
gr.Markdown("# Semantic Region Attribution via Submodular Subset Selection") # 使用Markdown添加标题
|
335 |
|
336 |
gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
|
337 |
with gr.Row():
|
@@ -352,6 +401,9 @@ with gr.Blocks() as demo:
|
|
352 |
# )
|
353 |
gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
|
354 |
|
|
|
|
|
|
|
355 |
# 第二排:两个缩略图
|
356 |
with gr.Row():
|
357 |
for key in default_images.keys():
|
@@ -363,10 +415,6 @@ with gr.Blocks() as demo:
|
|
363 |
inputs=[],
|
364 |
outputs=image_input
|
365 |
)
|
366 |
-
|
367 |
-
# 文本输入框和滑块
|
368 |
-
text_input = gr.Textbox(label="Text Input", placeholder="Enter some text here... (optional)")
|
369 |
-
|
370 |
|
371 |
with gr.Column():
|
372 |
# 输出图像和控件
|
@@ -374,19 +422,29 @@ with gr.Blocks() as demo:
|
|
374 |
|
375 |
slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
|
376 |
|
|
|
377 |
text_output_class = gr.Textbox(label="Explaining Category")
|
|
|
378 |
with gr.Row():
|
379 |
# 最高置信度和插入 AUC Score 并排显示
|
380 |
text_output_confidence = gr.Textbox(label="Highest Confidence")
|
381 |
text_output_auc = gr.Textbox(label="Insertion AUC Score")
|
382 |
|
383 |
-
|
|
|
|
|
384 |
|
385 |
# 定义解释模型按钮点击事件
|
386 |
interpret_button.click(
|
387 |
fn=interpret_image,
|
388 |
inputs=[image_input, slider, text_input],
|
389 |
-
outputs=[image_output, text_output_confidence, text_output_auc, text_output_class]
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
)
|
391 |
|
392 |
# 实时更新的滑块
|
|
|
113 |
lambda2=0.05,
|
114 |
lambda3=20.,
|
115 |
lambda4=5.)
|
116 |
+
explainer.org_semantic_feature = semantic_feature
|
117 |
|
118 |
def add_value_decrease(smdl_mask, json_file):
|
119 |
single_mask = np.zeros_like(smdl_mask[0].mean(-1))
|
|
|
277 |
def read_image(file_path):
|
278 |
image = Image.open(file_path)
|
279 |
image = image.convert("RGB")
|
280 |
+
image = image.resize((512,512))
|
281 |
return np.array(image)
|
282 |
|
283 |
# 使用同一个示例图像 "shark.png"
|
284 |
default_images = {
|
285 |
# "Default Image": read_image("images/shark.png"),
|
286 |
+
"Example: tiger shark": read_image("images/shark.png"),
|
287 |
+
"Example: quail": read_image("images/bird.png"), # 所有选项都使用相同的图片
|
288 |
+
"Example: tabby cat or lion": read_image("images/cat_lion.jpeg"),
|
289 |
+
"Example: rabbit or duck": read_image("images/rabbit-duck.jpg"),
|
290 |
}
|
291 |
|
292 |
def interpret_image(uploaded_image, slider, text_input):
|
|
|
300 |
element_sets_V = SubRegionDivision(image, mode="slico", region_size=40)
|
301 |
|
302 |
explainer.k = len(element_sets_V)
|
|
|
303 |
|
304 |
global submodular_image_set
|
305 |
global saved_json_file
|
306 |
+
|
307 |
+
image_input = explainer.preproccessing_function(image).unsqueeze(0)
|
308 |
+
predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
|
309 |
+
|
310 |
+
# input
|
311 |
+
if text_input == "":
|
312 |
+
target_id = predicted_class
|
313 |
+
else:
|
314 |
+
if text_input in imagenet_classes:
|
315 |
+
target_id = imagenet_classes.index(text_input)
|
316 |
+
else:
|
317 |
+
target_id = -1
|
318 |
+
texts = [text_input]
|
319 |
+
texts = clip.tokenize(texts).to(device) #tokenize
|
320 |
+
|
321 |
+
with torch.no_grad():
|
322 |
+
class_embeddings = vis_model.model.encode_text(texts)
|
323 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
324 |
+
class_embeddings = class_embeddings.to(device) * 100
|
325 |
+
|
326 |
+
explainer.semantic_feature = torch.cat((explainer.org_semantic_feature, class_embeddings), dim=0)
|
327 |
+
|
328 |
# global im
|
329 |
+
submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=target_id)
|
330 |
|
331 |
# attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
|
332 |
# im, heatmap = gen_cam(image, norm_image(attribution_map))
|
333 |
|
334 |
image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)
|
335 |
|
336 |
+
if target_id == -1:
|
337 |
+
text_output_class = "This method explains that CLIP is interested in describing \"{}\".".format(text_input)
|
338 |
+
else:
|
339 |
+
text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
|
340 |
+
|
341 |
+
text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
|
342 |
+
|
343 |
+
explainer.semantic_feature = explainer.org_semantic_feature
|
344 |
+
|
345 |
+
return image_curve, highest_confidence, insertion_auc_score, text_output_class, text_output_predict, None
|
346 |
+
|
347 |
+
def predict_image(uploaded_image):
|
348 |
+
# 使用上传的图像(如果有),否则使用生成的图像
|
349 |
+
if uploaded_image is not None:
|
350 |
+
image = np.array(uploaded_image)
|
351 |
+
else:
|
352 |
+
return None, 0, 0
|
353 |
|
354 |
+
image = cv2.resize(image, (224, 224))
|
355 |
+
|
356 |
+
image_input = explainer.preproccessing_function(image).unsqueeze(0)
|
357 |
+
predicted_class = (explainer.model(image_input.to(explainer.device)) @ explainer.semantic_feature.T).argmax().cpu().item()
|
358 |
+
|
359 |
+
text_output_predict = "The image is predicted as {}".format(imagenet_classes[predicted_class])
|
360 |
+
|
361 |
+
return text_output_predict
|
362 |
+
|
363 |
|
364 |
def visualization_slider(uploaded_image, slider):
|
365 |
# 使用上传的图像(如果有),否则使用生成的图像
|
|
|
380 |
|
381 |
# 创建 Gradio 界面
|
382 |
with gr.Blocks() as demo:
|
383 |
+
gr.Markdown("# Semantic Region Attribution and Mistake Discovery via Submodular Subset Selection") # 使用Markdown添加标题
|
384 |
|
385 |
gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the attribution performance. The inference time is about 5 minutes (GPU is about 4s). If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
|
386 |
with gr.Row():
|
|
|
401 |
# )
|
402 |
gr.Textbox("Thank you for using our interpretable attribution method, which originates from the ICLR 2024 Oral paper titled \"Less is More: Fewer Interpretable Regions via Submodular Subset Selection.\" We have now implemented this method on the multimodal ViT model and achieved promising results in explaining model predictions. A key feature of our approach is its ability to clarify the reasons behind the model's prediction errors. We invite you to try out this demo and explore its capabilities. The source code is available at https://github.com/RuoyuChen10/SMDL-Attribution.\nYou can upload an image yourself or select one from the following, then click the button Interpreting Model to get the result. The demo currently does not support selecting categories or descriptions by yourself. If you are interested, you can try it from the source code.", label="Instructions for use", interactive=False)
|
403 |
|
404 |
+
# 文本输入框和滑块
|
405 |
+
text_input = gr.Textbox(label="Text Input", placeholder="You can choose what you want to explain. You can enter a word (e.g., 'Rabbit') or a description (e.g., 'A photo of a rabbit'). If the input is empty, the model will explain the predicted category.")
|
406 |
+
|
407 |
# 第二排:两个缩略图
|
408 |
with gr.Row():
|
409 |
for key in default_images.keys():
|
|
|
415 |
inputs=[],
|
416 |
outputs=image_input
|
417 |
)
|
|
|
|
|
|
|
|
|
418 |
|
419 |
with gr.Column():
|
420 |
# 输出图像和控件
|
|
|
422 |
|
423 |
slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
|
424 |
|
425 |
+
text_output_predict = gr.Textbox(label="Predicted Category")
|
426 |
text_output_class = gr.Textbox(label="Explaining Category")
|
427 |
+
|
428 |
with gr.Row():
|
429 |
# 最高置信度和插入 AUC Score 并排显示
|
430 |
text_output_confidence = gr.Textbox(label="Highest Confidence")
|
431 |
text_output_auc = gr.Textbox(label="Insertion AUC Score")
|
432 |
|
433 |
+
with gr.Row():
|
434 |
+
predict_button = gr.Button("Model Inference")
|
435 |
+
interpret_button = gr.Button("Interpreting Model")
|
436 |
|
437 |
# 定义解释模型按钮点击事件
|
438 |
interpret_button.click(
|
439 |
fn=interpret_image,
|
440 |
inputs=[image_input, slider, text_input],
|
441 |
+
outputs=[image_output, text_output_confidence, text_output_auc, text_output_class, text_output_predict, text_input]
|
442 |
+
)
|
443 |
+
|
444 |
+
predict_button.click(
|
445 |
+
fn=predict_image,
|
446 |
+
inputs=[image_input],
|
447 |
+
outputs=[text_output_predict]
|
448 |
)
|
449 |
|
450 |
# 实时更新的滑块
|
images/cat_lion.jpeg
ADDED
![]() |
images/rabbit-duck.jpg
ADDED
![]() |