Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import subprocess
|
2 |
import re
|
|
|
3 |
|
4 |
# Define the command to be executed
|
5 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
@@ -43,7 +44,7 @@ def get_video_fps(video_path):
|
|
43 |
|
44 |
return fps
|
45 |
|
46 |
-
def
|
47 |
# we clean all
|
48 |
return [
|
49 |
image, # first_frame_path
|
@@ -59,10 +60,10 @@ def preprocess_video_in(video_path):
|
|
59 |
unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
|
60 |
|
61 |
# Set directory with this ID to store video frames
|
62 |
-
|
63 |
|
64 |
# Create the output directory
|
65 |
-
os.makedirs(
|
66 |
|
67 |
### Process video frames ###
|
68 |
# Open the video file
|
@@ -87,7 +88,7 @@ def preprocess_video_in(video_path):
|
|
87 |
break
|
88 |
|
89 |
# Format the frame filename as '00000.jpg'
|
90 |
-
frame_filename = os.path.join(
|
91 |
|
92 |
# Save the frame as a JPEG file
|
93 |
cv2.imwrite(frame_filename, frame)
|
@@ -103,12 +104,11 @@ def preprocess_video_in(video_path):
|
|
103 |
|
104 |
# scan all the JPEG frame names in this directory
|
105 |
scanned_frames = [
|
106 |
-
p for p in os.listdir(
|
107 |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
108 |
]
|
109 |
scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
110 |
print(f"SCANNED_FRAMES: {scanned_frames}")
|
111 |
-
|
112 |
|
113 |
return [
|
114 |
first_frame, # first_frame_path
|
@@ -116,7 +116,7 @@ def preprocess_video_in(video_path):
|
|
116 |
gr.State([]), # trackings_input_label
|
117 |
first_frame, # input_first_frame_image
|
118 |
first_frame, # points_map
|
119 |
-
|
120 |
scanned_frames, # scanned_frames
|
121 |
None, # stored_inference_state
|
122 |
None, # stored_frame_names
|
@@ -195,46 +195,61 @@ def load_model(checkpoint):
|
|
195 |
if checkpoint == "tiny":
|
196 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
197 |
model_cfg = "sam2_hiera_t.yaml"
|
198 |
-
return sam2_checkpoint, model_cfg
|
199 |
elif checkpoint == "samll":
|
200 |
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
|
201 |
model_cfg = "sam2_hiera_s.yaml"
|
202 |
-
return sam2_checkpoint, model_cfg
|
203 |
elif checkpoint == "base-plus":
|
204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
|
205 |
model_cfg = "sam2_hiera_b+.yaml"
|
206 |
-
return sam2_checkpoint, model_cfg
|
207 |
elif checkpoint == "large":
|
208 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
209 |
model_cfg = "sam2_hiera_l.yaml"
|
210 |
-
return sam2_checkpoint, model_cfg
|
211 |
|
212 |
|
213 |
|
214 |
-
def
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
|
|
|
|
|
|
219 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
220 |
|
221 |
-
|
222 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
223 |
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
224 |
video_dir = video_frames_dir
|
225 |
|
226 |
# scan all the JPEG frame names in this directory
|
227 |
frame_names = scanned_frames
|
228 |
-
|
|
|
229 |
inference_state = predictor.init_state(video_path=video_dir)
|
|
|
230 |
|
231 |
# segment and track one object
|
232 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
233 |
|
|
|
234 |
new_working_frame = None
|
235 |
# Add new point
|
236 |
-
if working_frame
|
237 |
-
ann_frame_idx = 0 # the frame index we interact with
|
238 |
new_working_frame = "frames_output_images/frame_0.jpg"
|
239 |
else:
|
240 |
# Use a regular expression to find the integer
|
@@ -244,6 +259,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
244 |
frame_number = int(match.group(1))
|
245 |
ann_frame_idx = frame_number
|
246 |
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
|
|
247 |
|
248 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
249 |
|
@@ -458,7 +474,7 @@ with gr.Blocks() as demo:
|
|
458 |
|
459 |
# Clear every points clicked and added to the map
|
460 |
clear_points_btn.click(
|
461 |
-
fn =
|
462 |
inputs = input_first_frame_image, # we get the untouched hidden image
|
463 |
outputs = [
|
464 |
first_frame_path,
|
@@ -480,9 +496,21 @@ with gr.Blocks() as demo:
|
|
480 |
"""
|
481 |
|
482 |
submit_btn.click(
|
483 |
-
fn =
|
484 |
-
inputs = [
|
485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
)
|
487 |
|
488 |
propagate_btn.click(
|
|
|
1 |
import subprocess
|
2 |
import re
|
3 |
+
from typing import List, Tuple, Optional
|
4 |
|
5 |
# Define the command to be executed
|
6 |
command = ["python", "setup.py", "build_ext", "--inplace"]
|
|
|
44 |
|
45 |
return fps
|
46 |
|
47 |
+
def clear_points(image):
|
48 |
# we clean all
|
49 |
return [
|
50 |
image, # first_frame_path
|
|
|
60 |
unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
|
61 |
|
62 |
# Set directory with this ID to store video frames
|
63 |
+
extracted_frames_output_dir = f'frames_{unique_id}'
|
64 |
|
65 |
# Create the output directory
|
66 |
+
os.makedirs(extracted_frames_output_dir, exist_ok=True)
|
67 |
|
68 |
### Process video frames ###
|
69 |
# Open the video file
|
|
|
88 |
break
|
89 |
|
90 |
# Format the frame filename as '00000.jpg'
|
91 |
+
frame_filename = os.path.join(extracted_frames_output_dir, f'{frame_number:05d}.jpg')
|
92 |
|
93 |
# Save the frame as a JPEG file
|
94 |
cv2.imwrite(frame_filename, frame)
|
|
|
104 |
|
105 |
# scan all the JPEG frame names in this directory
|
106 |
scanned_frames = [
|
107 |
+
p for p in os.listdir(extracted_frames_output_dir)
|
108 |
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
|
109 |
]
|
110 |
scanned_frames.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
111 |
print(f"SCANNED_FRAMES: {scanned_frames}")
|
|
|
112 |
|
113 |
return [
|
114 |
first_frame, # first_frame_path
|
|
|
116 |
gr.State([]), # trackings_input_label
|
117 |
first_frame, # input_first_frame_image
|
118 |
first_frame, # points_map
|
119 |
+
extracted_frames_output_dir, # video_frames_dir
|
120 |
scanned_frames, # scanned_frames
|
121 |
None, # stored_inference_state
|
122 |
None, # stored_frame_names
|
|
|
195 |
if checkpoint == "tiny":
|
196 |
sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
|
197 |
model_cfg = "sam2_hiera_t.yaml"
|
198 |
+
return [sam2_checkpoint, model_cfg]
|
199 |
elif checkpoint == "samll":
|
200 |
sam2_checkpoint = "./checkpoints/sam2_hiera_small.pt"
|
201 |
model_cfg = "sam2_hiera_s.yaml"
|
202 |
+
return [sam2_checkpoint, model_cfg]
|
203 |
elif checkpoint == "base-plus":
|
204 |
sam2_checkpoint = "./checkpoints/sam2_hiera_base_plus.pt"
|
205 |
model_cfg = "sam2_hiera_b+.yaml"
|
206 |
+
return [sam2_checkpoint, model_cfg]
|
207 |
elif checkpoint == "large":
|
208 |
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
|
209 |
model_cfg = "sam2_hiera_l.yaml"
|
210 |
+
return [sam2_checkpoint, model_cfg]
|
211 |
|
212 |
|
213 |
|
214 |
+
def get_mask_sam_process(
|
215 |
+
input_first_frame_image,
|
216 |
+
checkpoint,
|
217 |
+
tracking_points,
|
218 |
+
trackings_input_label,
|
219 |
+
video_frames_dir, # extracted_frames_output_dir defined in 'preprocess_video_in' function
|
220 |
+
scanned_frames,
|
221 |
+
working_frame: str = None, # current frame being added points
|
222 |
+
progress=gr.Progress(track_tqdm=True)
|
223 |
+
):
|
224 |
+
|
225 |
+
# get model and model config paths
|
226 |
+
print(f"USER CHOSEN CHECKPOINT: {checkpoint}")
|
227 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
228 |
+
print("MODEL LOADED")
|
229 |
+
|
230 |
+
# set predictor
|
231 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
232 |
+
print("PREDICTOR READY")
|
233 |
|
|
|
234 |
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
|
235 |
print(f"STATE FRAME OUTPUT DIRECTORY: {video_frames_dir}")
|
236 |
video_dir = video_frames_dir
|
237 |
|
238 |
# scan all the JPEG frame names in this directory
|
239 |
frame_names = scanned_frames
|
240 |
+
|
241 |
+
# Init SAM2 inference_state
|
242 |
inference_state = predictor.init_state(video_path=video_dir)
|
243 |
+
print("NEW INFERENCE_STATE INITIATED")
|
244 |
|
245 |
# segment and track one object
|
246 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
247 |
|
248 |
+
### HANDLING WORKING FRAME
|
249 |
new_working_frame = None
|
250 |
# Add new point
|
251 |
+
if working_frame is None:
|
252 |
+
ann_frame_idx = 0 # the frame index we interact with, 0 if it is the first frame
|
253 |
new_working_frame = "frames_output_images/frame_0.jpg"
|
254 |
else:
|
255 |
# Use a regular expression to find the integer
|
|
|
259 |
frame_number = int(match.group(1))
|
260 |
ann_frame_idx = frame_number
|
261 |
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
262 |
+
print(f"NEW_WORKING_FRAME PATH: {new_working_frame}")
|
263 |
|
264 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
265 |
|
|
|
474 |
|
475 |
# Clear every points clicked and added to the map
|
476 |
clear_points_btn.click(
|
477 |
+
fn = clear_points,
|
478 |
inputs = input_first_frame_image, # we get the untouched hidden image
|
479 |
outputs = [
|
480 |
first_frame_path,
|
|
|
496 |
"""
|
497 |
|
498 |
submit_btn.click(
|
499 |
+
fn = get_mask_sam_process,
|
500 |
+
inputs = [
|
501 |
+
input_first_frame_image,
|
502 |
+
checkpoint,
|
503 |
+
tracking_points,
|
504 |
+
trackings_input_label,
|
505 |
+
video_frames_dir,
|
506 |
+
scanned_frames,
|
507 |
+
working_frame,
|
508 |
+
],
|
509 |
+
outputs = [
|
510 |
+
output_result,
|
511 |
+
stored_frame_names,
|
512 |
+
stored_inference_state,
|
513 |
+
]
|
514 |
)
|
515 |
|
516 |
propagate_btn.click(
|