fffiloni commited on
Commit
2ac24da
·
verified ·
1 Parent(s): 59c23c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -1
app.py CHANGED
@@ -299,7 +299,7 @@ def get_mask_sam_process(
299
  available_frames_to_check.append(working_frame)
300
  print(available_frames_to_check)
301
 
302
- return "output_first_frame.jpg", frame_names, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
303
 
304
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
305
  #### PROPAGATION ####
@@ -392,12 +392,19 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
392
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
393
  return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
394
 
 
 
 
 
 
 
395
  with gr.Blocks() as demo:
396
  first_frame_path = gr.State()
397
  tracking_points = gr.State([])
398
  trackings_input_label = gr.State([])
399
  video_frames_dir = gr.State()
400
  scanned_frames = gr.State()
 
401
  stored_inference_state = gr.State()
402
  stored_frame_names = gr.State()
403
  available_frames_to_check = gr.State([])
@@ -442,6 +449,7 @@ with gr.Blocks() as demo:
442
  with gr.Row():
443
  vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
444
  propagate_btn = gr.Button("Propagate", scale=1)
 
445
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", visible=False)
446
  output_video = gr.Video(visible=False)
447
  # output_result_mask = gr.Image()
@@ -524,11 +532,19 @@ with gr.Blocks() as demo:
524
  outputs = [
525
  output_result,
526
  stored_frame_names,
 
527
  stored_inference_state,
528
  working_frame,
529
  ]
530
  )
531
 
 
 
 
 
 
 
 
532
  propagate_btn.click(
533
  fn = update_ui,
534
  inputs = [vis_frame_type],
 
299
  available_frames_to_check.append(working_frame)
300
  print(available_frames_to_check)
301
 
302
+ return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
303
 
304
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
305
  #### PROPAGATION ####
 
392
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
393
  return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
394
 
395
+ def reset_propagation(predictor, stored_inference_state):
396
+
397
+ predictor.reset_state(stored_inference_state)
398
+ print(f"RESET State: {stored_inference_state} ")
399
+ return stored_inference_state
400
+
401
  with gr.Blocks() as demo:
402
  first_frame_path = gr.State()
403
  tracking_points = gr.State([])
404
  trackings_input_label = gr.State([])
405
  video_frames_dir = gr.State()
406
  scanned_frames = gr.State()
407
+ loaded_predictor = gr.State()
408
  stored_inference_state = gr.State()
409
  stored_frame_names = gr.State()
410
  available_frames_to_check = gr.State([])
 
449
  with gr.Row():
450
  vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
451
  propagate_btn = gr.Button("Propagate", scale=1)
452
+ reset_prpgt_brn = gr.Button("Reset", scale=0.75)
453
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", visible=False)
454
  output_video = gr.Video(visible=False)
455
  # output_result_mask = gr.Image()
 
532
  outputs = [
533
  output_result,
534
  stored_frame_names,
535
+ loaded_predictor,
536
  stored_inference_state,
537
  working_frame,
538
  ]
539
  )
540
 
541
+ reset_prpgt_brn.click(
542
+ fn = reset_propagation,
543
+ inputs = [loaded_predictor, stored_inference_state],
544
+ outputs = [stored_inference_state],
545
+ queue=False
546
+ )
547
+
548
  propagate_btn.click(
549
  fn = update_ui,
550
  inputs = [vis_frame_type],