Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -169,12 +169,12 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
|
|
169 |
return tracking_points, trackings_input_label, selected_point_map
|
170 |
|
171 |
# use bfloat16 for the entire notebook
|
172 |
-
torch.autocast(device_type="
|
173 |
|
174 |
-
if torch.cuda.get_device_properties(0).major >= 8:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
|
179 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
180 |
if random_color:
|
@@ -240,7 +240,11 @@ def get_mask_sam_process(
|
|
240 |
print("MODEL LOADED")
|
241 |
|
242 |
# set predictor
|
243 |
-
|
|
|
|
|
|
|
|
|
244 |
print("PREDICTOR READY")
|
245 |
|
246 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
@@ -287,6 +291,8 @@ def get_mask_sam_process(
|
|
287 |
points = np.array(tracking_points, dtype=np.float32)
|
288 |
# for labels, `1` means positive click and `0` means negative click
|
289 |
labels = np.array(trackings_input_label, np.int32)
|
|
|
|
|
290 |
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
|
291 |
inference_state=inference_state,
|
292 |
frame_idx=ann_frame_idx,
|
@@ -306,7 +312,7 @@ def get_mask_sam_process(
|
|
306 |
first_frame_output_filename = "output_first_frame.jpg"
|
307 |
plt.savefig(first_frame_output_filename, format='jpg')
|
308 |
plt.close()
|
309 |
-
torch.cuda.empty_cache()
|
310 |
|
311 |
# Assuming available_frames_to_check.value is a list
|
312 |
if working_frame not in available_frames_to_check:
|
@@ -316,11 +322,15 @@ def get_mask_sam_process(
|
|
316 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
317 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
318 |
|
319 |
-
@spaces.GPU(duration=
|
320 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
321 |
#### PROPAGATION ####
|
322 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
323 |
-
|
|
|
|
|
|
|
|
|
324 |
|
325 |
inference_state = stored_inference_state
|
326 |
frame_names = stored_frame_names
|
|
|
169 |
return tracking_points, trackings_input_label, selected_point_map
|
170 |
|
171 |
# use bfloat16 for the entire notebook
|
172 |
+
# torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
|
173 |
|
174 |
+
# if torch.cuda.get_device_properties(0).major >= 8:
|
175 |
+
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
176 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
177 |
+
# torch.backends.cudnn.allow_tf32 = True
|
178 |
|
179 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
180 |
if random_color:
|
|
|
240 |
print("MODEL LOADED")
|
241 |
|
242 |
# set predictor
|
243 |
+
if torch.cuda.is_available():
|
244 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
245 |
+
else:
|
246 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
|
247 |
+
|
248 |
print("PREDICTOR READY")
|
249 |
|
250 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
|
|
291 |
points = np.array(tracking_points, dtype=np.float32)
|
292 |
# for labels, `1` means positive click and `0` means negative click
|
293 |
labels = np.array(trackings_input_label, np.int32)
|
294 |
+
|
295 |
+
|
296 |
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
|
297 |
inference_state=inference_state,
|
298 |
frame_idx=ann_frame_idx,
|
|
|
312 |
first_frame_output_filename = "output_first_frame.jpg"
|
313 |
plt.savefig(first_frame_output_filename, format='jpg')
|
314 |
plt.close()
|
315 |
+
# torch.cuda.empty_cache()
|
316 |
|
317 |
# Assuming available_frames_to_check.value is a list
|
318 |
if working_frame not in available_frames_to_check:
|
|
|
322 |
# return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
|
323 |
return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
|
324 |
|
325 |
+
@spaces.GPU(duration=110)
|
326 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
327 |
#### PROPAGATION ####
|
328 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
329 |
+
if torch.cuda.is_available():
|
330 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
331 |
+
else:
|
332 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
|
333 |
+
|
334 |
|
335 |
inference_state = stored_inference_state
|
336 |
frame_names = stored_frame_names
|