tyriaa commited on
Commit
5bd3244
·
1 Parent(s): 2241a90

Initialisation 00001

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. build_sam.py +0 -167
  3. sam2_image_predictor.py +0 -466
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
build_sam.py DELETED
@@ -1,167 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
- import os
9
-
10
- import torch
11
- from hydra import compose
12
- from hydra.utils import instantiate
13
- from omegaconf import OmegaConf
14
-
15
- import sam2
16
-
17
- # Check if the user is running Python from the parent directory of the sam2 repo
18
- # (i.e. the directory where this repo is cloned into) -- this is not supported since
19
- # it could shadow the sam2 package and cause issues.
20
- if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21
- # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22
- # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23
- # This typically happens because the user is running Python from the parent directory
24
- # that contains the sam2 repo they cloned.
25
- raise RuntimeError(
26
- "You're likely running Python from the parent directory of the sam2 repository "
27
- "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28
- "This is not supported since the `sam2` Python package could be shadowed by the "
29
- "repository name (the repository is also named `sam2` and contains the Python package "
30
- "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31
- "rather than its parent dir, or from your home directory) after installing SAM 2."
32
- )
33
-
34
-
35
- HF_MODEL_ID_TO_FILENAMES = {
36
- "facebook/sam2-hiera-tiny": (
37
- "configs/sam2/sam2_hiera_t.yaml",
38
- "sam2_hiera_tiny.pt",
39
- ),
40
- "facebook/sam2-hiera-small": (
41
- "configs/sam2/sam2_hiera_s.yaml",
42
- "sam2_hiera_small.pt",
43
- ),
44
- "facebook/sam2-hiera-base-plus": (
45
- "configs/sam2/sam2_hiera_b+.yaml",
46
- "sam2_hiera_base_plus.pt",
47
- ),
48
- "facebook/sam2-hiera-large": (
49
- "configs/sam2/sam2_hiera_l.yaml",
50
- "sam2_hiera_large.pt",
51
- ),
52
- "facebook/sam2.1-hiera-tiny": (
53
- "configs/sam2.1/sam2.1_hiera_t.yaml",
54
- "sam2.1_hiera_tiny.pt",
55
- ),
56
- "facebook/sam2.1-hiera-small": (
57
- "configs/sam2.1/sam2.1_hiera_s.yaml",
58
- "sam2.1_hiera_small.pt",
59
- ),
60
- "facebook/sam2.1-hiera-base-plus": (
61
- "configs/sam2.1/sam2.1_hiera_b+.yaml",
62
- "sam2.1_hiera_base_plus.pt",
63
- ),
64
- "facebook/sam2.1-hiera-large": (
65
- "configs/sam2.1/sam2.1_hiera_l.yaml",
66
- "sam2.1_hiera_large.pt",
67
- ),
68
- }
69
-
70
-
71
- def build_sam2(
72
- config_file,
73
- ckpt_path=None,
74
- device="cuda",
75
- mode="eval",
76
- hydra_overrides_extra=[],
77
- apply_postprocessing=True,
78
- **kwargs,
79
- ):
80
-
81
- if apply_postprocessing:
82
- hydra_overrides_extra = hydra_overrides_extra.copy()
83
- hydra_overrides_extra += [
84
- # dynamically fall back to multi-mask if the single mask is not stable
85
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88
- ]
89
- # Read config and init model
90
- cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91
- OmegaConf.resolve(cfg)
92
- model = instantiate(cfg.model, _recursive_=True)
93
- _load_checkpoint(model, ckpt_path)
94
- model = model.to(device)
95
- if mode == "eval":
96
- model.eval()
97
- return model
98
-
99
-
100
- def build_sam2_video_predictor(
101
- config_file,
102
- ckpt_path=None,
103
- device="cuda",
104
- mode="eval",
105
- hydra_overrides_extra=[],
106
- apply_postprocessing=True,
107
- **kwargs,
108
- ):
109
- hydra_overrides = [
110
- "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111
- ]
112
- if apply_postprocessing:
113
- hydra_overrides_extra = hydra_overrides_extra.copy()
114
- hydra_overrides_extra += [
115
- # dynamically fall back to multi-mask if the single mask is not stable
116
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
117
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
118
- "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
119
- # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
120
- "++model.binarize_mask_from_pts_for_mem_enc=true",
121
- # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
122
- "++model.fill_hole_area=8",
123
- ]
124
- hydra_overrides.extend(hydra_overrides_extra)
125
-
126
- # Read config and init model
127
- cfg = compose(config_name=config_file, overrides=hydra_overrides)
128
- OmegaConf.resolve(cfg)
129
- model = instantiate(cfg.model, _recursive_=True)
130
- _load_checkpoint(model, ckpt_path)
131
- model = model.to(device)
132
- if mode == "eval":
133
- model.eval()
134
- return model
135
-
136
-
137
- def _hf_download(model_id):
138
- from huggingface_hub import hf_hub_download
139
-
140
- config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
141
- ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
142
- return config_name, ckpt_path
143
-
144
-
145
- def build_sam2_hf(model_id, **kwargs):
146
- config_name, ckpt_path = _hf_download(model_id)
147
- return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
148
-
149
-
150
- def build_sam2_video_predictor_hf(model_id, **kwargs):
151
- config_name, ckpt_path = _hf_download(model_id)
152
- return build_sam2_video_predictor(
153
- config_file=config_name, ckpt_path=ckpt_path, **kwargs
154
- )
155
-
156
-
157
- def _load_checkpoint(model, ckpt_path):
158
- if ckpt_path is not None:
159
- sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
160
- missing_keys, unexpected_keys = model.load_state_dict(sd)
161
- if missing_keys:
162
- logging.error(missing_keys)
163
- raise RuntimeError()
164
- if unexpected_keys:
165
- logging.error(unexpected_keys)
166
- raise RuntimeError()
167
- logging.info("Loaded checkpoint sucessfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sam2_image_predictor.py DELETED
@@ -1,466 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
-
9
- from typing import List, Optional, Tuple, Union
10
-
11
- import numpy as np
12
- import torch
13
- from PIL.Image import Image
14
-
15
- from sam2.modeling.sam2_base import SAM2Base
16
-
17
- from sam2.utils.transforms import SAM2Transforms
18
-
19
-
20
- class SAM2ImagePredictor:
21
- def __init__(
22
- self,
23
- sam_model: SAM2Base,
24
- mask_threshold=0.0,
25
- max_hole_area=0.0,
26
- max_sprinkle_area=0.0,
27
- **kwargs,
28
- ) -> None:
29
- """
30
- Uses SAM-2 to calculate the image embedding for an image, and then
31
- allow repeated, efficient mask prediction given prompts.
32
-
33
- Arguments:
34
- sam_model (Sam-2): The model to use for mask prediction.
35
- mask_threshold (float): The threshold to use when converting mask logits
36
- to binary masks. Masks are thresholded at 0 by default.
37
- max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
38
- the maximum area of max_hole_area in low_res_masks.
39
- max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
40
- the maximum area of max_sprinkle_area in low_res_masks.
41
- """
42
- super().__init__()
43
- self.model = sam_model
44
- self._transforms = SAM2Transforms(
45
- resolution=self.model.image_size,
46
- mask_threshold=mask_threshold,
47
- max_hole_area=max_hole_area,
48
- max_sprinkle_area=max_sprinkle_area,
49
- )
50
-
51
- # Predictor state
52
- self._is_image_set = False
53
- self._features = None
54
- self._orig_hw = None
55
- # Whether the predictor is set for single image or a batch of images
56
- self._is_batch = False
57
-
58
- # Predictor config
59
- self.mask_threshold = mask_threshold
60
-
61
- # Spatial dim for backbone feature maps
62
- self._bb_feat_sizes = [
63
- (256, 256),
64
- (128, 128),
65
- (64, 64),
66
- ]
67
-
68
- @classmethod
69
- def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
70
- """
71
- Load a pretrained model from the Hugging Face hub.
72
-
73
- Arguments:
74
- model_id (str): The Hugging Face repository ID.
75
- **kwargs: Additional arguments to pass to the model constructor.
76
-
77
- Returns:
78
- (SAM2ImagePredictor): The loaded model.
79
- """
80
- from sam2.build_sam import build_sam2_hf
81
-
82
- sam_model = build_sam2_hf(model_id, **kwargs)
83
- return cls(sam_model, **kwargs)
84
-
85
- @torch.no_grad()
86
- def set_image(
87
- self,
88
- image: Union[np.ndarray, Image],
89
- ) -> None:
90
- """
91
- Calculates the image embeddings for the provided image, allowing
92
- masks to be predicted with the 'predict' method.
93
-
94
- Arguments:
95
- image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
96
- with pixel values in [0, 255].
97
- image_format (str): The color format of the image, in ['RGB', 'BGR'].
98
- """
99
- self.reset_predictor()
100
- # Transform the image to the form expected by the model
101
- if isinstance(image, np.ndarray):
102
- logging.info("For numpy array image, we assume (HxWxC) format")
103
- self._orig_hw = [image.shape[:2]]
104
- elif isinstance(image, Image):
105
- w, h = image.size
106
- self._orig_hw = [(h, w)]
107
- else:
108
- raise NotImplementedError("Image format not supported")
109
-
110
- input_image = self._transforms(image)
111
- input_image = input_image[None, ...].to(self.device)
112
-
113
- assert (
114
- len(input_image.shape) == 4 and input_image.shape[1] == 3
115
- ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
116
- logging.info("Computing image embeddings for the provided image...")
117
- backbone_out = self.model.forward_image(input_image)
118
- _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
119
- # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
120
- if self.model.directly_add_no_mem_embed:
121
- vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
122
-
123
- feats = [
124
- feat.permute(1, 2, 0).view(1, -1, *feat_size)
125
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
126
- ][::-1]
127
- self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
128
- self._is_image_set = True
129
- logging.info("Image embeddings computed.")
130
-
131
- @torch.no_grad()
132
- def set_image_batch(
133
- self,
134
- image_list: List[Union[np.ndarray]],
135
- ) -> None:
136
- """
137
- Calculates the image embeddings for the provided image batch, allowing
138
- masks to be predicted with the 'predict_batch' method.
139
-
140
- Arguments:
141
- image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
142
- with pixel values in [0, 255].
143
- """
144
- self.reset_predictor()
145
- assert isinstance(image_list, list)
146
- self._orig_hw = []
147
- for image in image_list:
148
- assert isinstance(
149
- image, np.ndarray
150
- ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
151
- self._orig_hw.append(image.shape[:2])
152
- # Transform the image to the form expected by the model
153
- img_batch = self._transforms.forward_batch(image_list)
154
- img_batch = img_batch.to(self.device)
155
- batch_size = img_batch.shape[0]
156
- assert (
157
- len(img_batch.shape) == 4 and img_batch.shape[1] == 3
158
- ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
159
- logging.info("Computing image embeddings for the provided images...")
160
- backbone_out = self.model.forward_image(img_batch)
161
- _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
162
- # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
163
- if self.model.directly_add_no_mem_embed:
164
- vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
165
-
166
- feats = [
167
- feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
168
- for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
169
- ][::-1]
170
- self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
171
- self._is_image_set = True
172
- self._is_batch = True
173
- logging.info("Image embeddings computed.")
174
-
175
- def predict_batch(
176
- self,
177
- point_coords_batch: List[np.ndarray] = None,
178
- point_labels_batch: List[np.ndarray] = None,
179
- box_batch: List[np.ndarray] = None,
180
- mask_input_batch: List[np.ndarray] = None,
181
- multimask_output: bool = True,
182
- return_logits: bool = False,
183
- normalize_coords=True,
184
- ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
185
- """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
186
- It returns a tuple of lists of masks, ious, and low_res_masks_logits.
187
- """
188
- assert self._is_batch, "This function should only be used when in batched mode"
189
- if not self._is_image_set:
190
- raise RuntimeError(
191
- "An image must be set with .set_image_batch(...) before mask prediction."
192
- )
193
- num_images = len(self._features["image_embed"])
194
- all_masks = []
195
- all_ious = []
196
- all_low_res_masks = []
197
- for img_idx in range(num_images):
198
- # Transform input prompts
199
- point_coords = (
200
- point_coords_batch[img_idx] if point_coords_batch is not None else None
201
- )
202
- point_labels = (
203
- point_labels_batch[img_idx] if point_labels_batch is not None else None
204
- )
205
- box = box_batch[img_idx] if box_batch is not None else None
206
- mask_input = (
207
- mask_input_batch[img_idx] if mask_input_batch is not None else None
208
- )
209
- mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
210
- point_coords,
211
- point_labels,
212
- box,
213
- mask_input,
214
- normalize_coords,
215
- img_idx=img_idx,
216
- )
217
- masks, iou_predictions, low_res_masks = self._predict(
218
- unnorm_coords,
219
- labels,
220
- unnorm_box,
221
- mask_input,
222
- multimask_output,
223
- return_logits=return_logits,
224
- img_idx=img_idx,
225
- )
226
- masks_np = masks.squeeze(0).float().detach().cpu().numpy()
227
- iou_predictions_np = (
228
- iou_predictions.squeeze(0).float().detach().cpu().numpy()
229
- )
230
- low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
231
- all_masks.append(masks_np)
232
- all_ious.append(iou_predictions_np)
233
- all_low_res_masks.append(low_res_masks_np)
234
-
235
- return all_masks, all_ious, all_low_res_masks
236
-
237
- def predict(
238
- self,
239
- point_coords: Optional[np.ndarray] = None,
240
- point_labels: Optional[np.ndarray] = None,
241
- box: Optional[np.ndarray] = None,
242
- mask_input: Optional[np.ndarray] = None,
243
- multimask_output: bool = True,
244
- return_logits: bool = False,
245
- normalize_coords=True,
246
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
247
- """
248
- Predict masks for the given input prompts, using the currently set image.
249
-
250
- Arguments:
251
- point_coords (np.ndarray or None): A Nx2 array of point prompts to the
252
- model. Each point is in (X,Y) in pixels.
253
- point_labels (np.ndarray or None): A length N array of labels for the
254
- point prompts. 1 indicates a foreground point and 0 indicates a
255
- background point.
256
- box (np.ndarray or None): A length 4 array given a box prompt to the
257
- model, in XYXY format.
258
- mask_input (np.ndarray): A low resolution mask input to the model, typically
259
- coming from a previous prediction iteration. Has form 1xHxW, where
260
- for SAM, H=W=256.
261
- multimask_output (bool): If true, the model will return three masks.
262
- For ambiguous input prompts (such as a single click), this will often
263
- produce better masks than a single prediction. If only a single
264
- mask is needed, the model's predicted quality score can be used
265
- to select the best mask. For non-ambiguous prompts, such as multiple
266
- input prompts, multimask_output=False can give better results.
267
- return_logits (bool): If true, returns un-thresholded masks logits
268
- instead of a binary mask.
269
- normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
270
-
271
- Returns:
272
- (np.ndarray): The output masks in CxHxW format, where C is the
273
- number of masks, and (H, W) is the original image size.
274
- (np.ndarray): An array of length C containing the model's
275
- predictions for the quality of each mask.
276
- (np.ndarray): An array of shape CxHxW, where C is the number
277
- of masks and H=W=256. These low resolution logits can be passed to
278
- a subsequent iteration as mask input.
279
- """
280
- if not self._is_image_set:
281
- raise RuntimeError(
282
- "An image must be set with .set_image(...) before mask prediction."
283
- )
284
-
285
- # Transform input prompts
286
-
287
- mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
288
- point_coords, point_labels, box, mask_input, normalize_coords
289
- )
290
-
291
- masks, iou_predictions, low_res_masks = self._predict(
292
- unnorm_coords,
293
- labels,
294
- unnorm_box,
295
- mask_input,
296
- multimask_output,
297
- return_logits=return_logits,
298
- )
299
-
300
- masks_np = masks.squeeze(0).float().detach().cpu().numpy()
301
- iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
302
- low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
303
- return masks_np, iou_predictions_np, low_res_masks_np
304
-
305
- def _prep_prompts(
306
- self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
307
- ):
308
-
309
- unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
310
- if point_coords is not None:
311
- assert (
312
- point_labels is not None
313
- ), "point_labels must be supplied if point_coords is supplied."
314
- point_coords = torch.as_tensor(
315
- point_coords, dtype=torch.float, device=self.device
316
- )
317
- unnorm_coords = self._transforms.transform_coords(
318
- point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
319
- )
320
- labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
321
- if len(unnorm_coords.shape) == 2:
322
- unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
323
- if box is not None:
324
- box = torch.as_tensor(box, dtype=torch.float, device=self.device)
325
- unnorm_box = self._transforms.transform_boxes(
326
- box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
327
- ) # Bx2x2
328
- if mask_logits is not None:
329
- mask_input = torch.as_tensor(
330
- mask_logits, dtype=torch.float, device=self.device
331
- )
332
- if len(mask_input.shape) == 3:
333
- mask_input = mask_input[None, :, :, :]
334
- return mask_input, unnorm_coords, labels, unnorm_box
335
-
336
- @torch.no_grad()
337
- def _predict(
338
- self,
339
- point_coords: Optional[torch.Tensor],
340
- point_labels: Optional[torch.Tensor],
341
- boxes: Optional[torch.Tensor] = None,
342
- mask_input: Optional[torch.Tensor] = None,
343
- multimask_output: bool = True,
344
- return_logits: bool = False,
345
- img_idx: int = -1,
346
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
347
- """
348
- Predict masks for the given input prompts, using the currently set image.
349
- Input prompts are batched torch tensors and are expected to already be
350
- transformed to the input frame using SAM2Transforms.
351
-
352
- Arguments:
353
- point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
354
- model. Each point is in (X,Y) in pixels.
355
- point_labels (torch.Tensor or None): A BxN array of labels for the
356
- point prompts. 1 indicates a foreground point and 0 indicates a
357
- background point.
358
- boxes (np.ndarray or None): A Bx4 array given a box prompt to the
359
- model, in XYXY format.
360
- mask_input (np.ndarray): A low resolution mask input to the model, typically
361
- coming from a previous prediction iteration. Has form Bx1xHxW, where
362
- for SAM, H=W=256. Masks returned by a previous iteration of the
363
- predict method do not need further transformation.
364
- multimask_output (bool): If true, the model will return three masks.
365
- For ambiguous input prompts (such as a single click), this will often
366
- produce better masks than a single prediction. If only a single
367
- mask is needed, the model's predicted quality score can be used
368
- to select the best mask. For non-ambiguous prompts, such as multiple
369
- input prompts, multimask_output=False can give better results.
370
- return_logits (bool): If true, returns un-thresholded masks logits
371
- instead of a binary mask.
372
-
373
- Returns:
374
- (torch.Tensor): The output masks in BxCxHxW format, where C is the
375
- number of masks, and (H, W) is the original image size.
376
- (torch.Tensor): An array of shape BxC containing the model's
377
- predictions for the quality of each mask.
378
- (torch.Tensor): An array of shape BxCxHxW, where C is the number
379
- of masks and H=W=256. These low res logits can be passed to
380
- a subsequent iteration as mask input.
381
- """
382
- if not self._is_image_set:
383
- raise RuntimeError(
384
- "An image must be set with .set_image(...) before mask prediction."
385
- )
386
-
387
- if point_coords is not None:
388
- concat_points = (point_coords, point_labels)
389
- else:
390
- concat_points = None
391
-
392
- # Embed prompts
393
- if boxes is not None:
394
- box_coords = boxes.reshape(-1, 2, 2)
395
- box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
396
- box_labels = box_labels.repeat(boxes.size(0), 1)
397
- # we merge "boxes" and "points" into a single "concat_points" input (where
398
- # boxes are added at the beginning) to sam_prompt_encoder
399
- if concat_points is not None:
400
- concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
401
- concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
402
- concat_points = (concat_coords, concat_labels)
403
- else:
404
- concat_points = (box_coords, box_labels)
405
-
406
- sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
407
- points=concat_points,
408
- boxes=None,
409
- masks=mask_input,
410
- )
411
-
412
- # Predict masks
413
- batched_mode = (
414
- concat_points is not None and concat_points[0].shape[0] > 1
415
- ) # multi object prediction
416
- high_res_features = [
417
- feat_level[img_idx].unsqueeze(0)
418
- for feat_level in self._features["high_res_feats"]
419
- ]
420
- low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
421
- image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
422
- image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
423
- sparse_prompt_embeddings=sparse_embeddings,
424
- dense_prompt_embeddings=dense_embeddings,
425
- multimask_output=multimask_output,
426
- repeat_image=batched_mode,
427
- high_res_features=high_res_features,
428
- )
429
-
430
- # Upscale the masks to the original image resolution
431
- masks = self._transforms.postprocess_masks(
432
- low_res_masks, self._orig_hw[img_idx]
433
- )
434
- low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
435
- if not return_logits:
436
- masks = masks > self.mask_threshold
437
-
438
- return masks, iou_predictions, low_res_masks
439
-
440
- def get_image_embedding(self) -> torch.Tensor:
441
- """
442
- Returns the image embeddings for the currently set image, with
443
- shape 1xCxHxW, where C is the embedding dimension and (H,W) are
444
- the embedding spatial dimension of SAM (typically C=256, H=W=64).
445
- """
446
- if not self._is_image_set:
447
- raise RuntimeError(
448
- "An image must be set with .set_image(...) to generate an embedding."
449
- )
450
- assert (
451
- self._features is not None
452
- ), "Features must exist if an image has been set."
453
- return self._features["image_embed"]
454
-
455
- @property
456
- def device(self) -> torch.device:
457
- return self.model.device
458
-
459
- def reset_predictor(self) -> None:
460
- """
461
- Resets the image embeddings and other state variables.
462
- """
463
- self._is_image_set = False
464
- self._features = None
465
- self._orig_hw = None
466
- self._is_batch = False