Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ from PIL import Image, ImageFilter
|
|
28 |
from sam2.build_sam import build_sam2_video_predictor
|
29 |
|
30 |
def preprocess_image(image):
|
31 |
-
return image, gr.State([]), gr.State([]), image
|
32 |
|
33 |
def preprocess_video_in(video_path):
|
34 |
|
@@ -70,7 +70,7 @@ def preprocess_video_in(video_path):
|
|
70 |
cap.release()
|
71 |
|
72 |
# 'image' is the first frame extracted from video_in
|
73 |
-
return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir
|
74 |
|
75 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
76 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
@@ -184,12 +184,7 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
|
|
184 |
|
185 |
return combined_images, mask_images
|
186 |
|
187 |
-
|
188 |
-
def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, frames_output_dir):
|
189 |
-
# 1. We need to preprocess the video and store frames in the right directory
|
190 |
-
# — Penser à utiliser un ID unique pour le dossier
|
191 |
-
|
192 |
-
|
193 |
# Load model accordingly to user's choice
|
194 |
if checkpoint == "tiny":
|
195 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
@@ -203,13 +198,20 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
203 |
elif checkpoint == "large":
|
204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
205 |
model_cfg = "sam2_hiera_l.yaml"
|
|
|
|
|
206 |
|
|
|
|
|
|
|
|
|
|
|
207 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
208 |
|
209 |
|
210 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
211 |
-
print(f"STATE FRAME OUTPUT DIRECTORY: {
|
212 |
-
video_dir =
|
213 |
|
214 |
# scan all the JPEG frame names in this directory
|
215 |
frame_names = [
|
@@ -248,13 +250,18 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
248 |
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
|
249 |
|
250 |
# Save the plot as a JPG file
|
251 |
-
|
252 |
-
plt.savefig(
|
253 |
plt.close()
|
254 |
|
255 |
-
|
256 |
-
#### PROPAGATION ####
|
257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
# Define a directory to save the JPEG images
|
259 |
frames_output_dir = "frames_output_images"
|
260 |
os.makedirs(frames_output_dir, exist_ok=True)
|
@@ -289,16 +296,16 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
289 |
|
290 |
# Close the plot
|
291 |
plt.close()
|
292 |
-
|
293 |
-
# OLD
|
294 |
|
295 |
-
return
|
296 |
|
297 |
with gr.Blocks() as demo:
|
298 |
first_frame_path = gr.State()
|
299 |
tracking_points = gr.State([])
|
300 |
trackings_input_label = gr.State([])
|
301 |
-
|
|
|
|
|
302 |
with gr.Column():
|
303 |
gr.Markdown("# SAM2 Video Predictor")
|
304 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
@@ -325,20 +332,21 @@ with gr.Blocks() as demo:
|
|
325 |
submit_btn = gr.Button("Submit")
|
326 |
with gr.Column():
|
327 |
output_result = gr.Image()
|
|
|
328 |
output_propagated = gr.Gallery()
|
329 |
# output_result_mask = gr.Image()
|
330 |
|
331 |
clear_points_btn.click(
|
332 |
fn = preprocess_image,
|
333 |
inputs = input_first_frame_image,
|
334 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
|
335 |
queue=False
|
336 |
)
|
337 |
|
338 |
video_in.upload(
|
339 |
fn = preprocess_video_in,
|
340 |
inputs = [video_in],
|
341 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map,
|
342 |
queue = False
|
343 |
)
|
344 |
|
@@ -351,8 +359,14 @@ with gr.Blocks() as demo:
|
|
351 |
|
352 |
submit_btn.click(
|
353 |
fn = sam_process,
|
354 |
-
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label,
|
355 |
-
outputs = [output_result,
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
)
|
357 |
|
358 |
demo.launch(show_api=False, show_error=True)
|
|
|
28 |
from sam2.build_sam import build_sam2_video_predictor
|
29 |
|
30 |
def preprocess_image(image):
|
31 |
+
return image, gr.State([]), gr.State([]), image, gr.State([])
|
32 |
|
33 |
def preprocess_video_in(video_path):
|
34 |
|
|
|
70 |
cap.release()
|
71 |
|
72 |
# 'image' is the first frame extracted from video_in
|
73 |
+
return first_frame, gr.State([]), gr.State([]), first_frame, first_frame, output_dir, gr.State([]), gr.State([])
|
74 |
|
75 |
def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
|
76 |
print(f"You selected {evt.value} at {evt.index} from {evt.target}")
|
|
|
184 |
|
185 |
return combined_images, mask_images
|
186 |
|
187 |
+
def load_model(checkpoint):
|
|
|
|
|
|
|
|
|
|
|
188 |
# Load model accordingly to user's choice
|
189 |
if checkpoint == "tiny":
|
190 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
|
|
198 |
elif checkpoint == "large":
|
199 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
200 |
model_cfg = "sam2_hiera_l.yaml"
|
201 |
+
|
202 |
+
return sam2_checkpoint, model_cfg
|
203 |
|
204 |
+
def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir):
|
205 |
+
# 1. We need to preprocess the video and store frames in the right directory
|
206 |
+
# — Penser à utiliser un ID unique pour le dossier
|
207 |
+
|
208 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
209 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
210 |
|
211 |
|
212 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
213 |
+
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
214 |
+
video_dir = video_frames_dir
|
215 |
|
216 |
# scan all the JPEG frame names in this directory
|
217 |
frame_names = [
|
|
|
250 |
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
|
251 |
|
252 |
# Save the plot as a JPG file
|
253 |
+
first_frame_output_filename = "output_first_frame.jpg"
|
254 |
+
plt.savefig(first_frame_output_filename, format='jpg')
|
255 |
plt.close()
|
256 |
|
257 |
+
return "output_first_frame.jpg", frame_names, inference_state
|
|
|
258 |
|
259 |
+
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names):
|
260 |
+
#### PROPAGATION ####
|
261 |
+
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
262 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
263 |
+
inference_state = stored_inference_state.value
|
264 |
+
frame_names = stored_frame_names.value
|
265 |
# Define a directory to save the JPEG images
|
266 |
frames_output_dir = "frames_output_images"
|
267 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
|
296 |
|
297 |
# Close the plot
|
298 |
plt.close()
|
|
|
|
|
299 |
|
300 |
+
return jpeg_images
|
301 |
|
302 |
with gr.Blocks() as demo:
|
303 |
first_frame_path = gr.State()
|
304 |
tracking_points = gr.State([])
|
305 |
trackings_input_label = gr.State([])
|
306 |
+
video_frames_dir = gr.State()
|
307 |
+
stored_inference_state = gr.State([])
|
308 |
+
stored_frame_names = gr.State([])
|
309 |
with gr.Column():
|
310 |
gr.Markdown("# SAM2 Video Predictor")
|
311 |
gr.Markdown("This is a simple demo for video segmentation with SAM2.")
|
|
|
332 |
submit_btn = gr.Button("Submit")
|
333 |
with gr.Column():
|
334 |
output_result = gr.Image()
|
335 |
+
propagate_btn = gr.Button("Propagate")
|
336 |
output_propagated = gr.Gallery()
|
337 |
# output_result_mask = gr.Image()
|
338 |
|
339 |
clear_points_btn.click(
|
340 |
fn = preprocess_image,
|
341 |
inputs = input_first_frame_image,
|
342 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, points_map, stored_inference_state],
|
343 |
queue=False
|
344 |
)
|
345 |
|
346 |
video_in.upload(
|
347 |
fn = preprocess_video_in,
|
348 |
inputs = [video_in],
|
349 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, video_frames_dir, stored_inference_state, stored_frame_names],
|
350 |
queue = False
|
351 |
)
|
352 |
|
|
|
359 |
|
360 |
submit_btn.click(
|
361 |
fn = sam_process,
|
362 |
+
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir],
|
363 |
+
outputs = [output_result, stored_frame_names, stored_inference_state]
|
364 |
+
)
|
365 |
+
|
366 |
+
propagate_btn.click(
|
367 |
+
fn = propagate_to_all,
|
368 |
+
inputs = [checkpoint, stored_inference_state, stored_frame_names],
|
369 |
+
outputs = [output_propagated]
|
370 |
)
|
371 |
|
372 |
demo.launch(show_api=False, show_error=True)
|