fffiloni commited on
Commit
f2010da
·
verified ·
1 Parent(s): e799376

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -22
app.py CHANGED
@@ -28,7 +28,7 @@ from PIL import Image, ImageFilter
28
  from sam2.build_sam import build_sam2_video_predictor
29
 
30
  def preprocess_image(image):
31
- return image, gr.State([]), gr.State([]), image
32
 
33
  def preprocess_video_in(video_path):
34
 
@@ -70,7 +70,7 @@ def preprocess_video_in(video_path):
70
  cap.release()
71
 
72
  # 'image' is the first frame extracted from video_in
73
- return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir
74
 
75
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
76
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
@@ -184,12 +184,7 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
184
 
185
  return combined_images, mask_images
186
 
187
-
188
- def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, frames_output_dir):
189
- # 1. We need to preprocess the video and store frames in the right directory
190
- # — Penser à utiliser un ID unique pour le dossier
191
-
192
-
193
  # Load model accordingly to user's choice
194
  if checkpoint == "tiny":
195
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
@@ -203,13 +198,20 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
203
  elif checkpoint == "large":
204
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
205
  model_cfg = "sam2_hiera_l.yaml"
 
 
206
 
 
 
 
 
 
207
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
208
 
209
 
210
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
211
- print(f"STATE FRAME OUTPUT DIRECTORY: {frames_output_dir}")
212
- video_dir = frames_output_dir
213
 
214
  # scan all the JPEG frame names in this directory
215
  frame_names = [
@@ -248,13 +250,18 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
248
  show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
249
 
250
  # Save the plot as a JPG file
251
- output_filename = "output_frame.jpg"
252
- plt.savefig(output_filename, format='jpg')
253
  plt.close()
254
 
255
-
256
- #### PROPAGATION ####
257
 
 
 
 
 
 
 
258
  # Define a directory to save the JPEG images
259
  frames_output_dir = "frames_output_images"
260
  os.makedirs(frames_output_dir, exist_ok=True)
@@ -289,16 +296,16 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
289
 
290
  # Close the plot
291
  plt.close()
292
-
293
- # OLD
294
 
295
- return output_filename, jpeg_images
296
 
297
  with gr.Blocks() as demo:
298
  first_frame_path = gr.State()
299
  tracking_points = gr.State([])
300
  trackings_input_label = gr.State([])
301
- frames_output_dir = gr.State()
 
 
302
  with gr.Column():
303
  gr.Markdown("# SAM2 Video Predictor")
304
  gr.Markdown("This is a simple demo for video segmentation with SAM2.")
@@ -325,20 +332,21 @@ with gr.Blocks() as demo:
325
  submit_btn = gr.Button("Submit")
326
  with gr.Column():
327
  output_result = gr.Image()
 
328
  output_propagated = gr.Gallery()
329
  # output_result_mask = gr.Image()
330
 
331
  clear_points_btn.click(
332
  fn = preprocess_image,
333
  inputs = input_first_frame_image,
334
- outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
335
  queue=False
336
  )
337
 
338
  video_in.upload(
339
  fn = preprocess_video_in,
340
  inputs = [video_in],
341
- outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, frames_output_dir],
342
  queue = False
343
  )
344
 
@@ -351,8 +359,14 @@ with gr.Blocks() as demo:
351
 
352
  submit_btn.click(
353
  fn = sam_process,
354
- inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, frames_output_dir],
355
- outputs = [output_result, output_propagated]
 
 
 
 
 
 
356
  )
357
 
358
  demo.launch(show_api=False, show_error=True)
 
28
  from sam2.build_sam import build_sam2_video_predictor
29
 
30
  def preprocess_image(image):
31
+ return image, gr.State([]), gr.State([]), image, gr.State([])
32
 
33
  def preprocess_video_in(video_path):
34
 
 
70
  cap.release()
71
 
72
  # 'image' is the first frame extracted from video_in
73
+ return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir, gr.State([]), gr.State([])
74
 
75
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
76
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
 
184
 
185
  return combined_images, mask_images
186
 
187
+ def load_model(checkpoint):
 
 
 
 
 
188
  # Load model accordingly to user's choice
189
  if checkpoint == "tiny":
190
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
 
198
  elif checkpoint == "large":
199
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
200
  model_cfg = "sam2_hiera_l.yaml"
201
+
202
+ return sam2_checkpoint, model_cfg
203
 
204
+ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir):
205
+ # 1. We need to preprocess the video and store frames in the right directory
206
+ # — Penser à utiliser un ID unique pour le dossier
207
+
208
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
209
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
210
 
211
 
212
  # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
213
+ print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
214
+ video_dir = video_frames_dir
215
 
216
  # scan all the JPEG frame names in this directory
217
  frame_names = [
 
250
  show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
251
 
252
  # Save the plot as a JPG file
253
+ first_frame_output_filename = "output_first_frame.jpg"
254
+ plt.savefig(first_frame_output_filename, format='jpg')
255
  plt.close()
256
 
257
+ return "output_first_frame.jpg", frame_names, inference_state
 
258
 
259
+ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names):
260
+ #### PROPAGATION ####
261
+ sam2_checkpoint, model_cfg = load_model(checkpoint)
262
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
263
+ inference_state = stored_inference_state.value
264
+ frame_names = stored_frame_names.value
265
  # Define a directory to save the JPEG images
266
  frames_output_dir = "frames_output_images"
267
  os.makedirs(frames_output_dir, exist_ok=True)
 
296
 
297
  # Close the plot
298
  plt.close()
 
 
299
 
300
+ return jpeg_images
301
 
302
  with gr.Blocks() as demo:
303
  first_frame_path = gr.State()
304
  tracking_points = gr.State([])
305
  trackings_input_label = gr.State([])
306
+ video_frames_dir = gr.State()
307
+ stored_inference_state = gr.State([])
308
+ stored_frame_names = gr.State([])
309
  with gr.Column():
310
  gr.Markdown("# SAM2 Video Predictor")
311
  gr.Markdown("This is a simple demo for video segmentation with SAM2.")
 
332
  submit_btn = gr.Button("Submit")
333
  with gr.Column():
334
  output_result = gr.Image()
335
+ propagate_btn = gr.Button("Propagate")
336
  output_propagated = gr.Gallery()
337
  # output_result_mask = gr.Image()
338
 
339
  clear_points_btn.click(
340
  fn = preprocess_image,
341
  inputs = input_first_frame_image,
342
+ outputs = [first_frame_path, tracking_points, trackings_input_label, points_map, stored_inference_state],
343
  queue=False
344
  )
345
 
346
  video_in.upload(
347
  fn = preprocess_video_in,
348
  inputs = [video_in],
349
+ outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, video_frames_dir, stored_inference_state, stored_frame_names],
350
  queue = False
351
  )
352
 
 
359
 
360
  submit_btn.click(
361
  fn = sam_process,
362
+ inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir],
363
+ outputs = [output_result, stored_frame_names, stored_inference_state]
364
+ )
365
+
366
+ propagate_btn.click(
367
+ fn = propagate_to_all,
368
+ inputs = [checkpoint, stored_inference_state, stored_frame_names],
369
+ outputs = [output_propagated]
370
  )
371
 
372
  demo.launch(show_api=False, show_error=True)