Spaces:
Runtime error
Runtime error
Initialisation 00001
Browse files- .DS_Store +0 -0
- build_sam.py +0 -167
- sam2_image_predictor.py +0 -466
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
build_sam.py
DELETED
@@ -1,167 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import logging
|
8 |
-
import os
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from hydra import compose
|
12 |
-
from hydra.utils import instantiate
|
13 |
-
from omegaconf import OmegaConf
|
14 |
-
|
15 |
-
import sam2
|
16 |
-
|
17 |
-
# Check if the user is running Python from the parent directory of the sam2 repo
|
18 |
-
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
19 |
-
# it could shadow the sam2 package and cause issues.
|
20 |
-
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
21 |
-
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
22 |
-
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
23 |
-
# This typically happens because the user is running Python from the parent directory
|
24 |
-
# that contains the sam2 repo they cloned.
|
25 |
-
raise RuntimeError(
|
26 |
-
"You're likely running Python from the parent directory of the sam2 repository "
|
27 |
-
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
28 |
-
"This is not supported since the `sam2` Python package could be shadowed by the "
|
29 |
-
"repository name (the repository is also named `sam2` and contains the Python package "
|
30 |
-
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
31 |
-
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
32 |
-
)
|
33 |
-
|
34 |
-
|
35 |
-
HF_MODEL_ID_TO_FILENAMES = {
|
36 |
-
"facebook/sam2-hiera-tiny": (
|
37 |
-
"configs/sam2/sam2_hiera_t.yaml",
|
38 |
-
"sam2_hiera_tiny.pt",
|
39 |
-
),
|
40 |
-
"facebook/sam2-hiera-small": (
|
41 |
-
"configs/sam2/sam2_hiera_s.yaml",
|
42 |
-
"sam2_hiera_small.pt",
|
43 |
-
),
|
44 |
-
"facebook/sam2-hiera-base-plus": (
|
45 |
-
"configs/sam2/sam2_hiera_b+.yaml",
|
46 |
-
"sam2_hiera_base_plus.pt",
|
47 |
-
),
|
48 |
-
"facebook/sam2-hiera-large": (
|
49 |
-
"configs/sam2/sam2_hiera_l.yaml",
|
50 |
-
"sam2_hiera_large.pt",
|
51 |
-
),
|
52 |
-
"facebook/sam2.1-hiera-tiny": (
|
53 |
-
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
54 |
-
"sam2.1_hiera_tiny.pt",
|
55 |
-
),
|
56 |
-
"facebook/sam2.1-hiera-small": (
|
57 |
-
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
58 |
-
"sam2.1_hiera_small.pt",
|
59 |
-
),
|
60 |
-
"facebook/sam2.1-hiera-base-plus": (
|
61 |
-
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
62 |
-
"sam2.1_hiera_base_plus.pt",
|
63 |
-
),
|
64 |
-
"facebook/sam2.1-hiera-large": (
|
65 |
-
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
66 |
-
"sam2.1_hiera_large.pt",
|
67 |
-
),
|
68 |
-
}
|
69 |
-
|
70 |
-
|
71 |
-
def build_sam2(
|
72 |
-
config_file,
|
73 |
-
ckpt_path=None,
|
74 |
-
device="cuda",
|
75 |
-
mode="eval",
|
76 |
-
hydra_overrides_extra=[],
|
77 |
-
apply_postprocessing=True,
|
78 |
-
**kwargs,
|
79 |
-
):
|
80 |
-
|
81 |
-
if apply_postprocessing:
|
82 |
-
hydra_overrides_extra = hydra_overrides_extra.copy()
|
83 |
-
hydra_overrides_extra += [
|
84 |
-
# dynamically fall back to multi-mask if the single mask is not stable
|
85 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
86 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
87 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
88 |
-
]
|
89 |
-
# Read config and init model
|
90 |
-
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
91 |
-
OmegaConf.resolve(cfg)
|
92 |
-
model = instantiate(cfg.model, _recursive_=True)
|
93 |
-
_load_checkpoint(model, ckpt_path)
|
94 |
-
model = model.to(device)
|
95 |
-
if mode == "eval":
|
96 |
-
model.eval()
|
97 |
-
return model
|
98 |
-
|
99 |
-
|
100 |
-
def build_sam2_video_predictor(
|
101 |
-
config_file,
|
102 |
-
ckpt_path=None,
|
103 |
-
device="cuda",
|
104 |
-
mode="eval",
|
105 |
-
hydra_overrides_extra=[],
|
106 |
-
apply_postprocessing=True,
|
107 |
-
**kwargs,
|
108 |
-
):
|
109 |
-
hydra_overrides = [
|
110 |
-
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
111 |
-
]
|
112 |
-
if apply_postprocessing:
|
113 |
-
hydra_overrides_extra = hydra_overrides_extra.copy()
|
114 |
-
hydra_overrides_extra += [
|
115 |
-
# dynamically fall back to multi-mask if the single mask is not stable
|
116 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
117 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
118 |
-
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
119 |
-
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
120 |
-
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
121 |
-
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
122 |
-
"++model.fill_hole_area=8",
|
123 |
-
]
|
124 |
-
hydra_overrides.extend(hydra_overrides_extra)
|
125 |
-
|
126 |
-
# Read config and init model
|
127 |
-
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
128 |
-
OmegaConf.resolve(cfg)
|
129 |
-
model = instantiate(cfg.model, _recursive_=True)
|
130 |
-
_load_checkpoint(model, ckpt_path)
|
131 |
-
model = model.to(device)
|
132 |
-
if mode == "eval":
|
133 |
-
model.eval()
|
134 |
-
return model
|
135 |
-
|
136 |
-
|
137 |
-
def _hf_download(model_id):
|
138 |
-
from huggingface_hub import hf_hub_download
|
139 |
-
|
140 |
-
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
141 |
-
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
142 |
-
return config_name, ckpt_path
|
143 |
-
|
144 |
-
|
145 |
-
def build_sam2_hf(model_id, **kwargs):
|
146 |
-
config_name, ckpt_path = _hf_download(model_id)
|
147 |
-
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
148 |
-
|
149 |
-
|
150 |
-
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
151 |
-
config_name, ckpt_path = _hf_download(model_id)
|
152 |
-
return build_sam2_video_predictor(
|
153 |
-
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
154 |
-
)
|
155 |
-
|
156 |
-
|
157 |
-
def _load_checkpoint(model, ckpt_path):
|
158 |
-
if ckpt_path is not None:
|
159 |
-
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
160 |
-
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
161 |
-
if missing_keys:
|
162 |
-
logging.error(missing_keys)
|
163 |
-
raise RuntimeError()
|
164 |
-
if unexpected_keys:
|
165 |
-
logging.error(unexpected_keys)
|
166 |
-
raise RuntimeError()
|
167 |
-
logging.info("Loaded checkpoint sucessfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sam2_image_predictor.py
DELETED
@@ -1,466 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import logging
|
8 |
-
|
9 |
-
from typing import List, Optional, Tuple, Union
|
10 |
-
|
11 |
-
import numpy as np
|
12 |
-
import torch
|
13 |
-
from PIL.Image import Image
|
14 |
-
|
15 |
-
from sam2.modeling.sam2_base import SAM2Base
|
16 |
-
|
17 |
-
from sam2.utils.transforms import SAM2Transforms
|
18 |
-
|
19 |
-
|
20 |
-
class SAM2ImagePredictor:
|
21 |
-
def __init__(
|
22 |
-
self,
|
23 |
-
sam_model: SAM2Base,
|
24 |
-
mask_threshold=0.0,
|
25 |
-
max_hole_area=0.0,
|
26 |
-
max_sprinkle_area=0.0,
|
27 |
-
**kwargs,
|
28 |
-
) -> None:
|
29 |
-
"""
|
30 |
-
Uses SAM-2 to calculate the image embedding for an image, and then
|
31 |
-
allow repeated, efficient mask prediction given prompts.
|
32 |
-
|
33 |
-
Arguments:
|
34 |
-
sam_model (Sam-2): The model to use for mask prediction.
|
35 |
-
mask_threshold (float): The threshold to use when converting mask logits
|
36 |
-
to binary masks. Masks are thresholded at 0 by default.
|
37 |
-
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
38 |
-
the maximum area of max_hole_area in low_res_masks.
|
39 |
-
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
40 |
-
the maximum area of max_sprinkle_area in low_res_masks.
|
41 |
-
"""
|
42 |
-
super().__init__()
|
43 |
-
self.model = sam_model
|
44 |
-
self._transforms = SAM2Transforms(
|
45 |
-
resolution=self.model.image_size,
|
46 |
-
mask_threshold=mask_threshold,
|
47 |
-
max_hole_area=max_hole_area,
|
48 |
-
max_sprinkle_area=max_sprinkle_area,
|
49 |
-
)
|
50 |
-
|
51 |
-
# Predictor state
|
52 |
-
self._is_image_set = False
|
53 |
-
self._features = None
|
54 |
-
self._orig_hw = None
|
55 |
-
# Whether the predictor is set for single image or a batch of images
|
56 |
-
self._is_batch = False
|
57 |
-
|
58 |
-
# Predictor config
|
59 |
-
self.mask_threshold = mask_threshold
|
60 |
-
|
61 |
-
# Spatial dim for backbone feature maps
|
62 |
-
self._bb_feat_sizes = [
|
63 |
-
(256, 256),
|
64 |
-
(128, 128),
|
65 |
-
(64, 64),
|
66 |
-
]
|
67 |
-
|
68 |
-
@classmethod
|
69 |
-
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
70 |
-
"""
|
71 |
-
Load a pretrained model from the Hugging Face hub.
|
72 |
-
|
73 |
-
Arguments:
|
74 |
-
model_id (str): The Hugging Face repository ID.
|
75 |
-
**kwargs: Additional arguments to pass to the model constructor.
|
76 |
-
|
77 |
-
Returns:
|
78 |
-
(SAM2ImagePredictor): The loaded model.
|
79 |
-
"""
|
80 |
-
from sam2.build_sam import build_sam2_hf
|
81 |
-
|
82 |
-
sam_model = build_sam2_hf(model_id, **kwargs)
|
83 |
-
return cls(sam_model, **kwargs)
|
84 |
-
|
85 |
-
@torch.no_grad()
|
86 |
-
def set_image(
|
87 |
-
self,
|
88 |
-
image: Union[np.ndarray, Image],
|
89 |
-
) -> None:
|
90 |
-
"""
|
91 |
-
Calculates the image embeddings for the provided image, allowing
|
92 |
-
masks to be predicted with the 'predict' method.
|
93 |
-
|
94 |
-
Arguments:
|
95 |
-
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
|
96 |
-
with pixel values in [0, 255].
|
97 |
-
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
98 |
-
"""
|
99 |
-
self.reset_predictor()
|
100 |
-
# Transform the image to the form expected by the model
|
101 |
-
if isinstance(image, np.ndarray):
|
102 |
-
logging.info("For numpy array image, we assume (HxWxC) format")
|
103 |
-
self._orig_hw = [image.shape[:2]]
|
104 |
-
elif isinstance(image, Image):
|
105 |
-
w, h = image.size
|
106 |
-
self._orig_hw = [(h, w)]
|
107 |
-
else:
|
108 |
-
raise NotImplementedError("Image format not supported")
|
109 |
-
|
110 |
-
input_image = self._transforms(image)
|
111 |
-
input_image = input_image[None, ...].to(self.device)
|
112 |
-
|
113 |
-
assert (
|
114 |
-
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
115 |
-
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
116 |
-
logging.info("Computing image embeddings for the provided image...")
|
117 |
-
backbone_out = self.model.forward_image(input_image)
|
118 |
-
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
119 |
-
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
120 |
-
if self.model.directly_add_no_mem_embed:
|
121 |
-
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
122 |
-
|
123 |
-
feats = [
|
124 |
-
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
125 |
-
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
126 |
-
][::-1]
|
127 |
-
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
128 |
-
self._is_image_set = True
|
129 |
-
logging.info("Image embeddings computed.")
|
130 |
-
|
131 |
-
@torch.no_grad()
|
132 |
-
def set_image_batch(
|
133 |
-
self,
|
134 |
-
image_list: List[Union[np.ndarray]],
|
135 |
-
) -> None:
|
136 |
-
"""
|
137 |
-
Calculates the image embeddings for the provided image batch, allowing
|
138 |
-
masks to be predicted with the 'predict_batch' method.
|
139 |
-
|
140 |
-
Arguments:
|
141 |
-
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
|
142 |
-
with pixel values in [0, 255].
|
143 |
-
"""
|
144 |
-
self.reset_predictor()
|
145 |
-
assert isinstance(image_list, list)
|
146 |
-
self._orig_hw = []
|
147 |
-
for image in image_list:
|
148 |
-
assert isinstance(
|
149 |
-
image, np.ndarray
|
150 |
-
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
151 |
-
self._orig_hw.append(image.shape[:2])
|
152 |
-
# Transform the image to the form expected by the model
|
153 |
-
img_batch = self._transforms.forward_batch(image_list)
|
154 |
-
img_batch = img_batch.to(self.device)
|
155 |
-
batch_size = img_batch.shape[0]
|
156 |
-
assert (
|
157 |
-
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
158 |
-
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
159 |
-
logging.info("Computing image embeddings for the provided images...")
|
160 |
-
backbone_out = self.model.forward_image(img_batch)
|
161 |
-
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
162 |
-
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
163 |
-
if self.model.directly_add_no_mem_embed:
|
164 |
-
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
165 |
-
|
166 |
-
feats = [
|
167 |
-
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
168 |
-
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
169 |
-
][::-1]
|
170 |
-
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
171 |
-
self._is_image_set = True
|
172 |
-
self._is_batch = True
|
173 |
-
logging.info("Image embeddings computed.")
|
174 |
-
|
175 |
-
def predict_batch(
|
176 |
-
self,
|
177 |
-
point_coords_batch: List[np.ndarray] = None,
|
178 |
-
point_labels_batch: List[np.ndarray] = None,
|
179 |
-
box_batch: List[np.ndarray] = None,
|
180 |
-
mask_input_batch: List[np.ndarray] = None,
|
181 |
-
multimask_output: bool = True,
|
182 |
-
return_logits: bool = False,
|
183 |
-
normalize_coords=True,
|
184 |
-
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
185 |
-
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
186 |
-
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
187 |
-
"""
|
188 |
-
assert self._is_batch, "This function should only be used when in batched mode"
|
189 |
-
if not self._is_image_set:
|
190 |
-
raise RuntimeError(
|
191 |
-
"An image must be set with .set_image_batch(...) before mask prediction."
|
192 |
-
)
|
193 |
-
num_images = len(self._features["image_embed"])
|
194 |
-
all_masks = []
|
195 |
-
all_ious = []
|
196 |
-
all_low_res_masks = []
|
197 |
-
for img_idx in range(num_images):
|
198 |
-
# Transform input prompts
|
199 |
-
point_coords = (
|
200 |
-
point_coords_batch[img_idx] if point_coords_batch is not None else None
|
201 |
-
)
|
202 |
-
point_labels = (
|
203 |
-
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
204 |
-
)
|
205 |
-
box = box_batch[img_idx] if box_batch is not None else None
|
206 |
-
mask_input = (
|
207 |
-
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
208 |
-
)
|
209 |
-
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
210 |
-
point_coords,
|
211 |
-
point_labels,
|
212 |
-
box,
|
213 |
-
mask_input,
|
214 |
-
normalize_coords,
|
215 |
-
img_idx=img_idx,
|
216 |
-
)
|
217 |
-
masks, iou_predictions, low_res_masks = self._predict(
|
218 |
-
unnorm_coords,
|
219 |
-
labels,
|
220 |
-
unnorm_box,
|
221 |
-
mask_input,
|
222 |
-
multimask_output,
|
223 |
-
return_logits=return_logits,
|
224 |
-
img_idx=img_idx,
|
225 |
-
)
|
226 |
-
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
227 |
-
iou_predictions_np = (
|
228 |
-
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
229 |
-
)
|
230 |
-
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
231 |
-
all_masks.append(masks_np)
|
232 |
-
all_ious.append(iou_predictions_np)
|
233 |
-
all_low_res_masks.append(low_res_masks_np)
|
234 |
-
|
235 |
-
return all_masks, all_ious, all_low_res_masks
|
236 |
-
|
237 |
-
def predict(
|
238 |
-
self,
|
239 |
-
point_coords: Optional[np.ndarray] = None,
|
240 |
-
point_labels: Optional[np.ndarray] = None,
|
241 |
-
box: Optional[np.ndarray] = None,
|
242 |
-
mask_input: Optional[np.ndarray] = None,
|
243 |
-
multimask_output: bool = True,
|
244 |
-
return_logits: bool = False,
|
245 |
-
normalize_coords=True,
|
246 |
-
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
247 |
-
"""
|
248 |
-
Predict masks for the given input prompts, using the currently set image.
|
249 |
-
|
250 |
-
Arguments:
|
251 |
-
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
252 |
-
model. Each point is in (X,Y) in pixels.
|
253 |
-
point_labels (np.ndarray or None): A length N array of labels for the
|
254 |
-
point prompts. 1 indicates a foreground point and 0 indicates a
|
255 |
-
background point.
|
256 |
-
box (np.ndarray or None): A length 4 array given a box prompt to the
|
257 |
-
model, in XYXY format.
|
258 |
-
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
259 |
-
coming from a previous prediction iteration. Has form 1xHxW, where
|
260 |
-
for SAM, H=W=256.
|
261 |
-
multimask_output (bool): If true, the model will return three masks.
|
262 |
-
For ambiguous input prompts (such as a single click), this will often
|
263 |
-
produce better masks than a single prediction. If only a single
|
264 |
-
mask is needed, the model's predicted quality score can be used
|
265 |
-
to select the best mask. For non-ambiguous prompts, such as multiple
|
266 |
-
input prompts, multimask_output=False can give better results.
|
267 |
-
return_logits (bool): If true, returns un-thresholded masks logits
|
268 |
-
instead of a binary mask.
|
269 |
-
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
|
270 |
-
|
271 |
-
Returns:
|
272 |
-
(np.ndarray): The output masks in CxHxW format, where C is the
|
273 |
-
number of masks, and (H, W) is the original image size.
|
274 |
-
(np.ndarray): An array of length C containing the model's
|
275 |
-
predictions for the quality of each mask.
|
276 |
-
(np.ndarray): An array of shape CxHxW, where C is the number
|
277 |
-
of masks and H=W=256. These low resolution logits can be passed to
|
278 |
-
a subsequent iteration as mask input.
|
279 |
-
"""
|
280 |
-
if not self._is_image_set:
|
281 |
-
raise RuntimeError(
|
282 |
-
"An image must be set with .set_image(...) before mask prediction."
|
283 |
-
)
|
284 |
-
|
285 |
-
# Transform input prompts
|
286 |
-
|
287 |
-
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
288 |
-
point_coords, point_labels, box, mask_input, normalize_coords
|
289 |
-
)
|
290 |
-
|
291 |
-
masks, iou_predictions, low_res_masks = self._predict(
|
292 |
-
unnorm_coords,
|
293 |
-
labels,
|
294 |
-
unnorm_box,
|
295 |
-
mask_input,
|
296 |
-
multimask_output,
|
297 |
-
return_logits=return_logits,
|
298 |
-
)
|
299 |
-
|
300 |
-
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
301 |
-
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
302 |
-
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
303 |
-
return masks_np, iou_predictions_np, low_res_masks_np
|
304 |
-
|
305 |
-
def _prep_prompts(
|
306 |
-
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
307 |
-
):
|
308 |
-
|
309 |
-
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
310 |
-
if point_coords is not None:
|
311 |
-
assert (
|
312 |
-
point_labels is not None
|
313 |
-
), "point_labels must be supplied if point_coords is supplied."
|
314 |
-
point_coords = torch.as_tensor(
|
315 |
-
point_coords, dtype=torch.float, device=self.device
|
316 |
-
)
|
317 |
-
unnorm_coords = self._transforms.transform_coords(
|
318 |
-
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
319 |
-
)
|
320 |
-
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
321 |
-
if len(unnorm_coords.shape) == 2:
|
322 |
-
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
323 |
-
if box is not None:
|
324 |
-
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
325 |
-
unnorm_box = self._transforms.transform_boxes(
|
326 |
-
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
327 |
-
) # Bx2x2
|
328 |
-
if mask_logits is not None:
|
329 |
-
mask_input = torch.as_tensor(
|
330 |
-
mask_logits, dtype=torch.float, device=self.device
|
331 |
-
)
|
332 |
-
if len(mask_input.shape) == 3:
|
333 |
-
mask_input = mask_input[None, :, :, :]
|
334 |
-
return mask_input, unnorm_coords, labels, unnorm_box
|
335 |
-
|
336 |
-
@torch.no_grad()
|
337 |
-
def _predict(
|
338 |
-
self,
|
339 |
-
point_coords: Optional[torch.Tensor],
|
340 |
-
point_labels: Optional[torch.Tensor],
|
341 |
-
boxes: Optional[torch.Tensor] = None,
|
342 |
-
mask_input: Optional[torch.Tensor] = None,
|
343 |
-
multimask_output: bool = True,
|
344 |
-
return_logits: bool = False,
|
345 |
-
img_idx: int = -1,
|
346 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
347 |
-
"""
|
348 |
-
Predict masks for the given input prompts, using the currently set image.
|
349 |
-
Input prompts are batched torch tensors and are expected to already be
|
350 |
-
transformed to the input frame using SAM2Transforms.
|
351 |
-
|
352 |
-
Arguments:
|
353 |
-
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
354 |
-
model. Each point is in (X,Y) in pixels.
|
355 |
-
point_labels (torch.Tensor or None): A BxN array of labels for the
|
356 |
-
point prompts. 1 indicates a foreground point and 0 indicates a
|
357 |
-
background point.
|
358 |
-
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
359 |
-
model, in XYXY format.
|
360 |
-
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
361 |
-
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
362 |
-
for SAM, H=W=256. Masks returned by a previous iteration of the
|
363 |
-
predict method do not need further transformation.
|
364 |
-
multimask_output (bool): If true, the model will return three masks.
|
365 |
-
For ambiguous input prompts (such as a single click), this will often
|
366 |
-
produce better masks than a single prediction. If only a single
|
367 |
-
mask is needed, the model's predicted quality score can be used
|
368 |
-
to select the best mask. For non-ambiguous prompts, such as multiple
|
369 |
-
input prompts, multimask_output=False can give better results.
|
370 |
-
return_logits (bool): If true, returns un-thresholded masks logits
|
371 |
-
instead of a binary mask.
|
372 |
-
|
373 |
-
Returns:
|
374 |
-
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
375 |
-
number of masks, and (H, W) is the original image size.
|
376 |
-
(torch.Tensor): An array of shape BxC containing the model's
|
377 |
-
predictions for the quality of each mask.
|
378 |
-
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
379 |
-
of masks and H=W=256. These low res logits can be passed to
|
380 |
-
a subsequent iteration as mask input.
|
381 |
-
"""
|
382 |
-
if not self._is_image_set:
|
383 |
-
raise RuntimeError(
|
384 |
-
"An image must be set with .set_image(...) before mask prediction."
|
385 |
-
)
|
386 |
-
|
387 |
-
if point_coords is not None:
|
388 |
-
concat_points = (point_coords, point_labels)
|
389 |
-
else:
|
390 |
-
concat_points = None
|
391 |
-
|
392 |
-
# Embed prompts
|
393 |
-
if boxes is not None:
|
394 |
-
box_coords = boxes.reshape(-1, 2, 2)
|
395 |
-
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
396 |
-
box_labels = box_labels.repeat(boxes.size(0), 1)
|
397 |
-
# we merge "boxes" and "points" into a single "concat_points" input (where
|
398 |
-
# boxes are added at the beginning) to sam_prompt_encoder
|
399 |
-
if concat_points is not None:
|
400 |
-
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
401 |
-
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
402 |
-
concat_points = (concat_coords, concat_labels)
|
403 |
-
else:
|
404 |
-
concat_points = (box_coords, box_labels)
|
405 |
-
|
406 |
-
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
407 |
-
points=concat_points,
|
408 |
-
boxes=None,
|
409 |
-
masks=mask_input,
|
410 |
-
)
|
411 |
-
|
412 |
-
# Predict masks
|
413 |
-
batched_mode = (
|
414 |
-
concat_points is not None and concat_points[0].shape[0] > 1
|
415 |
-
) # multi object prediction
|
416 |
-
high_res_features = [
|
417 |
-
feat_level[img_idx].unsqueeze(0)
|
418 |
-
for feat_level in self._features["high_res_feats"]
|
419 |
-
]
|
420 |
-
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
421 |
-
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
422 |
-
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
423 |
-
sparse_prompt_embeddings=sparse_embeddings,
|
424 |
-
dense_prompt_embeddings=dense_embeddings,
|
425 |
-
multimask_output=multimask_output,
|
426 |
-
repeat_image=batched_mode,
|
427 |
-
high_res_features=high_res_features,
|
428 |
-
)
|
429 |
-
|
430 |
-
# Upscale the masks to the original image resolution
|
431 |
-
masks = self._transforms.postprocess_masks(
|
432 |
-
low_res_masks, self._orig_hw[img_idx]
|
433 |
-
)
|
434 |
-
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
435 |
-
if not return_logits:
|
436 |
-
masks = masks > self.mask_threshold
|
437 |
-
|
438 |
-
return masks, iou_predictions, low_res_masks
|
439 |
-
|
440 |
-
def get_image_embedding(self) -> torch.Tensor:
|
441 |
-
"""
|
442 |
-
Returns the image embeddings for the currently set image, with
|
443 |
-
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
444 |
-
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
445 |
-
"""
|
446 |
-
if not self._is_image_set:
|
447 |
-
raise RuntimeError(
|
448 |
-
"An image must be set with .set_image(...) to generate an embedding."
|
449 |
-
)
|
450 |
-
assert (
|
451 |
-
self._features is not None
|
452 |
-
), "Features must exist if an image has been set."
|
453 |
-
return self._features["image_embed"]
|
454 |
-
|
455 |
-
@property
|
456 |
-
def device(self) -> torch.device:
|
457 |
-
return self.model.device
|
458 |
-
|
459 |
-
def reset_predictor(self) -> None:
|
460 |
-
"""
|
461 |
-
Resets the image embeddings and other state variables.
|
462 |
-
"""
|
463 |
-
self._is_image_set = False
|
464 |
-
self._features = None
|
465 |
-
self._orig_hw = None
|
466 |
-
self._is_batch = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|