Mar2Ding commited on
Commit
3145ec5
·
verified ·
1 Parent(s): 4ed4483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -169,12 +169,12 @@ def get_point(point_type, tracking_points, trackings_input_label, input_first_fr
169
  return tracking_points, trackings_input_label, selected_point_map
170
 
171
  # use bfloat16 for the entire notebook
172
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
173
 
174
- if torch.cuda.get_device_properties(0).major >= 8:
175
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
176
- torch.backends.cuda.matmul.allow_tf32 = True
177
- torch.backends.cudnn.allow_tf32 = True
178
 
179
  def show_mask(mask, ax, obj_id=None, random_color=False):
180
  if random_color:
@@ -240,7 +240,11 @@ def get_mask_sam_process(
240
  print("MODEL LOADED")
241
 
242
  # set predictor
243
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
 
 
 
244
  print("PREDICTOR READY")
245
 
246
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
@@ -287,6 +291,8 @@ def get_mask_sam_process(
287
  points = np.array(tracking_points, dtype=np.float32)
288
  # for labels, `1` means positive click and `0` means negative click
289
  labels = np.array(trackings_input_label, np.int32)
 
 
290
  _, out_obj_ids, out_mask_logits = predictor.add_new_points(
291
  inference_state=inference_state,
292
  frame_idx=ann_frame_idx,
@@ -306,7 +312,7 @@ def get_mask_sam_process(
306
  first_frame_output_filename = "output_first_frame.jpg"
307
  plt.savefig(first_frame_output_filename, format='jpg')
308
  plt.close()
309
- torch.cuda.empty_cache()
310
 
311
  # Assuming available_frames_to_check.value is a list
312
  if working_frame not in available_frames_to_check:
@@ -316,11 +322,15 @@ def get_mask_sam_process(
316
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
317
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
318
 
319
- @spaces.GPU(duration=120)
320
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
321
  #### PROPAGATION ####
322
  sam2_checkpoint, model_cfg = load_model(checkpoint)
323
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
 
 
 
324
 
325
  inference_state = stored_inference_state
326
  frame_names = stored_frame_names
 
169
  return tracking_points, trackings_input_label, selected_point_map
170
 
171
  # use bfloat16 for the entire notebook
172
+ # torch.autocast(device_type="cpu", dtype=torch.bfloat16).__enter__()
173
 
174
+ # if torch.cuda.get_device_properties(0).major >= 8:
175
+ # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
176
+ # torch.backends.cuda.matmul.allow_tf32 = True
177
+ # torch.backends.cudnn.allow_tf32 = True
178
 
179
  def show_mask(mask, ax, obj_id=None, random_color=False):
180
  if random_color:
 
240
  print("MODEL LOADED")
241
 
242
  # set predictor
243
+ if torch.cuda.is_available():
244
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
245
+ else:
246
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
247
+
248
  print("PREDICTOR READY")
249
 
250
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
 
291
  points = np.array(tracking_points, dtype=np.float32)
292
  # for labels, `1` means positive click and `0` means negative click
293
  labels = np.array(trackings_input_label, np.int32)
294
+
295
+
296
  _, out_obj_ids, out_mask_logits = predictor.add_new_points(
297
  inference_state=inference_state,
298
  frame_idx=ann_frame_idx,
 
312
  first_frame_output_filename = "output_first_frame.jpg"
313
  plt.savefig(first_frame_output_filename, format='jpg')
314
  plt.close()
315
+ # torch.cuda.empty_cache()
316
 
317
  # Assuming available_frames_to_check.value is a list
318
  if working_frame not in available_frames_to_check:
 
322
  # return gr.update(visible=True), "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=True)
323
  return "output_first_frame.jpg", frame_names, predictor, inference_state, gr.update(choices=available_frames_to_check, value=working_frame, visible=False)
324
 
325
+ @spaces.GPU(duration=110)
326
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, available_frames_to_check, working_frame, progress=gr.Progress(track_tqdm=True)):
327
  #### PROPAGATION ####
328
  sam2_checkpoint, model_cfg = load_model(checkpoint)
329
+ if torch.cuda.is_available():
330
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
331
+ else:
332
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device='cpu')
333
+
334
 
335
  inference_state = stored_inference_state
336
  frame_names = stored_frame_names