fffiloni commited on
Commit
96e731e
·
verified ·
1 Parent(s): d5deb3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -23
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import subprocess
2
  import re
 
3
 
4
  # Define the command to be executed
5
  command = ["python", "setup.py", "build_ext", "--inplace"]
@@ -43,7 +44,7 @@ def get_video_fps(video_path):
43
 
44
  return fps
45
 
46
- def preprocess_image(image):
47
  # we clean all
48
  return [
49
  image, # first_frame_path
@@ -59,10 +60,10 @@ def preprocess_video_in(video_path):
59
  unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
60
 
61
  # Set directory with this ID to store video frames
62
- output_dir = f'frames_{unique_id}'
63
 
64
  # Create the output directory
65
- os.makedirs(output_dir, exist_ok=True)
66
 
67
  ### Process video frames ###
68
  # Open the video file
@@ -87,7 +88,7 @@ def preprocess_video_in(video_path):
87
  break
88
 
89
  # Format the frame filename as '00000.jpg'
90
- frame_filename = os.path.join(output_dir, f'{frame_number:05d}.jpg')
91
 
92
  # Save the frame as a JPEG file
93
  cv2.imwrite(frame_filename, frame)
@@ -103,12 +104,11 @@ def preprocess_video_in(video_path):
103
 
104
  # scan all the JPEG frame names in this directory
105
  scanned_frames = [
106
- p for p in os.listdir(output_dir)
107
  if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
108
  ]
109
  scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
110
  print(f"SCANNED_FRAMES: {scanned_frames}")
111
-
112
 
113
  return [
114
  first_frame, # first_frame_path
@@ -116,7 +116,7 @@ def preprocess_video_in(video_path):
116
  gr.State([]), # trackings_input_label
117
  first_frame, # input_first_frame_image
118
  first_frame, # points_map
119
- output_dir, # video_frames_dir
120
  scanned_frames, # scanned_frames
121
  None, # stored_inference_state
122
  None, # stored_frame_names
@@ -195,46 +195,61 @@ def load_model(checkpoint):
195
  if checkpoint == "tiny":
196
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
197
  model_cfg = "sam2_hiera_t.yaml"
198
- return sam2_checkpoint, model_cfg
199
  elif checkpoint == "samll":
200
  sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
201
  model_cfg = "sam2_hiera_s.yaml"
202
- return sam2_checkpoint, model_cfg
203
  elif checkpoint == "base-plus":
204
  sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
205
  model_cfg = "sam2_hiera_b+.yaml"
206
- return sam2_checkpoint, model_cfg
207
  elif checkpoint == "large":
208
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
209
  model_cfg = "sam2_hiera_l.yaml"
210
- return sam2_checkpoint, model_cfg
211
 
212
 
213
 
214
- def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame, progress=gr.Progress(track_tqdm=True)):
215
- # 1. We need to preprocess the video and store frames in the right directory
216
- # — Penser à utiliser un ID unique pour le dossier
217
-
 
 
 
 
 
 
 
 
 
218
  sam2_checkpoint, model_cfg = load_model(checkpoint)
 
 
 
219
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
220
 
221
-
222
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
223
  print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
224
  video_dir = video_frames_dir
225
 
226
  # scan all the JPEG frame names in this directory
227
  frame_names = scanned_frames
228
-
 
229
  inference_state = predictor.init_state(video_path=video_dir)
 
230
 
231
  # segment and track one object
232
  # predictor.reset_state(inference_state) # if any previous tracking, reset
233
 
 
234
  new_working_frame = None
235
  # Add new point
236
- if working_frame == None:
237
- ann_frame_idx = 0 # the frame index we interact with
238
  new_working_frame = "frames_output_images/frame_0.jpg"
239
  else:
240
  # Use a regular expression to find the integer
@@ -244,6 +259,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
244
  frame_number = int(match.group(1))
245
  ann_frame_idx = frame_number
246
  new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
 
247
 
248
  ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
249
 
@@ -458,7 +474,7 @@ with gr.Blocks() as demo:
458
 
459
  # Clear every points clicked and added to the map
460
  clear_points_btn.click(
461
- fn = preprocess_image,
462
  inputs = input_first_frame_image, # we get the untouched hidden image
463
  outputs = [
464
  first_frame_path,
@@ -480,9 +496,21 @@ with gr.Blocks() as demo:
480
  """
481
 
482
  submit_btn.click(
483
- fn = sam_process,
484
- inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
485
- outputs = [output_result, stored_frame_names, stored_inference_state]
 
 
 
 
 
 
 
 
 
 
 
 
486
  )
487
 
488
  propagate_btn.click(
 
1
  import subprocess
2
  import re
3
+ from typing import List, Tuple, Optional
4
 
5
  # Define the command to be executed
6
  command = ["python", "setup.py", "build_ext", "--inplace"]
 
44
 
45
  return fps
46
 
47
+ def clear_points(image):
48
  # we clean all
49
  return [
50
  image, # first_frame_path
 
60
  unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
61
 
62
  # Set directory with this ID to store video frames
63
+ extracted_frames_output_dir = f'frames_{unique_id}'
64
 
65
  # Create the output directory
66
+ os.makedirs(extracted_frames_output_dir, exist_ok=True)
67
 
68
  ### Process video frames ###
69
  # Open the video file
 
88
  break
89
 
90
  # Format the frame filename as '00000.jpg'
91
+ frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
92
 
93
  # Save the frame as a JPEG file
94
  cv2.imwrite(frame_filename, frame)
 
104
 
105
  # scan all the JPEG frame names in this directory
106
  scanned_frames = [
107
+ p for p in os.listdir(extracted_frames_output_dir)
108
  if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
109
  ]
110
  scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
111
  print(f"SCANNED_FRAMES: {scanned_frames}")
 
112
 
113
  return [
114
  first_frame, # first_frame_path
 
116
  gr.State([]), # trackings_input_label
117
  first_frame, # input_first_frame_image
118
  first_frame, # points_map
119
+ extracted_frames_output_dir, # video_frames_dir
120
  scanned_frames, # scanned_frames
121
  None, # stored_inference_state
122
  None, # stored_frame_names
 
195
  if checkpoint == "tiny":
196
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
197
  model_cfg = "sam2_hiera_t.yaml"
198
+ return [sam2_checkpoint, model_cfg]
199
  elif checkpoint == "samll":
200
  sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
201
  model_cfg = "sam2_hiera_s.yaml"
202
+ return [sam2_checkpoint, model_cfg]
203
  elif checkpoint == "base-plus":
204
  sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
205
  model_cfg = "sam2_hiera_b+.yaml"
206
+ return [sam2_checkpoint, model_cfg]
207
  elif checkpoint == "large":
208
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
209
  model_cfg = "sam2_hiera_l.yaml"
210
+ return [sam2_checkpoint, model_cfg]
211
 
212
 
213
 
214
+ def get_mask_sam_process(
215
+ input_first_frame_image,
216
+ checkpoint,
217
+ tracking_points,
218
+ trackings_input_label,
219
+ video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
220
+ scanned_frames,
221
+ working_frame: str = None, # current frame being added points
222
+ progress=gr.Progress(track_tqdm=True)
223
+ ):
224
+
225
+ # get model and model config paths
226
+ print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
227
  sam2_checkpoint, model_cfg = load_model(checkpoint)
228
+ print("MODEL LOADED")
229
+
230
+ # set predictor
231
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
232
+ print("PREDICTOR READY")
233
 
 
234
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
235
  print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
236
  video_dir = video_frames_dir
237
 
238
  # scan all the JPEG frame names in this directory
239
  frame_names = scanned_frames
240
+
241
+ # Init SAM2 inference_state
242
  inference_state = predictor.init_state(video_path=video_dir)
243
+ print("NEW INFERENCE_STATE INITIATED")
244
 
245
  # segment and track one object
246
  # predictor.reset_state(inference_state) # if any previous tracking, reset
247
 
248
+ ### HANDLING WORKING FRAME
249
  new_working_frame = None
250
  # Add new point
251
+ if working_frame is None:
252
+ ann_frame_idx = 0 # the frame index we interact with, 0 if it is the first frame
253
  new_working_frame = "frames_output_images/frame_0.jpg"
254
  else:
255
  # Use a regular expression to find the integer
 
259
  frame_number = int(match.group(1))
260
  ann_frame_idx = frame_number
261
  new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
262
+ print(f"NEW_WORKING_FRAME PATH: {new_working_frame}")
263
 
264
  ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
265
 
 
474
 
475
  # Clear every points clicked and added to the map
476
  clear_points_btn.click(
477
+ fn = clear_points,
478
  inputs = input_first_frame_image, # we get the untouched hidden image
479
  outputs = [
480
  first_frame_path,
 
496
  """
497
 
498
  submit_btn.click(
499
+ fn = get_mask_sam_process,
500
+ inputs = [
501
+ input_first_frame_image,
502
+ checkpoint,
503
+ tracking_points,
504
+ trackings_input_label,
505
+ video_frames_dir,
506
+ scanned_frames,
507
+ working_frame,
508
+ ],
509
+ outputs = [
510
+ output_result,
511
+ stored_frame_names,
512
+ stored_inference_state,
513
+ ]
514
  )
515
 
516
  propagate_btn.click(