RuoyuChen commited on
Commit
64a13b7
·
1 Parent(s): 015bb14

reduce time

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -89,8 +89,8 @@ def zeroshot_classifier(model, classnames, templates, device):
89
  zeroshot_weights = torch.stack(zeroshot_weights).cuda()
90
  return zeroshot_weights*100
91
 
92
- device = "cuda" if torch.cuda.is_available() else "cpu"
93
- # device = "cuda"
94
  # Instantiate model
95
  vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
96
  vis_model.eval()
@@ -135,7 +135,10 @@ def add_value_decrease(smdl_mask, json_file):
135
 
136
  return attribution_map, np.array(values)
137
 
138
- def visualization(image, submodular_image_set, saved_json_file, vis_image, index=None, compute_params=True):
 
 
 
139
 
140
  insertion_ours_images = []
141
  # deletion_ours_images = []
@@ -166,7 +169,7 @@ def visualization(image, submodular_image_set, saved_json_file, vis_image, index
166
  ax1.yaxis.set_visible(False)
167
  ax1.set_title('Attribution Map', fontsize=54)
168
  ax1.set_facecolor('white')
169
- ax1.imshow(vis_image.astype(np.uint8))
170
 
171
  ax2.spines["left"].set_visible(False)
172
  ax2.spines["right"].set_visible(False)
@@ -289,19 +292,20 @@ def interpret_image(uploaded_image, slider, text_input):
289
  return None, 0, 0
290
 
291
  image = cv2.resize(image, (224, 224))
292
- element_sets_V = SubRegionDivision(image, mode="slico", region_size=30)
293
 
294
  explainer.k = len(element_sets_V)
 
295
 
296
  global submodular_image_set
297
  global saved_json_file
298
- global im
299
  submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=None)
300
 
301
- attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
302
- im, heatmap = gen_cam(image, norm_image(attribution_map))
303
 
304
- image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, im, index=None)
305
 
306
  text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
307
 
@@ -316,7 +320,7 @@ def visualization_slider(uploaded_image, slider):
316
 
317
  image = cv2.resize(image, (224, 224))
318
 
319
- image_curve = visualization(image, submodular_image_set, saved_json_file, im, index=slider, compute_params=False)
320
 
321
  return image_curve
322
 
@@ -326,6 +330,7 @@ def update_image(thumbnail_name):
326
 
327
  # 创建 Gradio 界面
328
  with gr.Blocks() as demo:
 
329
  with gr.Row():
330
  with gr.Column():
331
  # 第一排:上传图像输入框和一个缩略图
@@ -364,7 +369,7 @@ with gr.Blocks() as demo:
364
  # 输出图像和控件
365
  image_output = gr.Image(label="Output Image")
366
 
367
- slider = gr.Slider(minimum=0, maximum=50, step=1, label="Confidence Slider")
368
 
369
  text_output_class = gr.Textbox(label="Explaining Category")
370
  with gr.Row():
 
89
  zeroshot_weights = torch.stack(zeroshot_weights).cuda()
90
  return zeroshot_weights*100
91
 
92
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ device = "cpu"
94
  # Instantiate model
95
  vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
96
  vis_model.eval()
 
135
 
136
  return attribution_map, np.array(values)
137
 
138
+ def visualization(image, submodular_image_set, saved_json_file, index=None, compute_params=True):
139
+
140
+ attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
141
+ vis_image, heatmap = gen_cam(image, norm_image(attribution_map))
142
 
143
  insertion_ours_images = []
144
  # deletion_ours_images = []
 
169
  ax1.yaxis.set_visible(False)
170
  ax1.set_title('Attribution Map', fontsize=54)
171
  ax1.set_facecolor('white')
172
+ ax1.imshow(vis_image[...,::-1].astype(np.uint8))
173
 
174
  ax2.spines["left"].set_visible(False)
175
  ax2.spines["right"].set_visible(False)
 
292
  return None, 0, 0
293
 
294
  image = cv2.resize(image, (224, 224))
295
+ element_sets_V = SubRegionDivision(image, mode="slico", region_size=40)
296
 
297
  explainer.k = len(element_sets_V)
298
+ print(len(element_sets_V))
299
 
300
  global submodular_image_set
301
  global saved_json_file
302
+ # global im
303
  submodular_image, submodular_image_set, saved_json_file = explainer(element_sets_V, id=None)
304
 
305
+ # attribution_map, value_list = add_value_decrease(submodular_image_set, saved_json_file)
306
+ # im, heatmap = gen_cam(image, norm_image(attribution_map))
307
 
308
+ image_curve, highest_confidence, insertion_auc_score, ours_best_index = visualization(image, submodular_image_set, saved_json_file, index=None)
309
 
310
  text_output_class = "The method explains why the CLIP (ViT-B/16) model identifies an image as {}.".format(imagenet_classes[explainer.target_label])
311
 
 
320
 
321
  image = cv2.resize(image, (224, 224))
322
 
323
+ image_curve = visualization(image, submodular_image_set, saved_json_file, index=slider, compute_params=False)
324
 
325
  return image_curve
326
 
 
330
 
331
  # 创建 Gradio 界面
332
  with gr.Blocks() as demo:
333
+ gr.Markdown("# Semantic Region Attribution via Submodular Subset Selection") # 使用Markdown添加标题
334
  with gr.Row():
335
  with gr.Column():
336
  # 第一排:上传图像输入框和一个缩略图
 
369
  # 输出图像和控件
370
  image_output = gr.Image(label="Output Image")
371
 
372
+ slider = gr.Slider(minimum=0, maximum=34, step=1, label="Number of Sub-regions")
373
 
374
  text_output_class = gr.Textbox(label="Explaining Category")
375
  with gr.Row():