Spaces:
Running
on
T4
Running
on
T4
Update app.py
Browse files
app.py
CHANGED
@@ -177,6 +177,19 @@ def show_mask(mask, ax, obj_id=None, random_color=False):
|
|
177 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
178 |
ax.imshow(mask_image)
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
def show_points(coords, labels, ax, marker_size=200):
|
182 |
pos_points = coords[labels==1]
|
@@ -319,7 +332,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
319 |
|
320 |
# Initialize a list to store file paths of saved images
|
321 |
jpeg_images = []
|
322 |
-
|
323 |
|
324 |
# run propagation throughout the video and collect the results in a dict
|
325 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
@@ -343,20 +356,6 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
343 |
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
344 |
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
345 |
|
346 |
-
# Save the raw binary mask as a separate image
|
347 |
-
mask_filename = os.path.join(mask_frames_output_dir, f"mask_{out_frame_idx}.jpg")
|
348 |
-
binary_mask = np.squeeze(out_mask) # Ensure the mask is 2D
|
349 |
-
binary_mask = (binary_mask * 255).astype(np.uint8) # Scale mask to 0-255
|
350 |
-
|
351 |
-
if binary_mask.ndim != 2: # Ensure it's 2D for PIL
|
352 |
-
raise ValueError(f"Mask has invalid dimensions: {binary_mask.shape}")
|
353 |
-
|
354 |
-
mask_image = Image.fromarray(binary_mask)
|
355 |
-
mask_image.save(mask_filename) # Save the mask as a JPEG
|
356 |
-
masks_frames.append(mask_filename) # Append to the list of masks
|
357 |
-
|
358 |
-
print(f"MASKS FRAMES: {masks_frames}")
|
359 |
-
|
360 |
# Define the output filename and save the figure as a JPEG file
|
361 |
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
362 |
plt.savefig(output_filename, format='jpg')
|
@@ -370,6 +369,23 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
370 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
371 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
373 |
|
374 |
|
375 |
torch.cuda.empty_cache()
|
@@ -392,18 +408,30 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
|
|
392 |
codec='libx264'
|
393 |
)
|
394 |
|
395 |
-
|
396 |
-
# Create the video clip
|
397 |
-
mask_clip = ImageSequenceClip(masks_frames, fps=fps)
|
398 |
|
399 |
-
#
|
400 |
-
|
401 |
-
|
402 |
-
#
|
403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
|
406 |
-
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True),
|
407 |
|
408 |
def update_ui(vis_frame_type):
|
409 |
if vis_frame_type == "check":
|
|
|
177 |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
178 |
ax.imshow(mask_image)
|
179 |
|
180 |
+
def show_white_mask(mask, ax):
|
181 |
+
# Ensure mask is binary (values 0 or 1)
|
182 |
+
mask = (mask > 0).astype(float) # Convert to binary mask
|
183 |
+
h, w = mask.shape[-2:]
|
184 |
+
|
185 |
+
# Create a white mask (RGBA: [1, 1, 1, alpha])
|
186 |
+
alpha = 1.0 # Fully opaque
|
187 |
+
color = np.array([1, 1, 1, alpha])
|
188 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
189 |
+
|
190 |
+
# Display black background
|
191 |
+
ax.imshow(np.zeros((h, w, 3), dtype=float)) # Black background
|
192 |
+
ax.imshow(mask_image) # Overlay white mask
|
193 |
|
194 |
def show_points(coords, labels, ax, marker_size=200):
|
195 |
pos_points = coords[labels==1]
|
|
|
332 |
|
333 |
# Initialize a list to store file paths of saved images
|
334 |
jpeg_images = []
|
335 |
+
masks_images = []
|
336 |
|
337 |
# run propagation throughout the video and collect the results in a dict
|
338 |
video_segments = {} # video_segments contains the per-frame segmentation results
|
|
|
356 |
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
357 |
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
# Define the output filename and save the figure as a JPEG file
|
360 |
output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
|
361 |
plt.savefig(output_filename, format='jpg')
|
|
|
369 |
if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
|
370 |
available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
|
371 |
|
372 |
+
# Step 2: Create and store a black-and-white mask image using show_white_mask
|
373 |
+
# Create a figure without displaying it for the white mask
|
374 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
375 |
+
ax.axis("off") # Remove axes for a clean mask
|
376 |
+
|
377 |
+
# Overlay each mask as white on a black background
|
378 |
+
for out_mask in video_segments[out_frame_idx].values():
|
379 |
+
show_white_mask(out_mask, ax)
|
380 |
+
|
381 |
+
# Save the white mask figure to an image in memory
|
382 |
+
mask_filename = os.path.join(masks_output_dir, f"mask_{out_frame_idx}.jpg")
|
383 |
+
fig.savefig(mask_filename, format='jpg', bbox_inches="tight", pad_inches=0)
|
384 |
+
plt.close(fig)
|
385 |
+
|
386 |
+
# Add the saved mask image to the masks_images array
|
387 |
+
masks_images.append(mask_filename)
|
388 |
+
|
389 |
|
390 |
|
391 |
torch.cuda.empty_cache()
|
|
|
408 |
codec='libx264'
|
409 |
)
|
410 |
|
411 |
+
print("MAKING MASK VIDEO ...")
|
|
|
|
|
412 |
|
413 |
+
# Create a video from the masks_images array
|
414 |
+
mask_video_filename = "final_masks_video.mp4"
|
415 |
+
|
416 |
+
# Get the dimensions of the first mask image
|
417 |
+
frame = cv2.imread(masks_images[0])
|
418 |
+
height, width, _ = frame.shape
|
419 |
+
|
420 |
+
# Define the video writer
|
421 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
422 |
+
fps = original_fps # Frames per second
|
423 |
+
video_writer = cv2.VideoWriter(mask_video_filename, fourcc, fps, (width, height))
|
424 |
+
|
425 |
+
# Write each mask image to the video
|
426 |
+
for mask_path in masks_images:
|
427 |
+
frame = cv2.imread(mask_path)
|
428 |
+
video_writer.write(frame)
|
429 |
+
|
430 |
+
video_writer.release()
|
431 |
+
print(f"Mask Video saved at {mask_video_filename}")
|
432 |
|
433 |
|
434 |
+
return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
|
435 |
|
436 |
def update_ui(vis_frame_type):
|
437 |
if vis_frame_type == "check":
|