tyriaa commited on
Commit
1677992
·
1 Parent(s): 4aecc3e

Initialisation 000

Browse files
Files changed (1) hide show
  1. sam2_image_predictor.py +466 -0
sam2_image_predictor.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+
9
+ from typing import List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ from PIL.Image import Image
14
+
15
+ from sam2.modeling.sam2_base import SAM2Base
16
+
17
+ from sam2.utils.transforms import SAM2Transforms
18
+
19
+
20
+ class SAM2ImagePredictor:
21
+ def __init__(
22
+ self,
23
+ sam_model: SAM2Base,
24
+ mask_threshold=0.0,
25
+ max_hole_area=0.0,
26
+ max_sprinkle_area=0.0,
27
+ **kwargs,
28
+ ) -> None:
29
+ """
30
+ Uses SAM-2 to calculate the image embedding for an image, and then
31
+ allow repeated, efficient mask prediction given prompts.
32
+
33
+ Arguments:
34
+ sam_model (Sam-2): The model to use for mask prediction.
35
+ mask_threshold (float): The threshold to use when converting mask logits
36
+ to binary masks. Masks are thresholded at 0 by default.
37
+ max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
38
+ the maximum area of max_hole_area in low_res_masks.
39
+ max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
40
+ the maximum area of max_sprinkle_area in low_res_masks.
41
+ """
42
+ super().__init__()
43
+ self.model = sam_model
44
+ self._transforms = SAM2Transforms(
45
+ resolution=self.model.image_size,
46
+ mask_threshold=mask_threshold,
47
+ max_hole_area=max_hole_area,
48
+ max_sprinkle_area=max_sprinkle_area,
49
+ )
50
+
51
+ # Predictor state
52
+ self._is_image_set = False
53
+ self._features = None
54
+ self._orig_hw = None
55
+ # Whether the predictor is set for single image or a batch of images
56
+ self._is_batch = False
57
+
58
+ # Predictor config
59
+ self.mask_threshold = mask_threshold
60
+
61
+ # Spatial dim for backbone feature maps
62
+ self._bb_feat_sizes = [
63
+ (256, 256),
64
+ (128, 128),
65
+ (64, 64),
66
+ ]
67
+
68
+ @classmethod
69
+ def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
70
+ """
71
+ Load a pretrained model from the Hugging Face hub.
72
+
73
+ Arguments:
74
+ model_id (str): The Hugging Face repository ID.
75
+ **kwargs: Additional arguments to pass to the model constructor.
76
+
77
+ Returns:
78
+ (SAM2ImagePredictor): The loaded model.
79
+ """
80
+ from sam2.build_sam import build_sam2_hf
81
+
82
+ sam_model = build_sam2_hf(model_id, **kwargs)
83
+ return cls(sam_model, **kwargs)
84
+
85
+ @torch.no_grad()
86
+ def set_image(
87
+ self,
88
+ image: Union[np.ndarray, Image],
89
+ ) -> None:
90
+ """
91
+ Calculates the image embeddings for the provided image, allowing
92
+ masks to be predicted with the 'predict' method.
93
+
94
+ Arguments:
95
+ image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
96
+ with pixel values in [0, 255].
97
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
98
+ """
99
+ self.reset_predictor()
100
+ # Transform the image to the form expected by the model
101
+ if isinstance(image, np.ndarray):
102
+ logging.info("For numpy array image, we assume (HxWxC) format")
103
+ self._orig_hw = [image.shape[:2]]
104
+ elif isinstance(image, Image):
105
+ w, h = image.size
106
+ self._orig_hw = [(h, w)]
107
+ else:
108
+ raise NotImplementedError("Image format not supported")
109
+
110
+ input_image = self._transforms(image)
111
+ input_image = input_image[None, ...].to(self.device)
112
+
113
+ assert (
114
+ len(input_image.shape) == 4 and input_image.shape[1] == 3
115
+ ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
116
+ logging.info("Computing image embeddings for the provided image...")
117
+ backbone_out = self.model.forward_image(input_image)
118
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
119
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
120
+ if self.model.directly_add_no_mem_embed:
121
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
122
+
123
+ feats = [
124
+ feat.permute(1, 2, 0).view(1, -1, *feat_size)
125
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
126
+ ][::-1]
127
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
128
+ self._is_image_set = True
129
+ logging.info("Image embeddings computed.")
130
+
131
+ @torch.no_grad()
132
+ def set_image_batch(
133
+ self,
134
+ image_list: List[Union[np.ndarray]],
135
+ ) -> None:
136
+ """
137
+ Calculates the image embeddings for the provided image batch, allowing
138
+ masks to be predicted with the 'predict_batch' method.
139
+
140
+ Arguments:
141
+ image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
142
+ with pixel values in [0, 255].
143
+ """
144
+ self.reset_predictor()
145
+ assert isinstance(image_list, list)
146
+ self._orig_hw = []
147
+ for image in image_list:
148
+ assert isinstance(
149
+ image, np.ndarray
150
+ ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
151
+ self._orig_hw.append(image.shape[:2])
152
+ # Transform the image to the form expected by the model
153
+ img_batch = self._transforms.forward_batch(image_list)
154
+ img_batch = img_batch.to(self.device)
155
+ batch_size = img_batch.shape[0]
156
+ assert (
157
+ len(img_batch.shape) == 4 and img_batch.shape[1] == 3
158
+ ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
159
+ logging.info("Computing image embeddings for the provided images...")
160
+ backbone_out = self.model.forward_image(img_batch)
161
+ _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
162
+ # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
163
+ if self.model.directly_add_no_mem_embed:
164
+ vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
165
+
166
+ feats = [
167
+ feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
168
+ for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
169
+ ][::-1]
170
+ self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
171
+ self._is_image_set = True
172
+ self._is_batch = True
173
+ logging.info("Image embeddings computed.")
174
+
175
+ def predict_batch(
176
+ self,
177
+ point_coords_batch: List[np.ndarray] = None,
178
+ point_labels_batch: List[np.ndarray] = None,
179
+ box_batch: List[np.ndarray] = None,
180
+ mask_input_batch: List[np.ndarray] = None,
181
+ multimask_output: bool = True,
182
+ return_logits: bool = False,
183
+ normalize_coords=True,
184
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
185
+ """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
186
+ It returns a tuple of lists of masks, ious, and low_res_masks_logits.
187
+ """
188
+ assert self._is_batch, "This function should only be used when in batched mode"
189
+ if not self._is_image_set:
190
+ raise RuntimeError(
191
+ "An image must be set with .set_image_batch(...) before mask prediction."
192
+ )
193
+ num_images = len(self._features["image_embed"])
194
+ all_masks = []
195
+ all_ious = []
196
+ all_low_res_masks = []
197
+ for img_idx in range(num_images):
198
+ # Transform input prompts
199
+ point_coords = (
200
+ point_coords_batch[img_idx] if point_coords_batch is not None else None
201
+ )
202
+ point_labels = (
203
+ point_labels_batch[img_idx] if point_labels_batch is not None else None
204
+ )
205
+ box = box_batch[img_idx] if box_batch is not None else None
206
+ mask_input = (
207
+ mask_input_batch[img_idx] if mask_input_batch is not None else None
208
+ )
209
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
210
+ point_coords,
211
+ point_labels,
212
+ box,
213
+ mask_input,
214
+ normalize_coords,
215
+ img_idx=img_idx,
216
+ )
217
+ masks, iou_predictions, low_res_masks = self._predict(
218
+ unnorm_coords,
219
+ labels,
220
+ unnorm_box,
221
+ mask_input,
222
+ multimask_output,
223
+ return_logits=return_logits,
224
+ img_idx=img_idx,
225
+ )
226
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
227
+ iou_predictions_np = (
228
+ iou_predictions.squeeze(0).float().detach().cpu().numpy()
229
+ )
230
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
231
+ all_masks.append(masks_np)
232
+ all_ious.append(iou_predictions_np)
233
+ all_low_res_masks.append(low_res_masks_np)
234
+
235
+ return all_masks, all_ious, all_low_res_masks
236
+
237
+ def predict(
238
+ self,
239
+ point_coords: Optional[np.ndarray] = None,
240
+ point_labels: Optional[np.ndarray] = None,
241
+ box: Optional[np.ndarray] = None,
242
+ mask_input: Optional[np.ndarray] = None,
243
+ multimask_output: bool = True,
244
+ return_logits: bool = False,
245
+ normalize_coords=True,
246
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
247
+ """
248
+ Predict masks for the given input prompts, using the currently set image.
249
+
250
+ Arguments:
251
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
252
+ model. Each point is in (X,Y) in pixels.
253
+ point_labels (np.ndarray or None): A length N array of labels for the
254
+ point prompts. 1 indicates a foreground point and 0 indicates a
255
+ background point.
256
+ box (np.ndarray or None): A length 4 array given a box prompt to the
257
+ model, in XYXY format.
258
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
259
+ coming from a previous prediction iteration. Has form 1xHxW, where
260
+ for SAM, H=W=256.
261
+ multimask_output (bool): If true, the model will return three masks.
262
+ For ambiguous input prompts (such as a single click), this will often
263
+ produce better masks than a single prediction. If only a single
264
+ mask is needed, the model's predicted quality score can be used
265
+ to select the best mask. For non-ambiguous prompts, such as multiple
266
+ input prompts, multimask_output=False can give better results.
267
+ return_logits (bool): If true, returns un-thresholded masks logits
268
+ instead of a binary mask.
269
+ normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
270
+
271
+ Returns:
272
+ (np.ndarray): The output masks in CxHxW format, where C is the
273
+ number of masks, and (H, W) is the original image size.
274
+ (np.ndarray): An array of length C containing the model's
275
+ predictions for the quality of each mask.
276
+ (np.ndarray): An array of shape CxHxW, where C is the number
277
+ of masks and H=W=256. These low resolution logits can be passed to
278
+ a subsequent iteration as mask input.
279
+ """
280
+ if not self._is_image_set:
281
+ raise RuntimeError(
282
+ "An image must be set with .set_image(...) before mask prediction."
283
+ )
284
+
285
+ # Transform input prompts
286
+
287
+ mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
288
+ point_coords, point_labels, box, mask_input, normalize_coords
289
+ )
290
+
291
+ masks, iou_predictions, low_res_masks = self._predict(
292
+ unnorm_coords,
293
+ labels,
294
+ unnorm_box,
295
+ mask_input,
296
+ multimask_output,
297
+ return_logits=return_logits,
298
+ )
299
+
300
+ masks_np = masks.squeeze(0).float().detach().cpu().numpy()
301
+ iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
302
+ low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
303
+ return masks_np, iou_predictions_np, low_res_masks_np
304
+
305
+ def _prep_prompts(
306
+ self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
307
+ ):
308
+
309
+ unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
310
+ if point_coords is not None:
311
+ assert (
312
+ point_labels is not None
313
+ ), "point_labels must be supplied if point_coords is supplied."
314
+ point_coords = torch.as_tensor(
315
+ point_coords, dtype=torch.float, device=self.device
316
+ )
317
+ unnorm_coords = self._transforms.transform_coords(
318
+ point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
319
+ )
320
+ labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
321
+ if len(unnorm_coords.shape) == 2:
322
+ unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
323
+ if box is not None:
324
+ box = torch.as_tensor(box, dtype=torch.float, device=self.device)
325
+ unnorm_box = self._transforms.transform_boxes(
326
+ box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
327
+ ) # Bx2x2
328
+ if mask_logits is not None:
329
+ mask_input = torch.as_tensor(
330
+ mask_logits, dtype=torch.float, device=self.device
331
+ )
332
+ if len(mask_input.shape) == 3:
333
+ mask_input = mask_input[None, :, :, :]
334
+ return mask_input, unnorm_coords, labels, unnorm_box
335
+
336
+ @torch.no_grad()
337
+ def _predict(
338
+ self,
339
+ point_coords: Optional[torch.Tensor],
340
+ point_labels: Optional[torch.Tensor],
341
+ boxes: Optional[torch.Tensor] = None,
342
+ mask_input: Optional[torch.Tensor] = None,
343
+ multimask_output: bool = True,
344
+ return_logits: bool = False,
345
+ img_idx: int = -1,
346
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
347
+ """
348
+ Predict masks for the given input prompts, using the currently set image.
349
+ Input prompts are batched torch tensors and are expected to already be
350
+ transformed to the input frame using SAM2Transforms.
351
+
352
+ Arguments:
353
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
354
+ model. Each point is in (X,Y) in pixels.
355
+ point_labels (torch.Tensor or None): A BxN array of labels for the
356
+ point prompts. 1 indicates a foreground point and 0 indicates a
357
+ background point.
358
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
359
+ model, in XYXY format.
360
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
361
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
362
+ for SAM, H=W=256. Masks returned by a previous iteration of the
363
+ predict method do not need further transformation.
364
+ multimask_output (bool): If true, the model will return three masks.
365
+ For ambiguous input prompts (such as a single click), this will often
366
+ produce better masks than a single prediction. If only a single
367
+ mask is needed, the model's predicted quality score can be used
368
+ to select the best mask. For non-ambiguous prompts, such as multiple
369
+ input prompts, multimask_output=False can give better results.
370
+ return_logits (bool): If true, returns un-thresholded masks logits
371
+ instead of a binary mask.
372
+
373
+ Returns:
374
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
375
+ number of masks, and (H, W) is the original image size.
376
+ (torch.Tensor): An array of shape BxC containing the model's
377
+ predictions for the quality of each mask.
378
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
379
+ of masks and H=W=256. These low res logits can be passed to
380
+ a subsequent iteration as mask input.
381
+ """
382
+ if not self._is_image_set:
383
+ raise RuntimeError(
384
+ "An image must be set with .set_image(...) before mask prediction."
385
+ )
386
+
387
+ if point_coords is not None:
388
+ concat_points = (point_coords, point_labels)
389
+ else:
390
+ concat_points = None
391
+
392
+ # Embed prompts
393
+ if boxes is not None:
394
+ box_coords = boxes.reshape(-1, 2, 2)
395
+ box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
396
+ box_labels = box_labels.repeat(boxes.size(0), 1)
397
+ # we merge "boxes" and "points" into a single "concat_points" input (where
398
+ # boxes are added at the beginning) to sam_prompt_encoder
399
+ if concat_points is not None:
400
+ concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
401
+ concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
402
+ concat_points = (concat_coords, concat_labels)
403
+ else:
404
+ concat_points = (box_coords, box_labels)
405
+
406
+ sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
407
+ points=concat_points,
408
+ boxes=None,
409
+ masks=mask_input,
410
+ )
411
+
412
+ # Predict masks
413
+ batched_mode = (
414
+ concat_points is not None and concat_points[0].shape[0] > 1
415
+ ) # multi object prediction
416
+ high_res_features = [
417
+ feat_level[img_idx].unsqueeze(0)
418
+ for feat_level in self._features["high_res_feats"]
419
+ ]
420
+ low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
421
+ image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
422
+ image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
423
+ sparse_prompt_embeddings=sparse_embeddings,
424
+ dense_prompt_embeddings=dense_embeddings,
425
+ multimask_output=multimask_output,
426
+ repeat_image=batched_mode,
427
+ high_res_features=high_res_features,
428
+ )
429
+
430
+ # Upscale the masks to the original image resolution
431
+ masks = self._transforms.postprocess_masks(
432
+ low_res_masks, self._orig_hw[img_idx]
433
+ )
434
+ low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
435
+ if not return_logits:
436
+ masks = masks > self.mask_threshold
437
+
438
+ return masks, iou_predictions, low_res_masks
439
+
440
+ def get_image_embedding(self) -> torch.Tensor:
441
+ """
442
+ Returns the image embeddings for the currently set image, with
443
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
444
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
445
+ """
446
+ if not self._is_image_set:
447
+ raise RuntimeError(
448
+ "An image must be set with .set_image(...) to generate an embedding."
449
+ )
450
+ assert (
451
+ self._features is not None
452
+ ), "Features must exist if an image has been set."
453
+ return self._features["image_embed"]
454
+
455
+ @property
456
+ def device(self) -> torch.device:
457
+ return self.model.device
458
+
459
+ def reset_predictor(self) -> None:
460
+ """
461
+ Resets the image embeddings and other state variables.
462
+ """
463
+ self._is_image_set = False
464
+ self._features = None
465
+ self._orig_hw = None
466
+ self._is_batch = False