Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import time | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from sam2.build_sam import build_sam2_video_predictor | |
# Only cuda supported | |
assert torch.cuda.is_available() | |
device = torch.device("cuda") | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() | |
if torch.cuda.get_device_properties(0).major >= 8: | |
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Config and checkpoint | |
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt" | |
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml" | |
# Build video predictor with vos_optimized=True setting | |
predictor = build_sam2_video_predictor( | |
model_cfg, sam2_checkpoint, device=device, vos_optimized=True | |
) | |
# Initialize with video | |
video_dir = "notebooks/videos/bedroom" | |
# scan all the JPEG frame names in this directory | |
frame_names = [ | |
p | |
for p in os.listdir(video_dir) | |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] | |
] | |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | |
inference_state = predictor.init_state(video_path=video_dir) | |
# Number of runs, warmup etc | |
warm_up, runs = 5, 25 | |
verbose = True | |
num_frames = len(frame_names) | |
total, count = 0, 0 | |
torch.cuda.empty_cache() | |
# We will select an object with a click. | |
# See video_predictor_example.ipynb for more detailed explanation | |
ann_frame_idx, ann_obj_id = 0, 1 | |
# Add a positive click at (x, y) = (210, 350) | |
# For labels, `1` means positive click | |
points = np.array([[210, 350]], dtype=np.float32) | |
labels = np.array([1], np.int32) | |
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box( | |
inference_state=inference_state, | |
frame_idx=ann_frame_idx, | |
obj_id=ann_obj_id, | |
points=points, | |
labels=labels, | |
) | |
# Warmup and then average FPS over several runs | |
with torch.autocast("cuda", torch.bfloat16): | |
with torch.inference_mode(): | |
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): | |
start = time.time() | |
# Start tracking | |
for ( | |
out_frame_idx, | |
out_obj_ids, | |
out_mask_logits, | |
) in predictor.propagate_in_video(inference_state): | |
pass | |
end = time.time() | |
total += end - start | |
count += 1 | |
if i == warm_up - 1: | |
print("Warmup FPS: ", count * num_frames / total) | |
total = 0 | |
count = 0 | |
print("FPS: ", count * num_frames / total) | |