Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -408,6 +408,7 @@ def get_mask_sam_process(
|
|
408 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
409 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
410 |
|
|
|
411 |
#@spaces.GPU
|
412 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
413 |
# use bfloat16 for the entire notebook
|
@@ -505,6 +506,107 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
505 |
codec='libx264'
|
506 |
)
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
509 |
|
510 |
def update_ui(vis_frame_type):
|
|
|
408 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
409 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
410 |
|
411 |
+
'''
|
412 |
#@spaces.GPU
|
413 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
414 |
# use bfloat16 for the entire notebook
|
|
|
506 |
codec='libx264'
|
507 |
)
|
508 |
|
509 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
510 |
+
'''
|
511 |
+
|
512 |
+
import json
|
513 |
+
import numpy as np
|
514 |
+
|
515 |
+
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
516 |
+
# use bfloat16 for the entire notebook
|
517 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
518 |
+
|
519 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
520 |
+
# turn on tfloat32 for Ampere GPUs
|
521 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
522 |
+
torch.backends.cudnn.allow_tf32 = True
|
523 |
+
|
524 |
+
#### PROPAGATION ####
|
525 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
526 |
+
# set predictor
|
527 |
+
inference_state = stored_inference_state
|
528 |
+
|
529 |
+
if torch.cuda.is_available():
|
530 |
+
inference_state["device"] = 'cuda'
|
531 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
532 |
+
else:
|
533 |
+
inference_state["device"] = 'cpu'
|
534 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
|
535 |
+
|
536 |
+
frame_names = stored_frame_names
|
537 |
+
video_dir = video_frames_dir
|
538 |
+
|
539 |
+
# Define a directory to save the JPEG images
|
540 |
+
frames_output_dir = "frames_output_images"
|
541 |
+
os.makedirs(frames_output_dir, exist_ok=True)
|
542 |
+
|
543 |
+
# Initialize a list to store file paths of saved images
|
544 |
+
jpeg_images = []
|
545 |
+
|
546 |
+
# Initialize a list to store mask area ratios
|
547 |
+
mask_area_ratios = []
|
548 |
+
|
549 |
+
# run propagation throughout the video and collect the results in a dict
|
550 |
+
video_segments = {} # video_segments contains the per-frame segmentation results
|
551 |
+
out_obj_ids, out_mask_logits = predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False)
|
552 |
+
print(out_obj_ids)
|
553 |
+
for frame_idx in range(0, inference_state['num_frames']):
|
554 |
+
video_segments[frame_idx] = {out_obj_ids[0]: (out_mask_logits[frame_idx] > 0.0).cpu().numpy()}
|
555 |
+
|
556 |
+
# Calculate mask area ratio
|
557 |
+
mask = video_segments[frame_idx][out_obj_ids[0]]
|
558 |
+
mask_area = np.sum(mask) # Number of True pixels in the mask
|
559 |
+
total_area = mask.shape[0] * mask.shape[1] # Total number of pixels in the frame
|
560 |
+
mask_area_ratio = mask_area / total_area # Ratio of mask area to total area
|
561 |
+
|
562 |
+
mask_area_ratio = mask_area / np.ones_like(mask).sum()
|
563 |
+
|
564 |
+
mask_area_ratios.append(mask_area_ratio)
|
565 |
+
|
566 |
+
# Save mask area ratios as a JSON file
|
567 |
+
mask_area_ratios_dict = {f"frame_{frame_idx}": ratio for frame_idx, ratio in enumerate(mask_area_ratios)}
|
568 |
+
with open("mask_area_ratios.json", "w") as f:
|
569 |
+
json.dump(mask_area_ratios_dict, f, indent=4)
|
570 |
+
|
571 |
+
# render the segmentation results every few frames
|
572 |
+
if vis_frame_type == "check":
|
573 |
+
vis_frame_stride = 15
|
574 |
+
elif vis_frame_type == "render":
|
575 |
+
vis_frame_stride = 1
|
576 |
+
|
577 |
+
plt.close("all")
|
578 |
+
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
579 |
+
plt.figure(figsize=(6, 4))
|
580 |
+
plt.title(f"frame {out_frame_idx}")
|
581 |
+
plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
|
582 |
+
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
583 |
+
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
584 |
+
|
585 |
+
# Define the output filename and save the figure as a JPEG file
|
586 |
+
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
587 |
+
plt.savefig(output_filename, format='jpg')
|
588 |
+
|
589 |
+
# Close the plot
|
590 |
+
plt.close()
|
591 |
+
|
592 |
+
# Append the file path to the list
|
593 |
+
jpeg_images.append(output_filename)
|
594 |
+
|
595 |
+
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
596 |
+
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
597 |
+
|
598 |
+
torch.cuda.empty_cache()
|
599 |
+
print(f"JPEG_IMAGES: {jpeg_images}")
|
600 |
+
|
601 |
+
if vis_frame_type == "check":
|
602 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=available_frames_to_check, value=working_frame, visible=True), available_frames_to_check, gr.update(visible=True), mask_area_ratios_dict
|
603 |
+
elif vis_frame_type == "render":
|
604 |
+
# Create a video clip from the image sequence
|
605 |
+
original_fps = get_video_fps(video_in)
|
606 |
+
clip = ImageSequenceClip(jpeg_images, fps=original_fps // 6)
|
607 |
+
final_vid_output_path = "output_video.mp4"
|
608 |
+
clip.write_videofile(final_vid_output_path, codec='libx264')
|
609 |
+
|
610 |
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True)
|
611 |
|
612 |
def update_ui(vis_frame_type):
|