fffiloni commited on
Commit
f065d89
·
verified ·
1 Parent(s): 11d0cdc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -24
app.py CHANGED
@@ -177,6 +177,19 @@ def show_mask(mask, ax, obj_id=None, random_color=False):
177
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
178
  ax.imshow(mask_image)
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  def show_points(coords, labels, ax, marker_size=200):
182
  pos_points = coords[labels==1]
@@ -319,7 +332,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
319
 
320
  # Initialize a list to store file paths of saved images
321
  jpeg_images = []
322
- masks_frames = []
323
 
324
  # run propagation throughout the video and collect the results in a dict
325
  video_segments = {} # video_segments contains the per-frame segmentation results
@@ -343,20 +356,6 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
343
  for out_obj_id, out_mask in video_segments[out_frame_idx].items():
344
  show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
345
 
346
- # Save the raw binary mask as a separate image
347
- mask_filename = os.path.join(mask_frames_output_dir, f"mask_{out_frame_idx}.jpg")
348
- binary_mask = np.squeeze(out_mask) # Ensure the mask is 2D
349
- binary_mask = (binary_mask * 255).astype(np.uint8) # Scale mask to 0-255
350
-
351
- if binary_mask.ndim != 2: # Ensure it's 2D for PIL
352
- raise ValueError(f"Mask has invalid dimensions: {binary_mask.shape}")
353
-
354
- mask_image = Image.fromarray(binary_mask)
355
- mask_image.save(mask_filename) # Save the mask as a JPEG
356
- masks_frames.append(mask_filename) # Append to the list of masks
357
-
358
- print(f"MASKS FRAMES: {masks_frames}")
359
-
360
  # Define the output filename and save the figure as a JPEG file
361
  output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
362
  plt.savefig(output_filename, format='jpg')
@@ -370,6 +369,23 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
370
  if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
371
  available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
 
375
  torch.cuda.empty_cache()
@@ -392,18 +408,30 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
392
  codec='libx264'
393
  )
394
 
395
-
396
- # Create the video clip
397
- mask_clip = ImageSequenceClip(masks_frames, fps=fps)
398
 
399
- # Define the output file path
400
- mask_final_vid_output_path = "mask_output_video.mp4"
401
-
402
- # Write the video to a file
403
- mask_clip.write_videofile(mask_final_vid_output_path, codec='libx264')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
 
406
- return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_final_vid_output_path
407
 
408
  def update_ui(vis_frame_type):
409
  if vis_frame_type == "check":
 
177
  mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
178
  ax.imshow(mask_image)
179
 
180
+ def show_white_mask(mask, ax):
181
+ # Ensure mask is binary (values 0 or 1)
182
+ mask = (mask > 0).astype(float) # Convert to binary mask
183
+ h, w = mask.shape[-2:]
184
+
185
+ # Create a white mask (RGBA: [1, 1, 1, alpha])
186
+ alpha = 1.0 # Fully opaque
187
+ color = np.array([1, 1, 1, alpha])
188
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
189
+
190
+ # Display black background
191
+ ax.imshow(np.zeros((h, w, 3), dtype=float)) # Black background
192
+ ax.imshow(mask_image) # Overlay white mask
193
 
194
  def show_points(coords, labels, ax, marker_size=200):
195
  pos_points = coords[labels==1]
 
332
 
333
  # Initialize a list to store file paths of saved images
334
  jpeg_images = []
335
+ masks_images = []
336
 
337
  # run propagation throughout the video and collect the results in a dict
338
  video_segments = {} # video_segments contains the per-frame segmentation results
 
356
  for out_obj_id, out_mask in video_segments[out_frame_idx].items():
357
  show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  # Define the output filename and save the figure as a JPEG file
360
  output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
361
  plt.savefig(output_filename, format='jpg')
 
369
  if f"frame_{out_frame_idx}.jpg" not in available_frames_to_check:
370
  available_frames_to_check.append(f"frame_{out_frame_idx}.jpg")
371
 
372
+ # Step 2: Create and store a black-and-white mask image using show_white_mask
373
+ # Create a figure without displaying it for the white mask
374
+ fig, ax = plt.subplots(figsize=(6, 4))
375
+ ax.axis("off") # Remove axes for a clean mask
376
+
377
+ # Overlay each mask as white on a black background
378
+ for out_mask in video_segments[out_frame_idx].values():
379
+ show_white_mask(out_mask, ax)
380
+
381
+ # Save the white mask figure to an image in memory
382
+ mask_filename = os.path.join(masks_output_dir, f"mask_{out_frame_idx}.jpg")
383
+ fig.savefig(mask_filename, format='jpg', bbox_inches="tight", pad_inches=0)
384
+ plt.close(fig)
385
+
386
+ # Add the saved mask image to the masks_images array
387
+ masks_images.append(mask_filename)
388
+
389
 
390
 
391
  torch.cuda.empty_cache()
 
408
  codec='libx264'
409
  )
410
 
411
+ print("MAKING MASK VIDEO ...")
 
 
412
 
413
+ # Create a video from the masks_images array
414
+ mask_video_filename = "final_masks_video.mp4"
415
+
416
+ # Get the dimensions of the first mask image
417
+ frame = cv2.imread(masks_images[0])
418
+ height, width, _ = frame.shape
419
+
420
+ # Define the video writer
421
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
422
+ fps = original_fps # Frames per second
423
+ video_writer = cv2.VideoWriter(mask_video_filename, fourcc, fps, (width, height))
424
+
425
+ # Write each mask image to the video
426
+ for mask_path in masks_images:
427
+ frame = cv2.imread(mask_path)
428
+ video_writer.write(frame)
429
+
430
+ video_writer.release()
431
+ print(f"Mask Video saved at {mask_video_filename}")
432
 
433
 
434
+ return gr.update(value=None), gr.update(value=final_vid_output_path), working_frame, available_frames_to_check, gr.update(visible=True), mask_video_filename
435
 
436
  def update_ui(vis_frame_type):
437
  if vis_frame_type == "check":