Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -167,51 +167,6 @@ def show_box(box, ax):
|
|
167 |
w, h = box[2] - box[0], box[3] - box[1]
|
168 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
169 |
|
170 |
-
def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
|
171 |
-
combined_images = [] # List to store filenames of images with masks overlaid
|
172 |
-
mask_images = [] # List to store filenames of separate mask images
|
173 |
-
|
174 |
-
for i, (mask, score) in enumerate(zip(masks, scores)):
|
175 |
-
# ---- Original Image with Mask Overlaid ----
|
176 |
-
plt.figure(figsize=(10, 10))
|
177 |
-
plt.imshow(image)
|
178 |
-
show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
|
179 |
-
"""
|
180 |
-
if point_coords is not None:
|
181 |
-
assert input_labels is not None
|
182 |
-
show_points(point_coords, input_labels, plt.gca())
|
183 |
-
"""
|
184 |
-
if box_coords is not None:
|
185 |
-
show_box(box_coords, plt.gca())
|
186 |
-
if len(scores) > 1:
|
187 |
-
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
|
188 |
-
plt.axis('off')
|
189 |
-
|
190 |
-
# Save the figure as a JPG file
|
191 |
-
combined_filename = f"combined_image_{i+1}.jpg"
|
192 |
-
plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
|
193 |
-
combined_images.append(combined_filename)
|
194 |
-
|
195 |
-
plt.close() # Close the figure to free up memory
|
196 |
-
|
197 |
-
# ---- Separate Mask Image (White Mask on Black Background) ----
|
198 |
-
# Create a black image
|
199 |
-
mask_image = np.zeros_like(image, dtype=np.uint8)
|
200 |
-
|
201 |
-
# The mask is a binary array where the masked area is 1, else 0.
|
202 |
-
# Convert the mask to a white color in the mask_image
|
203 |
-
mask_layer = (mask > 0).astype(np.uint8) * 255
|
204 |
-
for c in range(3): # Assuming RGB, repeat mask for all channels
|
205 |
-
mask_image[:, :, c] = mask_layer
|
206 |
-
|
207 |
-
# Save the mask image
|
208 |
-
mask_filename = f"mask_image_{i+1}.png"
|
209 |
-
Image.fromarray(mask_image).save(mask_filename)
|
210 |
-
mask_images.append(mask_filename)
|
211 |
-
|
212 |
-
plt.close() # Close the figure to free up memory
|
213 |
-
|
214 |
-
return combined_images, mask_images
|
215 |
|
216 |
def load_model(checkpoint):
|
217 |
# Load model accordingly to user's choice
|
@@ -254,9 +209,11 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
254 |
# segment and track one object
|
255 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
256 |
|
|
|
257 |
# Add new point
|
258 |
if working_frame == None:
|
259 |
ann_frame_idx = 0 # the frame index we interact with
|
|
|
260 |
else:
|
261 |
# Use a regular expression to find the integer
|
262 |
match = re.search(r'frame_(\d+)', working_frame)
|
@@ -264,6 +221,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
264 |
# Extract the integer from the match
|
265 |
frame_number = int(match.group(1))
|
266 |
ann_frame_idx = frame_number
|
|
|
267 |
|
268 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
269 |
|
@@ -292,7 +250,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
292 |
plt.close()
|
293 |
torch.cuda.empty_cache()
|
294 |
|
295 |
-
return "output_first_frame.jpg", frame_names, inference_state
|
296 |
|
297 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
298 |
#### PROPAGATION ####
|
@@ -346,7 +304,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
346 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
347 |
|
348 |
if vis_frame_type == "check":
|
349 |
-
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=
|
350 |
elif vis_frame_type == "render":
|
351 |
# Create a video clip from the image sequence
|
352 |
original_fps = get_video_fps(video_in)
|
@@ -378,7 +336,7 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
|
|
378 |
frame_number = int(match.group(1))
|
379 |
ann_frame_idx = frame_number
|
380 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
381 |
-
return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame
|
382 |
|
383 |
with gr.Blocks() as demo:
|
384 |
first_frame_path = gr.State()
|
@@ -453,19 +411,19 @@ with gr.Blocks() as demo:
|
|
453 |
queue = False
|
454 |
)
|
455 |
|
456 |
-
|
457 |
working_frame.change(
|
458 |
fn = switch_working_frame,
|
459 |
inputs = [working_frame, scanned_frames, video_frames_dir],
|
460 |
-
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map],
|
461 |
queue=False
|
462 |
)
|
463 |
-
|
464 |
|
465 |
submit_btn.click(
|
466 |
fn = sam_process,
|
467 |
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
|
468 |
-
outputs = [output_result, stored_frame_names, stored_inference_state]
|
469 |
)
|
470 |
|
471 |
propagate_btn.click(
|
|
|
167 |
w, h = box[2] - box[0], box[3] - box[1]
|
168 |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
def load_model(checkpoint):
|
172 |
# Load model accordingly to user's choice
|
|
|
209 |
# segment and track one object
|
210 |
# predictor.reset_state(inference_state) # if any previous tracking, reset
|
211 |
|
212 |
+
new_working_frame = None
|
213 |
# Add new point
|
214 |
if working_frame == None:
|
215 |
ann_frame_idx = 0 # the frame index we interact with
|
216 |
+
new_working_frame = "frames_output_images/frame_0.jpg"
|
217 |
else:
|
218 |
# Use a regular expression to find the integer
|
219 |
match = re.search(r'frame_(\d+)', working_frame)
|
|
|
221 |
# Extract the integer from the match
|
222 |
frame_number = int(match.group(1))
|
223 |
ann_frame_idx = frame_number
|
224 |
+
new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
|
225 |
|
226 |
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
|
227 |
|
|
|
250 |
plt.close()
|
251 |
torch.cuda.empty_cache()
|
252 |
|
253 |
+
return "output_first_frame.jpg", frame_names, inference_state, new_working_frame
|
254 |
|
255 |
def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
|
256 |
#### PROPAGATION ####
|
|
|
304 |
print(f"JPEG_IMAGES: {jpeg_images}")
|
305 |
|
306 |
if vis_frame_type == "check":
|
307 |
+
return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=True)
|
308 |
elif vis_frame_type == "render":
|
309 |
# Create a video clip from the image sequence
|
310 |
original_fps = get_video_fps(video_in)
|
|
|
336 |
frame_number = int(match.group(1))
|
337 |
ann_frame_idx = frame_number
|
338 |
new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
|
339 |
+
return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
|
340 |
|
341 |
with gr.Blocks() as demo:
|
342 |
first_frame_path = gr.State()
|
|
|
411 |
queue = False
|
412 |
)
|
413 |
|
414 |
+
|
415 |
working_frame.change(
|
416 |
fn = switch_working_frame,
|
417 |
inputs = [working_frame, scanned_frames, video_frames_dir],
|
418 |
+
outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, working_frame],
|
419 |
queue=False
|
420 |
)
|
421 |
+
|
422 |
|
423 |
submit_btn.click(
|
424 |
fn = sam_process,
|
425 |
inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
|
426 |
+
outputs = [output_result, stored_frame_names, stored_inference_state, working_frame]
|
427 |
)
|
428 |
|
429 |
propagate_btn.click(
|