wondervictor commited on
Commit
3dd984b
·
verified ·
1 Parent(s): d2690c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -140,8 +140,8 @@ mask_adapter = None
140
  @torch.no_grad()
141
  @torch.autocast(device_type="cuda", dtype=torch.float32)
142
  def inference_box(input_img, img_state,class_names_input):
143
- if len(img_state.selected_bboxes) != 2:
144
- return None
145
  mp.set_start_method("spawn", force=True)
146
 
147
  box_points = img_state.selected_bboxes
@@ -239,6 +239,12 @@ def get_bbox_with_draw(image, img_state, evt: gr.SelectData):
239
  )
240
  return img_state, image
241
 
 
 
 
 
 
 
242
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
243
  cfg = setup_cfg(cfg)
244
  global sam2_model, clip_model, mask_adapter
@@ -356,7 +362,7 @@ with gr.Blocks() as demo:
356
  [input_image, img_state_bbox],
357
  outputs=[img_state_bbox, input_image]
358
  ).then(
359
- inference_box,
360
  inputs=[input_image, img_state_bbox,class_names_input_box],
361
  outputs=[output_image_box]
362
  )
 
140
  @torch.no_grad()
141
  @torch.autocast(device_type="cuda", dtype=torch.float32)
142
  def inference_box(input_img, img_state,class_names_input):
143
+ # if len(img_state.selected_bboxes) != 2:
144
+ # return None
145
  mp.set_start_method("spawn", force=True)
146
 
147
  box_points = img_state.selected_bboxes
 
239
  )
240
  return img_state, image
241
 
242
+ def check_and_infer_box(img_state_bbox, input_image, class_names_input_box):
243
+ if len(img_state_bbox.selected_bboxes) == 2:
244
+ return inference_box(input_image, img_state_bbox, class_names_input_box)
245
+ return None
246
+
247
+
248
  def initialize_models(sam_path, adapter_pth, model_cfg, cfg):
249
  cfg = setup_cfg(cfg)
250
  global sam2_model, clip_model, mask_adapter
 
362
  [input_image, img_state_bbox],
363
  outputs=[img_state_bbox, input_image]
364
  ).then(
365
+ check_and_infer_box,
366
  inputs=[input_image, img_state_bbox,class_names_input_box],
367
  outputs=[output_image_box]
368
  )