Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -177,13 +177,7 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
|
|
177 |
|
178 |
return tracking_points, trackings_input_label, selected_point_map
|
179 |
|
180 |
-
# use bfloat16 for the entire notebook
|
181 |
-
# torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
|
182 |
|
183 |
-
# if torch.cuda.get_device_properties(0).major >= 8:
|
184 |
-
# # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
185 |
-
# torch.backends.cuda.matmul.allow_tf32 = True
|
186 |
-
# torch.backends.cudnn.allow_tf32 = True
|
187 |
|
188 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
189 |
if random_color:
|
@@ -335,6 +329,14 @@ def get_mask_sam_process(
|
|
335 |
|
336 |
@spaces.GPU
|
337 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
#### PROPAGATION ####
|
339 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
340 |
# set predictor
|
@@ -530,9 +532,7 @@ with gr.Blocks(css=css) as demo:
|
|
530 |
with gr.Group():
|
531 |
with gr.Row():
|
532 |
vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
|
533 |
-
|
534 |
-
with gr.Column():
|
535 |
-
propagate_btn = gr.Button("Propagate", scale=2)
|
536 |
|
537 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
538 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|
|
|
177 |
|
178 |
return tracking_points, trackings_input_label, selected_point_map
|
179 |
|
|
|
|
|
180 |
|
|
|
|
|
|
|
|
|
181 |
|
182 |
def show_mask(mask, ax, obj_id=None, random_color=False):
|
183 |
if random_color:
|
|
|
329 |
|
330 |
@spaces.GPU
|
331 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
|
332 |
+
# use bfloat16 for the entire notebook
|
333 |
+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
334 |
+
|
335 |
+
if torch.cuda.get_device_properties(0).major >= 8:
|
336 |
+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
|
337 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
338 |
+
torch.backends.cudnn.allow_tf32 = True
|
339 |
+
|
340 |
#### PROPAGATION ####
|
341 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
342 |
# set predictor
|
|
|
532 |
with gr.Group():
|
533 |
with gr.Row():
|
534 |
vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
|
535 |
+
propagate_btn = gr.Button("Propagate", scale=2)
|
|
|
|
|
536 |
|
537 |
reset_prpgt_brn = gr.Button("Reset", visible=False)
|
538 |
output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
|