RuoyuChen commited on
Commit
8c0a35a
·
1 Parent(s): 55c3fad

reduce time

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -89,8 +89,8 @@ def zeroshot_classifier(model, classnames, templates, device):
89
  zeroshot_weights = torch.stack(zeroshot_weights).cuda()
90
  return zeroshot_weights*100
91
 
92
- # device = "cuda" if torch.cuda.is_available() else "cpu"
93
- device = "cpu"
94
  # Instantiate model
95
  vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
96
  vis_model.eval()
@@ -232,6 +232,7 @@ def visualization(image, submodular_image_set, saved_json_file, index=None, comp
232
 
233
  ax3.set_title('Insertion Curve', fontsize=54)
234
 
 
235
  fig.canvas.draw()
236
  img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
237
  img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,))
@@ -331,6 +332,8 @@ def update_image(thumbnail_name):
331
  # 创建 Gradio 界面
332
  with gr.Blocks() as demo:
333
  gr.Markdown("# Semantic Region Attribution via Submodular Subset Selection") # 使用Markdown添加标题
 
 
334
  with gr.Row():
335
  with gr.Column():
336
  # 第一排:上传图像输入框和一个缩略图
 
89
  zeroshot_weights = torch.stack(zeroshot_weights).cuda()
90
  return zeroshot_weights*100
91
 
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+ # device = "cpu"
94
  # Instantiate model
95
  vis_model = CLIPModel_Super("ViT-B/16", device=device, download_root="./ckpt")
96
  vis_model.eval()
 
232
 
233
  ax3.set_title('Insertion Curve', fontsize=54)
234
 
235
+ fig.tight_layout()
236
  fig.canvas.draw()
237
  img_curve = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
238
  img_curve = img_curve.reshape(fig.canvas.get_width_height()[::-1] + (3,))
 
332
  # 创建 Gradio 界面
333
  with gr.Blocks() as demo:
334
  gr.Markdown("# Semantic Region Attribution via Submodular Subset Selection") # 使用Markdown添加标题
335
+
336
+ gr.Markdown("Since huggingface only has ordinary CPUs available, our sub-region division is relatively coarse-grained, which may affect the model performance. The inference time is about 5 minutes. If you are interested, you can try our source code. We have written many scripts to facilitate visualization.")
337
  with gr.Row():
338
  with gr.Column():
339
  # 第一排:上传图像输入框和一个缩略图