fffiloni commited on
Commit
0479145
·
verified ·
1 Parent(s): 2ed6668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -49
app.py CHANGED
@@ -7,12 +7,53 @@ import numpy as np
7
  import cv2
8
  import matplotlib.pyplot as plt
9
  from PIL import Image, ImageFilter
10
- from sam2.build_sam import build_sam2
11
- from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
  def preprocess_image(image):
14
  return image, gr.State([]), gr.State([]), image
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
17
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
18
 
@@ -56,27 +97,23 @@ if torch.cuda.get_device_properties(0).major >= 8:
56
  torch.backends.cuda.matmul.allow_tf32 = True
57
  torch.backends.cudnn.allow_tf32 = True
58
 
59
- def show_mask(mask, ax, random_color=False, borders = True):
60
  if random_color:
61
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
62
  else:
63
- color = np.array([30/255, 144/255, 255/255, 0.6])
 
 
64
  h, w = mask.shape[-2:]
65
- mask = mask.astype(np.uint8)
66
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
67
- if borders:
68
- import cv2
69
- contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
70
- # Try to smooth contours
71
- contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
72
- mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
73
  ax.imshow(mask_image)
74
 
75
- def show_points(coords, labels, ax, marker_size=375):
 
76
  pos_points = coords[labels==1]
77
  neg_points = coords[labels==0]
78
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
79
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
 
81
  def show_box(box, ax):
82
  x0, y0 = box[0], box[1]
@@ -130,10 +167,12 @@ def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_l
130
  return combined_images, mask_images
131
 
132
 
133
- def sam_process(input_image, checkpoint, tracking_points, trackings_input_label):
134
- image = Image.open(input_image)
135
- image = np.array(image.convert("RGB"))
136
 
 
 
137
  if checkpoint == "tiny":
138
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
139
  model_cfg = "sam2_hiera_t.yaml"
@@ -147,56 +186,118 @@ def sam_process(input_image, checkpoint, tracking_points, trackings_input_label)
147
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
148
  model_cfg = "sam2_hiera_l.yaml"
149
 
150
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- predictor = SAM2ImagePredictor(sam2_model)
 
 
 
 
153
 
154
- predictor.set_image(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- input_point = np.array(tracking_points.value)
157
- input_label = np.array(trackings_input_label.value)
 
 
 
 
 
 
 
 
 
158
 
159
- print(predictor._features["image_embed"].shape, predictor._features["image_embed"][-1].shape)
 
160
 
161
- masks, scores, logits = predictor.predict(
162
- point_coords=input_point,
163
- point_labels=input_label,
164
- multimask_output=False,
165
- )
166
- sorted_ind = np.argsort(scores)[::-1]
167
- masks = masks[sorted_ind]
168
- scores = scores[sorted_ind]
169
- logits = logits[sorted_ind]
170
 
171
- print(masks.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- results, mask_results = show_masks(image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True)
174
- print(results)
 
 
 
 
 
 
 
 
 
175
 
176
- return results[0], mask_results[0]
177
 
178
  with gr.Blocks() as demo:
179
  first_frame_path = gr.State()
180
  tracking_points = gr.State([])
181
  trackings_input_label = gr.State([])
182
  with gr.Column():
183
- gr.Markdown("# SAM2 Image Predictor")
184
- gr.Markdown("This is a simple demo for image segmentation with SAM2.")
185
  gr.Markdown("""Instructions:
186
 
187
- 1. Upload your image
188
- 2. With 'include' point type selected, Click on the object to mask
189
  3. Switch to 'exclude' point type if you want to specify an area to avoid
190
  4. Submit !
191
  """)
192
  with gr.Row():
193
  with gr.Column():
194
- input_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
195
  points_map = gr.Image(
196
  label="points map",
197
  type="filepath",
198
  interactive=True
199
  )
 
200
  with gr.Row():
201
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
202
  clear_points_btn = gr.Button("Clear Points")
@@ -204,19 +305,19 @@ with gr.Blocks() as demo:
204
  submit_btn = gr.Button("Submit")
205
  with gr.Column():
206
  output_result = gr.Image()
207
- output_result_mask = gr.Image()
208
 
209
  clear_points_btn.click(
210
  fn = preprocess_image,
211
- inputs = input_image,
212
  outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
213
  queue=False
214
  )
215
 
216
- points_map.upload(
217
- fn = preprocess_image,
218
- inputs = [points_map],
219
- outputs = [first_frame_path, tracking_points, trackings_input_label, input_image],
220
  queue = False
221
  )
222
 
@@ -229,8 +330,8 @@ with gr.Blocks() as demo:
229
 
230
  submit_btn.click(
231
  fn = sam_process,
232
- inputs = [input_image, checkpoint, tracking_points, trackings_input_label],
233
- outputs = [output_result, output_result_mask]
234
  )
235
 
236
  demo.launch(show_api=False, show_error=True)
 
7
  import cv2
8
  import matplotlib.pyplot as plt
9
  from PIL import Image, ImageFilter
10
+ from sam2.build_sam import build_sam2_video_predictor
 
11
 
12
  def preprocess_image(image):
13
  return image, gr.State([]), gr.State([]), image
14
 
15
+ def preprocess_video_in(video_path):
16
+
17
+ # Generate a unique ID based on the current date and time
18
+ unique_id = datetime.now().strftime('%Y%m%d%H%M%S')
19
+ output_dir = f'frames_{unique_id}'
20
+
21
+ # Create the output directory
22
+ os.makedirs(output_dir, exist_ok=True)
23
+
24
+ # Open the video file
25
+ cap = cv2.VideoCapture(video_path)
26
+
27
+ if not cap.isOpened():
28
+ print("Error: Could not open video.")
29
+ return None
30
+
31
+ frame_number = 0
32
+ first_frame = None
33
+
34
+ while True:
35
+ ret, frame = cap.read()
36
+ if not ret:
37
+ break
38
+
39
+ # Format the frame filename as '00000.jpg'
40
+ frame_filename = os.path.join(output_dir, f'{frame_number:05d}.jpg')
41
+
42
+ # Save the frame as a JPEG file
43
+ cv2.imwrite(frame_filename, frame)
44
+
45
+ # Store the first frame
46
+ if frame_number == 0:
47
+ first_frame = frame_filename
48
+
49
+ frame_number += 1
50
+
51
+ # Release the video capture object
52
+ cap.release()
53
+
54
+ # 'image' is the first frame extracted from video_in
55
+ return first_frame, gr.State([]), gr.State([]), first_frame, first_frame
56
+
57
  def get_point(point_type, tracking_points, trackings_input_label, first_frame_path, evt: gr.SelectData):
58
  print(f"You selected {evt.value} at {evt.index} from {evt.target}")
59
 
 
97
  torch.backends.cuda.matmul.allow_tf32 = True
98
  torch.backends.cudnn.allow_tf32 = True
99
 
100
+ def show_mask(mask, ax, obj_id=None, random_color=False):
101
  if random_color:
102
  color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
103
  else:
104
+ cmap = plt.get_cmap("tab10")
105
+ cmap_idx = 0 if obj_id is None else obj_id
106
+ color = np.array([*cmap(cmap_idx)[:3], 0.6])
107
  h, w = mask.shape[-2:]
108
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
 
 
 
 
 
 
 
109
  ax.imshow(mask_image)
110
 
111
+
112
+ def show_points(coords, labels, ax, marker_size=200):
113
  pos_points = coords[labels==1]
114
  neg_points = coords[labels==0]
115
  ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
116
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
117
 
118
  def show_box(box, ax):
119
  x0, y0 = box[0], box[1]
 
167
  return combined_images, mask_images
168
 
169
 
170
+ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_input_label):
171
+ # 1. We need to preprocess the video and store frames in the right directory
172
+ # Penser à utiliser un ID unique pour le dossier
173
 
174
+
175
+ # Load model accordingly to user's choice
176
  if checkpoint == "tiny":
177
  sam2_checkpoint = "./checkpoints/sam2_hiera_tiny.pt"
178
  model_cfg = "sam2_hiera_t.yaml"
 
186
  sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
187
  model_cfg = "sam2_hiera_l.yaml"
188
 
189
+ predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
190
+
191
+
192
+ # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
193
+ video_dir = "./videos/bedroom"
194
+
195
+ # scan all the JPEG frame names in this directory
196
+ frame_names = [
197
+ p for p in os.listdir(video_dir)
198
+ if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
199
+ ]
200
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
201
 
202
+
203
+ inference_state = predictor.init_state(video_path=video_dir)
204
+
205
+ # segment and track one object
206
+ predictor.reset_state(inference_state) # if any previous tracking, reset
207
 
208
+ # Add new point
209
+ ann_frame_idx = 0 # the frame index we interact with
210
+ ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
211
+
212
+ # Let's add a positive click at (x, y) = (210, 350) to get started
213
+ points = np.array(tracking_points.value), dtype=np.float32)
214
+ # for labels, `1` means positive click and `0` means negative click
215
+ labels = np.array(trackings_input_label.value, np.int32)
216
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points(
217
+ inference_state=inference_state,
218
+ frame_idx=ann_frame_idx,
219
+ obj_id=ann_obj_id,
220
+ points=points,
221
+ labels=labels,
222
+ )
223
 
224
+ # Create the plot
225
+ plt.figure(figsize=(12, 8))
226
+ plt.title(f"frame {ann_frame_idx}")
227
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
228
+ show_points(points, labels, plt.gca())
229
+ show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
230
+
231
+ # Save the plot as a JPG file
232
+ output_filename = "output_frame.jpg"
233
+ plt.savefig(output_filename, format='jpg')
234
+ plt.close()
235
 
236
+ """
237
+ #### PROPAGATION ####
238
 
239
+ # Define a directory to save the JPEG images
240
+ frames_output_dir = "frames_output_images"
241
+ os.makedirs(frames_output_dir, exist_ok=True)
242
+
243
+ # Initialize a list to store file paths of saved images
244
+ jpeg_images = []
 
 
 
245
 
246
+ # run propagation throughout the video and collect the results in a dict
247
+ video_segments = {} # video_segments contains the per-frame segmentation results
248
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
249
+ video_segments[out_frame_idx] = {
250
+ out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
251
+ for i, out_obj_id in enumerate(out_obj_ids)
252
+ }
253
+
254
+ # render the segmentation results every few frames
255
+ vis_frame_stride = 15
256
+ plt.close("all")
257
+ for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
258
+ plt.figure(figsize=(6, 4))
259
+ plt.title(f"frame {out_frame_idx}")
260
+ plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
261
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
262
+ show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
263
 
264
+ # Define the output filename and save the figure as a JPEG file
265
+ output_filename = os.path.join(frames_output_dir, f"frame_{out_frame_idx}.jpg")
266
+ plt.savefig(output_filename, format='jpg')
267
+
268
+ # Append the file path to the list
269
+ jpeg_images.append(output_filename)
270
+
271
+ # Close the plot
272
+ plt.close()
273
+ """
274
+ # OLD
275
 
276
+ return output_filename
277
 
278
  with gr.Blocks() as demo:
279
  first_frame_path = gr.State()
280
  tracking_points = gr.State([])
281
  trackings_input_label = gr.State([])
282
  with gr.Column():
283
+ gr.Markdown("# SAM2 Video Predictor")
284
+ gr.Markdown("This is a simple demo for video segmentation with SAM2.")
285
  gr.Markdown("""Instructions:
286
 
287
+ 1. Upload your video
288
+ 2. With 'include' point type selected, Click on the object to mask on first frame
289
  3. Switch to 'exclude' point type if you want to specify an area to avoid
290
  4. Submit !
291
  """)
292
  with gr.Row():
293
  with gr.Column():
294
+ input_first_frame_image = gr.Image(label="input image", interactive=False, type="filepath", visible=False)
295
  points_map = gr.Image(
296
  label="points map",
297
  type="filepath",
298
  interactive=True
299
  )
300
+ video_in = gr.Video(label="Video IN")
301
  with gr.Row():
302
  point_type = gr.Radio(label="point type", choices=["include", "exclude"], value="include")
303
  clear_points_btn = gr.Button("Clear Points")
 
305
  submit_btn = gr.Button("Submit")
306
  with gr.Column():
307
  output_result = gr.Image()
308
+ # output_result_mask = gr.Image()
309
 
310
  clear_points_btn.click(
311
  fn = preprocess_image,
312
+ inputs = input_first_frame_image,
313
  outputs = [first_frame_path, tracking_points, trackings_input_label, points_map],
314
  queue=False
315
  )
316
 
317
+ video_in.upload(
318
+ fn = preprocess_video_in,
319
+ inputs = [video_in],
320
+ outputs = [first_frame_path, tracking_points, trackings_input_label, input_first_frame_image, point_map],
321
  queue = False
322
  )
323
 
 
330
 
331
  submit_btn.click(
332
  fn = sam_process,
333
+ inputs = [input_first_frame_image, checkpoint, tracking_points, trackings_input_label],
334
+ outputs = [output_result]
335
  )
336
 
337
  demo.launch(show_api=False, show_error=True)