fffiloni commited on
Commit
6c326ed
·
verified ·
1 Parent(s): 01fc1ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -301,7 +301,7 @@ def get_mask_sam_process(
301
 
302
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
303
 
304
- def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
305
  #### PROPAGATION ####
306
  sam2_checkpoint, model_cfg = load_model(checkpoint)
307
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
@@ -349,11 +349,14 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
349
  # Append the file path to the list
350
  jpeg_images.append(output_filename)
351
 
 
 
 
352
  torch.cuda.empty_cache()
353
  print(f"JPEG_IMAGES: {jpeg_images}")
354
 
355
  if vis_frame_type == "check":
356
- return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=True)
357
  elif vis_frame_type == "render":
358
  # Create a video clip from the image sequence
359
  original_fps = get_video_fps(video_in)
@@ -369,7 +372,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
369
  codec='libx264'
370
  )
371
 
372
- return gr.update(value=None), gr.update(value=final_vid_output_path), None
373
 
374
  def update_ui(vis_frame_type):
375
  if vis_frame_type == "check":
@@ -396,7 +399,7 @@ def reset_propagation(predictor, stored_inference_state):
396
 
397
  predictor.reset_state(stored_inference_state)
398
  print(f"RESET State: {stored_inference_state} ")
399
- return gr.update(value=None, visible=False), stored_inference_state, None
400
 
401
  with gr.Blocks() as demo:
402
  first_frame_path = gr.State()
@@ -437,7 +440,7 @@ with gr.Blocks() as demo:
437
 
438
  with gr.Row():
439
  checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
440
- submit_btn = gr.Button("Submit", size="lg")
441
 
442
  with gr.Accordion("Your video IN", open=True) as video_in_drawer:
443
  video_in = gr.Video(label="Video IN")
@@ -541,7 +544,7 @@ with gr.Blocks() as demo:
541
  reset_prpgt_brn.click(
542
  fn = reset_propagation,
543
  inputs = [loaded_predictor, stored_inference_state],
544
- outputs = [output_propagated, stored_inference_state, output_result],
545
  queue=False
546
  )
547
 
@@ -552,8 +555,8 @@ with gr.Blocks() as demo:
552
  queue=False
553
  ).then(
554
  fn = propagate_to_all,
555
- inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type],
556
- outputs = [output_propagated, output_video, working_frame]
557
  )
558
 
559
  demo.launch(show_api=False, show_error=True)
 
301
 
302
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
303
 
304
+ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, progress=gr.Progress(track_tqdm=True)):
305
  #### PROPAGATION ####
306
  sam2_checkpoint, model_cfg = load_model(checkpoint)
307
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
349
  # Append the file path to the list
350
  jpeg_images.append(output_filename)
351
 
352
+ if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
353
+ available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
354
+
355
  torch.cuda.empty_cache()
356
  print(f"JPEG_IMAGES: {jpeg_images}")
357
 
358
  if vis_frame_type == "check":
359
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=None, visible=True), available_frames_to_check
360
  elif vis_frame_type == "render":
361
  # Create a video clip from the image sequence
362
  original_fps = get_video_fps(video_in)
 
372
  codec='libx264'
373
  )
374
 
375
+ return gr.update(value=None), gr.update(value=final_vid_output_path), None, available_frames_to_check
376
 
377
  def update_ui(vis_frame_type):
378
  if vis_frame_type == "check":
 
399
 
400
  predictor.reset_state(stored_inference_state)
401
  print(f"RESET State: {stored_inference_state} ")
402
+ return gr.update(value=None, visible=False), stored_inference_state, None, gr.State([])
403
 
404
  with gr.Blocks() as demo:
405
  first_frame_path = gr.State()
 
440
 
441
  with gr.Row():
442
  checkpoint = gr.Dropdown(label="Checkpoint", choices=["tiny", "small", "base-plus", "large"], value="tiny")
443
+ submit_btn = gr.Button("Get Mask", size="lg")
444
 
445
  with gr.Accordion("Your video IN", open=True) as video_in_drawer:
446
  video_in = gr.Video(label="Video IN")
 
544
  reset_prpgt_brn.click(
545
  fn = reset_propagation,
546
  inputs = [loaded_predictor, stored_inference_state],
547
+ outputs = [output_propagated, stored_inference_state, output_result, available_frames_to_check],
548
  queue=False
549
  )
550
 
 
555
  queue=False
556
  ).then(
557
  fn = propagate_to_all,
558
+ inputs = [video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check],
559
+ outputs = [output_propagated, output_video, working_frame, available_frames_to_check]
560
  )
561
 
562
  demo.launch(show_api=False, show_error=True)