Mar2Ding commited on
Commit
bcb3f59
·
verified ·
1 Parent(s): 9ab77fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -177,13 +177,7 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
177
 
178
  return tracking_points, trackings_input_label, selected_point_map
179
 
180
- # use bfloat16 for the entire notebook
181
- # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
182
 
183
- # if torch.cuda.get_device_properties(0).major >= 8:
184
- # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
185
- # torch.backends.cuda.matmul.allow_tf32 = True
186
- # torch.backends.cudnn.allow_tf32 = True
187
 
188
  def show_mask(mask, ax, obj_id=None, random_color=False):
189
  if random_color:
@@ -335,6 +329,14 @@ def get_mask_sam_process(
335
 
336
  @spaces.GPU
337
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
338
  #### PROPAGATION ####
339
  sam2_checkpoint, model_cfg = load_model(checkpoint)
340
  # set predictor
@@ -530,9 +532,7 @@ with gr.Blocks(css=css) as demo:
530
  with gr.Group():
531
  with gr.Row():
532
  vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
533
- # Use gr.Column to center the button vertically
534
- with gr.Column():
535
- propagate_btn = gr.Button("Propagate", scale=2)
536
 
537
  reset_prpgt_brn = gr.Button("Reset", visible=False)
538
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)
 
177
 
178
  return tracking_points, trackings_input_label, selected_point_map
179
 
 
 
180
 
 
 
 
 
181
 
182
  def show_mask(mask, ax, obj_id=None, random_color=False):
183
  if random_color:
 
329
 
330
  @spaces.GPU
331
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
332
+ # use bfloat16 for the entire notebook
333
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
334
+
335
+ if torch.cuda.get_device_properties(0).major >= 8:
336
+ # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
337
+ torch.backends.cuda.matmul.allow_tf32 = True
338
+ torch.backends.cudnn.allow_tf32 = True
339
+
340
  #### PROPAGATION ####
341
  sam2_checkpoint, model_cfg = load_model(checkpoint)
342
  # set predictor
 
532
  with gr.Group():
533
  with gr.Row():
534
  vis_frame_type = gr.Radio(label="Propagation level", choices=["check", "render"], value="check", scale=2)
535
+ propagate_btn = gr.Button("Propagate", scale=2)
 
 
536
 
537
  reset_prpgt_brn = gr.Button("Reset", visible=False)
538
  output_propagated = gr.Gallery(label="Propagated Mask samples gallery", columns=4, visible=False)