Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
|
@@ -256,13 +256,15 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
| 256 |
|
| 257 |
return "output_first_frame.jpg", frame_names, inference_state
|
| 258 |
|
| 259 |
-
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir):
|
| 260 |
#### PROPAGATION ####
|
| 261 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 262 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
|
| 263 |
inference_state = stored_inference_state
|
| 264 |
frame_names = stored_frame_names
|
| 265 |
video_dir = video_frames_dir
|
|
|
|
| 266 |
# Define a directory to save the JPEG images
|
| 267 |
frames_output_dir = "frames_output_images"
|
| 268 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
@@ -279,7 +281,10 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
|
|
| 279 |
}
|
| 280 |
|
| 281 |
# render the segmentation results every few frames
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
| 283 |
plt.close("all")
|
| 284 |
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
| 285 |
plt.figure(figsize=(6, 4))
|
|
@@ -298,7 +303,11 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
|
|
| 298 |
# Close the plot
|
| 299 |
plt.close()
|
| 300 |
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
with gr.Blocks() as demo:
|
| 304 |
first_frame_path = gr.State()
|
|
@@ -323,7 +332,7 @@ with gr.Blocks() as demo:
|
|
| 323 |
points_map = gr.Image(
|
| 324 |
label="points map",
|
| 325 |
type="filepath",
|
| 326 |
-
interactive=
|
| 327 |
)
|
| 328 |
video_in = gr.Video(label="Video IN")
|
| 329 |
with gr.Row():
|
|
@@ -333,8 +342,11 @@ with gr.Blocks() as demo:
|
|
| 333 |
submit_btn = gr.Button("Submit")
|
| 334 |
with gr.Column():
|
| 335 |
output_result = gr.Image()
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
|
|
|
| 338 |
# output_result_mask = gr.Image()
|
| 339 |
|
| 340 |
clear_points_btn.click(
|
|
@@ -366,8 +378,8 @@ with gr.Blocks() as demo:
|
|
| 366 |
|
| 367 |
propagate_btn.click(
|
| 368 |
fn = propagate_to_all,
|
| 369 |
-
inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir],
|
| 370 |
-
outputs = [output_propagated]
|
| 371 |
)
|
| 372 |
|
| 373 |
demo.launch(show_api=False, show_error=True)
|
|
|
|
| 256 |
|
| 257 |
return "output_first_frame.jpg", frame_names, inference_state
|
| 258 |
|
| 259 |
+
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type):
|
| 260 |
#### PROPAGATION ####
|
| 261 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
| 262 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
| 263 |
+
|
| 264 |
inference_state = stored_inference_state
|
| 265 |
frame_names = stored_frame_names
|
| 266 |
video_dir = video_frames_dir
|
| 267 |
+
|
| 268 |
# Define a directory to save the JPEG images
|
| 269 |
frames_output_dir = "frames_output_images"
|
| 270 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
|
|
| 281 |
}
|
| 282 |
|
| 283 |
# render the segmentation results every few frames
|
| 284 |
+
if vis_frame_type == "check":
|
| 285 |
+
vis_frame_stride = 15
|
| 286 |
+
elif vis_frame_type == "render":
|
| 287 |
+
vis_frame_stride = 1
|
| 288 |
plt.close("all")
|
| 289 |
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
| 290 |
plt.figure(figsize=(6, 4))
|
|
|
|
| 303 |
# Close the plot
|
| 304 |
plt.close()
|
| 305 |
|
| 306 |
+
if vis_frame_type == "check":
|
| 307 |
+
return gr.update(value=jpeg_images, visible=True), gr.update(visible=False, value=None)
|
| 308 |
+
elif vis_frame_type == "render":
|
| 309 |
+
return gr.update(visible=False, value=None), gr.update(value=final_vid, visible=True)
|
| 310 |
+
|
| 311 |
|
| 312 |
with gr.Blocks() as demo:
|
| 313 |
first_frame_path = gr.State()
|
|
|
|
| 332 |
points_map = gr.Image(
|
| 333 |
label="points map",
|
| 334 |
type="filepath",
|
| 335 |
+
interactive=False
|
| 336 |
)
|
| 337 |
video_in = gr.Video(label="Video IN")
|
| 338 |
with gr.Row():
|
|
|
|
| 342 |
submit_btn = gr.Button("Submit")
|
| 343 |
with gr.Column():
|
| 344 |
output_result = gr.Image()
|
| 345 |
+
with gr.Row():
|
| 346 |
+
vis_frame_type = gr.Radio(choices=["check", "render"], value="render", scale=2)
|
| 347 |
+
propagate_btn = gr.Button("Propagate", scale=1)
|
| 348 |
+
output_propagated = gr.Gallery(visible=False)
|
| 349 |
+
output_video = gr.Video(visible=False)
|
| 350 |
# output_result_mask = gr.Image()
|
| 351 |
|
| 352 |
clear_points_btn.click(
|
|
|
|
| 378 |
|
| 379 |
propagate_btn.click(
|
| 380 |
fn = propagate_to_all,
|
| 381 |
+
inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type],
|
| 382 |
+
outputs = [output_propagated, output_video]
|
| 383 |
)
|
| 384 |
|
| 385 |
demo.launch(show_api=False, show_error=True)
|