fffiloni commited on
Commit
40eec4d
·
verified ·
1 Parent(s): f0c76f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -52
app.py CHANGED
@@ -167,51 +167,6 @@ def show_box(box, ax):
167
  w, h = box[2] - box[0], box[3] - box[1]
168
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
169
 
170
- def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
171
- combined_images = [] # List to store filenames of images with masks overlaid
172
- mask_images = [] # List to store filenames of separate mask images
173
-
174
- for i, (mask, score) in enumerate(zip(masks, scores)):
175
- # ---- Original Image with Mask Overlaid ----
176
- plt.figure(figsize=(10, 10))
177
- plt.imshow(image)
178
- show_mask(mask, plt.gca(), borders=borders) # Draw the mask with borders
179
- """
180
- if point_coords is not None:
181
- assert input_labels is not None
182
- show_points(point_coords, input_labels, plt.gca())
183
- """
184
- if box_coords is not None:
185
- show_box(box_coords, plt.gca())
186
- if len(scores) > 1:
187
- plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
188
- plt.axis('off')
189
-
190
- # Save the figure as a JPG file
191
- combined_filename = f"combined_image_{i+1}.jpg"
192
- plt.savefig(combined_filename, format='jpg', bbox_inches='tight')
193
- combined_images.append(combined_filename)
194
-
195
- plt.close() # Close the figure to free up memory
196
-
197
- # ---- Separate Mask Image (White Mask on Black Background) ----
198
- # Create a black image
199
- mask_image = np.zeros_like(image, dtype=np.uint8)
200
-
201
- # The mask is a binary array where the masked area is 1, else 0.
202
- # Convert the mask to a white color in the mask_image
203
- mask_layer = (mask > 0).astype(np.uint8) * 255
204
- for c in range(3): # Assuming RGB, repeat mask for all channels
205
- mask_image[:, :, c] = mask_layer
206
-
207
- # Save the mask image
208
- mask_filename = f"mask_image_{i+1}.png"
209
- Image.fromarray(mask_image).save(mask_filename)
210
- mask_images.append(mask_filename)
211
-
212
- plt.close() # Close the figure to free up memory
213
-
214
- return combined_images, mask_images
215
 
216
  def load_model(checkpoint):
217
  # Load model accordingly to user's choice
@@ -254,9 +209,11 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
254
  # segment and track one object
255
  # predictor.reset_state(inference_state) # if any previous tracking, reset
256
 
 
257
  # Add new point
258
  if working_frame == None:
259
  ann_frame_idx = 0 # the frame index we interact with
 
260
  else:
261
  # Use a regular expression to find the integer
262
  match = re.search(r'frame_(\d+)', working_frame)
@@ -264,6 +221,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
264
  # Extract the integer from the match
265
  frame_number = int(match.group(1))
266
  ann_frame_idx = frame_number
 
267
 
268
  ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
269
 
@@ -292,7 +250,7 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
292
  plt.close()
293
  torch.cuda.empty_cache()
294
 
295
- return "output_first_frame.jpg", frame_names, inference_state
296
 
297
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
298
  #### PROPAGATION ####
@@ -346,7 +304,7 @@ def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_
346
  print(f"JPEG_IMAGES: {jpeg_images}")
347
 
348
  if vis_frame_type == "check":
349
- return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=False)
350
  elif vis_frame_type == "render":
351
  # Create a video clip from the image sequence
352
  original_fps = get_video_fps(video_in)
@@ -378,7 +336,7 @@ def switch_working_frame(working_frame, scanned_frames, video_frames_dir):
378
  frame_number = int(match.group(1))
379
  ann_frame_idx = frame_number
380
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
381
- return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame
382
 
383
  with gr.Blocks() as demo:
384
  first_frame_path = gr.State()
@@ -453,19 +411,19 @@ with gr.Blocks() as demo:
453
  queue = False
454
  )
455
 
456
- """
457
  working_frame.change(
458
  fn = switch_working_frame,
459
  inputs = [working_frame, scanned_frames, video_frames_dir],
460
- outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map],
461
  queue=False
462
  )
463
- """
464
 
465
  submit_btn.click(
466
  fn = sam_process,
467
  inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
468
- outputs = [output_result, stored_frame_names, stored_inference_state]
469
  )
470
 
471
  propagate_btn.click(
 
167
  w, h = box[2] - box[0], box[3] - box[1]
168
  ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def load_model(checkpoint):
172
  # Load model accordingly to user's choice
 
209
  # segment and track one object
210
  # predictor.reset_state(inference_state) # if any previous tracking, reset
211
 
212
+ new_working_frame = None
213
  # Add new point
214
  if working_frame == None:
215
  ann_frame_idx = 0 # the frame index we interact with
216
+ new_working_frame = "frames_output_images/frame_0.jpg"
217
  else:
218
  # Use a regular expression to find the integer
219
  match = re.search(r'frame_(\d+)', working_frame)
 
221
  # Extract the integer from the match
222
  frame_number = int(match.group(1))
223
  ann_frame_idx = frame_number
224
+ new_working_frame = f"frames_output_images/frame_{ann_frame_idx}.jpg"
225
 
226
  ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
227
 
 
250
  plt.close()
251
  torch.cuda.empty_cache()
252
 
253
+ return "output_first_frame.jpg", frame_names, inference_state, new_working_frame
254
 
255
  def propagate_to_all(video_in, checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type, progress=gr.Progress(track_tqdm=True)):
256
  #### PROPAGATION ####
 
304
  print(f"JPEG_IMAGES: {jpeg_images}")
305
 
306
  if vis_frame_type == "check":
307
+ return gr.update(value=jpeg_images), gr.update(value=None), gr.update(choices=jpeg_images, value=None, visible=True)
308
  elif vis_frame_type == "render":
309
  # Create a video clip from the image sequence
310
  original_fps = get_video_fps(video_in)
 
336
  frame_number = int(match.group(1))
337
  ann_frame_idx = frame_number
338
  new_working_frame = os.path.join(video_frames_dir, scanned_frames[ann_frame_idx])
339
+ return new_working_frame, gr.State([]), gr.State([]), new_working_frame, new_working_frame, new_working_frame
340
 
341
  with gr.Blocks() as demo:
342
  first_frame_path = gr.State()
 
411
  queue = False
412
  )
413
 
414
+
415
  working_frame.change(
416
  fn = switch_working_frame,
417
  inputs = [working_frame, scanned_frames, video_frames_dir],
418
+ outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, points_map, working_frame],
419
  queue=False
420
  )
421
+
422
 
423
  submit_btn.click(
424
  fn = sam_process,
425
  inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label, video_frames_dir, scanned_frames, working_frame],
426
+ outputs = [output_result, stored_frame_names, stored_inference_state, working_frame]
427
  )
428
 
429
  propagate_btn.click(