svjack commited on
Commit
a8002a7
·
verified ·
1 Parent(s): 340485a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py CHANGED
@@ -408,6 +408,7 @@ def get_mask_sam_process(
408
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
409
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
410
 
 
411
  #@spaces.GPU
412
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
413
  # use bfloat16 for the entire notebook
@@ -505,6 +506,107 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
505
  codec='libx264'
506
  )
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
509
 
510
  def update_ui(vis_frame_type):
 
408
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
409
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
410
 
411
+ '''
412
  #@spaces.GPU
413
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
414
  # use bfloat16 for the entire notebook
 
506
  codec='libx264'
507
  )
508
 
509
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
510
+ '''
511
+
512
+ import json
513
+ import numpy as np
514
+
515
+ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
516
+ # use bfloat16 for the entire notebook
517
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
518
+
519
+ if torch.cuda.get_device_properties(0).major >= 8:
520
+ # turn on tfloat32 for Ampere GPUs
521
+ torch.backends.cuda.matmul.allow_tf32 = True
522
+ torch.backends.cudnn.allow_tf32 = True
523
+
524
+ #### PROPAGATION ####
525
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
526
+ # set predictor
527
+ inference_state = stored_inference_state
528
+
529
+ if torch.cuda.is_available():
530
+ inference_state["device"] = 'cuda'
531
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
532
+ else:
533
+ inference_state["device"] = 'cpu'
534
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
535
+
536
+ frame_names = stored_frame_names
537
+ video_dir = video_frames_dir
538
+
539
+ # Define a directory to save the JPEG images
540
+ frames_output_dir = "frames_output_images"
541
+ os.makedirs(frames_output_dir, exist_ok=True)
542
+
543
+ # Initialize a list to store file paths of saved images
544
+ jpeg_images = []
545
+
546
+ # Initialize a list to store mask area ratios
547
+ mask_area_ratios = []
548
+
549
+ # run propagation throughout the video and collect the results in a dict
550
+ video_segments = {} # video_segments contains the per-frame segmentation results
551
+ out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False)
552
+ print(out_obj_ids)
553
+ for frame_idx in range(0, inference_state['num_frames']):
554
+ video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx] > 0.0).cpu().numpy()}
555
+
556
+ # Calculate mask area ratio
557
+ mask = video_segments[frame_idx][out_obj_ids[0]]
558
+ mask_area = np.sum(mask) # Number of True pixels in the mask
559
+ total_area = mask.shape[0] * mask.shape[1] # Total number of pixels in the frame
560
+ mask_area_ratio = mask_area / total_area # Ratio of mask area to total area
561
+
562
+ mask_area_ratio = mask_area / np.ones_like(mask).sum()
563
+
564
+ mask_area_ratios.append(mask_area_ratio)
565
+
566
+ # Save mask area ratios as a JSON file
567
+ mask_area_ratios_dict = {f"frame_{frame_idx}": ratio for frame_idx, ratio in enumerate(mask_area_ratios)}
568
+ with open("mask_area_ratios.json", "w") as f:
569
+ json.dump(mask_area_ratios_dict, f, indent=4)
570
+
571
+ # render the segmentation results every few frames
572
+ if vis_frame_type == "check":
573
+ vis_frame_stride = 15
574
+ elif vis_frame_type == "render":
575
+ vis_frame_stride = 1
576
+
577
+ plt.close("all")
578
+ for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
579
+ plt.figure(figsize=(6, 4))
580
+ plt.title(f"frame {out_frame_idx}")
581
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
582
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
583
+ show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
584
+
585
+ # Define the output filename and save the figure as a JPEG file
586
+ output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
587
+ plt.savefig(output_filename, format='jpg')
588
+
589
+ # Close the plot
590
+ plt.close()
591
+
592
+ # Append the file path to the list
593
+ jpeg_images.append(output_filename)
594
+
595
+ if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
596
+ available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
597
+
598
+ torch.cuda.empty_cache()
599
+ print(f"JPEG_IMAGES: {jpeg_images}")
600
+
601
+ if vis_frame_type == "check":
602
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), mask_area_ratios_dict
603
+ elif vis_frame_type == "render":
604
+ # Create a video clip from the image sequence
605
+ original_fps = get_video_fps(video_in)
606
+ clip = ImageSequenceClip(jpeg_images, fps=original_fps // 6)
607
+ final_vid_output_path = "output_video.mp4"
608
+ clip.write_videofile(final_vid_output_path, codec='libx264')
609
+
610
  return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
611
 
612
  def update_ui(vis_frame_type):