Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- __init__.py +11 -0
- __init__.pyc +0 -0
- automatic_mask_generator.py +454 -0
- build_sam.py +167 -0
- sam2_hiera_b+.yaml +1 -0
- sam2_hiera_l.yaml +1 -0
- sam2_hiera_s.yaml +1 -0
- sam2_hiera_t.yaml +1 -0
- sam2_image_predictor.py +466 -0
- sam2_video_predictor.py +1172 -0
__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from hydra import initialize_config_module
|
8 |
+
from hydra.core.global_hydra import GlobalHydra
|
9 |
+
|
10 |
+
if not GlobalHydra.instance().is_initialized():
|
11 |
+
initialize_config_module("sam2", version_base="1.2")
|
__init__.pyc
ADDED
Binary file (352 Bytes). View file
|
|
automatic_mask_generator.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
16 |
+
from sam2.utils.amg import (
|
17 |
+
area_from_rle,
|
18 |
+
batch_iterator,
|
19 |
+
batched_mask_to_box,
|
20 |
+
box_xyxy_to_xywh,
|
21 |
+
build_all_layer_point_grids,
|
22 |
+
calculate_stability_score,
|
23 |
+
coco_encode_rle,
|
24 |
+
generate_crop_boxes,
|
25 |
+
is_box_near_crop_edge,
|
26 |
+
mask_to_rle_pytorch,
|
27 |
+
MaskData,
|
28 |
+
remove_small_regions,
|
29 |
+
rle_to_mask,
|
30 |
+
uncrop_boxes_xyxy,
|
31 |
+
uncrop_masks,
|
32 |
+
uncrop_points,
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
class SAM2AutomaticMaskGenerator:
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
model: SAM2Base,
|
40 |
+
points_per_side: Optional[int] = 32,
|
41 |
+
points_per_batch: int = 64,
|
42 |
+
pred_iou_thresh: float = 0.8,
|
43 |
+
stability_score_thresh: float = 0.95,
|
44 |
+
stability_score_offset: float = 1.0,
|
45 |
+
mask_threshold: float = 0.0,
|
46 |
+
box_nms_thresh: float = 0.7,
|
47 |
+
crop_n_layers: int = 0,
|
48 |
+
crop_nms_thresh: float = 0.7,
|
49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
50 |
+
crop_n_points_downscale_factor: int = 1,
|
51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
52 |
+
min_mask_region_area: int = 0,
|
53 |
+
output_mode: str = "binary_mask",
|
54 |
+
use_m2m: bool = False,
|
55 |
+
multimask_output: bool = True,
|
56 |
+
**kwargs,
|
57 |
+
) -> None:
|
58 |
+
"""
|
59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
60 |
+
Generates a grid of point prompts over the image, then filters
|
61 |
+
low quality and duplicate masks. The default settings are chosen
|
62 |
+
for SAM 2 with a HieraL backbone.
|
63 |
+
|
64 |
+
Arguments:
|
65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
66 |
+
points_per_side (int or None): The number of points to be sampled
|
67 |
+
along one side of the image. The total number of points is
|
68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
69 |
+
point sampling.
|
70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
73 |
+
model's predicted mask quality.
|
74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
76 |
+
the model's mask predictions.
|
77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
78 |
+
calculated the stability score.
|
79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
81 |
+
suppression to filter duplicate masks.
|
82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
83 |
+
crops of the image. Sets the number of layers to run, where each
|
84 |
+
layer has 2**i_layer number of image crops.
|
85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
86 |
+
suppression to filter duplicate masks between different crops.
|
87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
88 |
+
In the first crop layer, crops will overlap by this fraction of
|
89 |
+
the image length. Later layers with more crops scale down this overlap.
|
90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
96 |
+
to remove disconnected regions and holes in masks with area smaller
|
97 |
+
than min_mask_region_area. Requires opencv.
|
98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
101 |
+
memory.
|
102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
104 |
+
"""
|
105 |
+
|
106 |
+
assert (points_per_side is None) != (
|
107 |
+
point_grids is None
|
108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
109 |
+
if points_per_side is not None:
|
110 |
+
self.point_grids = build_all_layer_point_grids(
|
111 |
+
points_per_side,
|
112 |
+
crop_n_layers,
|
113 |
+
crop_n_points_downscale_factor,
|
114 |
+
)
|
115 |
+
elif point_grids is not None:
|
116 |
+
self.point_grids = point_grids
|
117 |
+
else:
|
118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
119 |
+
|
120 |
+
assert output_mode in [
|
121 |
+
"binary_mask",
|
122 |
+
"uncompressed_rle",
|
123 |
+
"coco_rle",
|
124 |
+
], f"Unknown output_mode {output_mode}."
|
125 |
+
if output_mode == "coco_rle":
|
126 |
+
try:
|
127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
128 |
+
except ImportError as e:
|
129 |
+
print("Please install pycocotools")
|
130 |
+
raise e
|
131 |
+
|
132 |
+
self.predictor = SAM2ImagePredictor(
|
133 |
+
model,
|
134 |
+
max_hole_area=min_mask_region_area,
|
135 |
+
max_sprinkle_area=min_mask_region_area,
|
136 |
+
)
|
137 |
+
self.points_per_batch = points_per_batch
|
138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
139 |
+
self.stability_score_thresh = stability_score_thresh
|
140 |
+
self.stability_score_offset = stability_score_offset
|
141 |
+
self.mask_threshold = mask_threshold
|
142 |
+
self.box_nms_thresh = box_nms_thresh
|
143 |
+
self.crop_n_layers = crop_n_layers
|
144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
147 |
+
self.min_mask_region_area = min_mask_region_area
|
148 |
+
self.output_mode = output_mode
|
149 |
+
self.use_m2m = use_m2m
|
150 |
+
self.multimask_output = multimask_output
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
154 |
+
"""
|
155 |
+
Load a pretrained model from the Hugging Face hub.
|
156 |
+
|
157 |
+
Arguments:
|
158 |
+
model_id (str): The Hugging Face repository ID.
|
159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
163 |
+
"""
|
164 |
+
from sam2.build_sam import build_sam2_hf
|
165 |
+
|
166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
167 |
+
return cls(sam_model, **kwargs)
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
171 |
+
"""
|
172 |
+
Generates masks for the given image.
|
173 |
+
|
174 |
+
Arguments:
|
175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
179 |
+
a dict containing the following keys:
|
180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
182 |
+
is a dictionary containing the RLE.
|
183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
184 |
+
area (int): The area in pixels of the mask.
|
185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
187 |
+
point_coords (list(list(float))): The point coordinates input
|
188 |
+
to the model to generate this mask.
|
189 |
+
stability_score (float): A measure of the mask's quality. This
|
190 |
+
is filtered on using the stability_score_thresh parameter.
|
191 |
+
crop_box (list(float)): The crop of the image used to generate
|
192 |
+
the mask, given in XYWH format.
|
193 |
+
"""
|
194 |
+
|
195 |
+
# Generate masks
|
196 |
+
mask_data = self._generate_masks(image)
|
197 |
+
|
198 |
+
# Encode masks
|
199 |
+
if self.output_mode == "coco_rle":
|
200 |
+
mask_data["segmentations"] = [
|
201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
202 |
+
]
|
203 |
+
elif self.output_mode == "binary_mask":
|
204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
205 |
+
else:
|
206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
207 |
+
|
208 |
+
# Write mask records
|
209 |
+
curr_anns = []
|
210 |
+
for idx in range(len(mask_data["segmentations"])):
|
211 |
+
ann = {
|
212 |
+
"segmentation": mask_data["segmentations"][idx],
|
213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
219 |
+
}
|
220 |
+
curr_anns.append(ann)
|
221 |
+
|
222 |
+
return curr_anns
|
223 |
+
|
224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
225 |
+
orig_size = image.shape[:2]
|
226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
228 |
+
)
|
229 |
+
|
230 |
+
# Iterate over image crops
|
231 |
+
data = MaskData()
|
232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
234 |
+
data.cat(crop_data)
|
235 |
+
|
236 |
+
# Remove duplicate masks between crops
|
237 |
+
if len(crop_boxes) > 1:
|
238 |
+
# Prefer masks from smaller crops
|
239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
240 |
+
scores = scores.to(data["boxes"].device)
|
241 |
+
keep_by_nms = batched_nms(
|
242 |
+
data["boxes"].float(),
|
243 |
+
scores,
|
244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
245 |
+
iou_threshold=self.crop_nms_thresh,
|
246 |
+
)
|
247 |
+
data.filter(keep_by_nms)
|
248 |
+
data.to_numpy()
|
249 |
+
return data
|
250 |
+
|
251 |
+
def _process_crop(
|
252 |
+
self,
|
253 |
+
image: np.ndarray,
|
254 |
+
crop_box: List[int],
|
255 |
+
crop_layer_idx: int,
|
256 |
+
orig_size: Tuple[int, ...],
|
257 |
+
) -> MaskData:
|
258 |
+
# Crop the image and calculate embeddings
|
259 |
+
x0, y0, x1, y1 = crop_box
|
260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
261 |
+
cropped_im_size = cropped_im.shape[:2]
|
262 |
+
self.predictor.set_image(cropped_im)
|
263 |
+
|
264 |
+
# Get points for this crop
|
265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
267 |
+
|
268 |
+
# Generate masks for this crop in batches
|
269 |
+
data = MaskData()
|
270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
271 |
+
batch_data = self._process_batch(
|
272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
273 |
+
)
|
274 |
+
data.cat(batch_data)
|
275 |
+
del batch_data
|
276 |
+
self.predictor.reset_predictor()
|
277 |
+
|
278 |
+
# Remove duplicates within this crop.
|
279 |
+
keep_by_nms = batched_nms(
|
280 |
+
data["boxes"].float(),
|
281 |
+
data["iou_preds"],
|
282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
283 |
+
iou_threshold=self.box_nms_thresh,
|
284 |
+
)
|
285 |
+
data.filter(keep_by_nms)
|
286 |
+
|
287 |
+
# Return to the original image frame
|
288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
291 |
+
|
292 |
+
return data
|
293 |
+
|
294 |
+
def _process_batch(
|
295 |
+
self,
|
296 |
+
points: np.ndarray,
|
297 |
+
im_size: Tuple[int, ...],
|
298 |
+
crop_box: List[int],
|
299 |
+
orig_size: Tuple[int, ...],
|
300 |
+
normalize=False,
|
301 |
+
) -> MaskData:
|
302 |
+
orig_h, orig_w = orig_size
|
303 |
+
|
304 |
+
# Run model on this batch
|
305 |
+
points = torch.as_tensor(
|
306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
307 |
+
)
|
308 |
+
in_points = self.predictor._transforms.transform_coords(
|
309 |
+
points, normalize=normalize, orig_hw=im_size
|
310 |
+
)
|
311 |
+
in_labels = torch.ones(
|
312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
313 |
+
)
|
314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
315 |
+
in_points[:, None, :],
|
316 |
+
in_labels[:, None],
|
317 |
+
multimask_output=self.multimask_output,
|
318 |
+
return_logits=True,
|
319 |
+
)
|
320 |
+
|
321 |
+
# Serialize predictions and store in MaskData
|
322 |
+
data = MaskData(
|
323 |
+
masks=masks.flatten(0, 1),
|
324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
327 |
+
)
|
328 |
+
del masks
|
329 |
+
|
330 |
+
if not self.use_m2m:
|
331 |
+
# Filter by predicted IoU
|
332 |
+
if self.pred_iou_thresh > 0.0:
|
333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
334 |
+
data.filter(keep_mask)
|
335 |
+
|
336 |
+
# Calculate and filter by stability score
|
337 |
+
data["stability_score"] = calculate_stability_score(
|
338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
339 |
+
)
|
340 |
+
if self.stability_score_thresh > 0.0:
|
341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
342 |
+
data.filter(keep_mask)
|
343 |
+
else:
|
344 |
+
# One step refinement using previous mask predictions
|
345 |
+
in_points = self.predictor._transforms.transform_coords(
|
346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
347 |
+
)
|
348 |
+
labels = torch.ones(
|
349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
350 |
+
)
|
351 |
+
masks, ious = self.refine_with_m2m(
|
352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
353 |
+
)
|
354 |
+
data["masks"] = masks.squeeze(1)
|
355 |
+
data["iou_preds"] = ious.squeeze(1)
|
356 |
+
|
357 |
+
if self.pred_iou_thresh > 0.0:
|
358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
359 |
+
data.filter(keep_mask)
|
360 |
+
|
361 |
+
data["stability_score"] = calculate_stability_score(
|
362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
363 |
+
)
|
364 |
+
if self.stability_score_thresh > 0.0:
|
365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
366 |
+
data.filter(keep_mask)
|
367 |
+
|
368 |
+
# Threshold masks and calculate boxes
|
369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
371 |
+
|
372 |
+
# Filter boxes that touch crop boundaries
|
373 |
+
keep_mask = ~is_box_near_crop_edge(
|
374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
375 |
+
)
|
376 |
+
if not torch.all(keep_mask):
|
377 |
+
data.filter(keep_mask)
|
378 |
+
|
379 |
+
# Compress to RLE
|
380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
382 |
+
del data["masks"]
|
383 |
+
|
384 |
+
return data
|
385 |
+
|
386 |
+
@staticmethod
|
387 |
+
def postprocess_small_regions(
|
388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
389 |
+
) -> MaskData:
|
390 |
+
"""
|
391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
392 |
+
box NMS to remove any new duplicates.
|
393 |
+
|
394 |
+
Edits mask_data in place.
|
395 |
+
|
396 |
+
Requires open-cv as a dependency.
|
397 |
+
"""
|
398 |
+
if len(mask_data["rles"]) == 0:
|
399 |
+
return mask_data
|
400 |
+
|
401 |
+
# Filter small disconnected regions and holes
|
402 |
+
new_masks = []
|
403 |
+
scores = []
|
404 |
+
for rle in mask_data["rles"]:
|
405 |
+
mask = rle_to_mask(rle)
|
406 |
+
|
407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
408 |
+
unchanged = not changed
|
409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
410 |
+
unchanged = unchanged and not changed
|
411 |
+
|
412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
415 |
+
scores.append(float(unchanged))
|
416 |
+
|
417 |
+
# Recalculate boxes and remove any new duplicates
|
418 |
+
masks = torch.cat(new_masks, dim=0)
|
419 |
+
boxes = batched_mask_to_box(masks)
|
420 |
+
keep_by_nms = batched_nms(
|
421 |
+
boxes.float(),
|
422 |
+
torch.as_tensor(scores),
|
423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
424 |
+
iou_threshold=nms_thresh,
|
425 |
+
)
|
426 |
+
|
427 |
+
# Only recalculate RLEs for masks that have changed
|
428 |
+
for i_mask in keep_by_nms:
|
429 |
+
if scores[i_mask] == 0.0:
|
430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
433 |
+
mask_data.filter(keep_by_nms)
|
434 |
+
|
435 |
+
return mask_data
|
436 |
+
|
437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
438 |
+
new_masks = []
|
439 |
+
new_iou_preds = []
|
440 |
+
|
441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
442 |
+
points_per_batch, points, point_labels, low_res_masks
|
443 |
+
):
|
444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
445 |
+
cur_points[:, None, :],
|
446 |
+
cur_point_labels[:, None],
|
447 |
+
mask_input=low_res_mask[:, None, :],
|
448 |
+
multimask_output=False,
|
449 |
+
return_logits=True,
|
450 |
+
)
|
451 |
+
new_masks.append(best_masks)
|
452 |
+
new_iou_preds.append(best_iou_preds)
|
453 |
+
masks = torch.cat(new_masks, dim=0)
|
454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
build_sam.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from hydra import compose
|
12 |
+
from hydra.utils import instantiate
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
import sam2
|
16 |
+
|
17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
19 |
+
# it could shadow the sam2 package and cause issues.
|
20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
23 |
+
# This typically happens because the user is running Python from the parent directory
|
24 |
+
# that contains the sam2 repo they cloned.
|
25 |
+
raise RuntimeError(
|
26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
36 |
+
"facebook/sam2-hiera-tiny": (
|
37 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
38 |
+
"sam2_hiera_tiny.pt",
|
39 |
+
),
|
40 |
+
"facebook/sam2-hiera-small": (
|
41 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
42 |
+
"sam2_hiera_small.pt",
|
43 |
+
),
|
44 |
+
"facebook/sam2-hiera-base-plus": (
|
45 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
46 |
+
"sam2_hiera_base_plus.pt",
|
47 |
+
),
|
48 |
+
"facebook/sam2-hiera-large": (
|
49 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
50 |
+
"sam2_hiera_large.pt",
|
51 |
+
),
|
52 |
+
"facebook/sam2.1-hiera-tiny": (
|
53 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
54 |
+
"sam2.1_hiera_tiny.pt",
|
55 |
+
),
|
56 |
+
"facebook/sam2.1-hiera-small": (
|
57 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
58 |
+
"sam2.1_hiera_small.pt",
|
59 |
+
),
|
60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
61 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
62 |
+
"sam2.1_hiera_base_plus.pt",
|
63 |
+
),
|
64 |
+
"facebook/sam2.1-hiera-large": (
|
65 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
66 |
+
"sam2.1_hiera_large.pt",
|
67 |
+
),
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
def build_sam2(
|
72 |
+
config_file,
|
73 |
+
ckpt_path=None,
|
74 |
+
device="cuda",
|
75 |
+
mode="eval",
|
76 |
+
hydra_overrides_extra=[],
|
77 |
+
apply_postprocessing=True,
|
78 |
+
**kwargs,
|
79 |
+
):
|
80 |
+
|
81 |
+
if apply_postprocessing:
|
82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
83 |
+
hydra_overrides_extra += [
|
84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
88 |
+
]
|
89 |
+
# Read config and init model
|
90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
91 |
+
OmegaConf.resolve(cfg)
|
92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
93 |
+
_load_checkpoint(model, ckpt_path)
|
94 |
+
model = model.to(device)
|
95 |
+
if mode == "eval":
|
96 |
+
model.eval()
|
97 |
+
return model
|
98 |
+
|
99 |
+
|
100 |
+
def build_sam2_video_predictor(
|
101 |
+
config_file,
|
102 |
+
ckpt_path=None,
|
103 |
+
device="cuda",
|
104 |
+
mode="eval",
|
105 |
+
hydra_overrides_extra=[],
|
106 |
+
apply_postprocessing=True,
|
107 |
+
**kwargs,
|
108 |
+
):
|
109 |
+
hydra_overrides = [
|
110 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
111 |
+
]
|
112 |
+
if apply_postprocessing:
|
113 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
114 |
+
hydra_overrides_extra += [
|
115 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
116 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
117 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
118 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
119 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
120 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
121 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
122 |
+
"++model.fill_hole_area=8",
|
123 |
+
]
|
124 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
125 |
+
|
126 |
+
# Read config and init model
|
127 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
128 |
+
OmegaConf.resolve(cfg)
|
129 |
+
model = instantiate(cfg.model, _recursive_=True)
|
130 |
+
_load_checkpoint(model, ckpt_path)
|
131 |
+
model = model.to(device)
|
132 |
+
if mode == "eval":
|
133 |
+
model.eval()
|
134 |
+
return model
|
135 |
+
|
136 |
+
|
137 |
+
def _hf_download(model_id):
|
138 |
+
from huggingface_hub import hf_hub_download
|
139 |
+
|
140 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
141 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
142 |
+
return config_name, ckpt_path
|
143 |
+
|
144 |
+
|
145 |
+
def build_sam2_hf(model_id, **kwargs):
|
146 |
+
config_name, ckpt_path = _hf_download(model_id)
|
147 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
148 |
+
|
149 |
+
|
150 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
151 |
+
config_name, ckpt_path = _hf_download(model_id)
|
152 |
+
return build_sam2_video_predictor(
|
153 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
154 |
+
)
|
155 |
+
|
156 |
+
|
157 |
+
def _load_checkpoint(model, ckpt_path):
|
158 |
+
if ckpt_path is not None:
|
159 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
160 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
161 |
+
if missing_keys:
|
162 |
+
logging.error(missing_keys)
|
163 |
+
raise RuntimeError()
|
164 |
+
if unexpected_keys:
|
165 |
+
logging.error(unexpected_keys)
|
166 |
+
raise RuntimeError()
|
167 |
+
logging.info("Loaded checkpoint sucessfully")
|
sam2_hiera_b+.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
configs/sam2/sam2_hiera_b+.yaml
|
sam2_hiera_l.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
configs/sam2/sam2_hiera_l.yaml
|
sam2_hiera_s.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
configs/sam2/sam2_hiera_s.yaml
|
sam2_hiera_t.yaml
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
configs/sam2/sam2_hiera_t.yaml
|
sam2_image_predictor.py
ADDED
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from typing import List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
from PIL.Image import Image
|
14 |
+
|
15 |
+
from sam2.modeling.sam2_base import SAM2Base
|
16 |
+
|
17 |
+
from sam2.utils.transforms import SAM2Transforms
|
18 |
+
|
19 |
+
|
20 |
+
class SAM2ImagePredictor:
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
sam_model: SAM2Base,
|
24 |
+
mask_threshold=0.0,
|
25 |
+
max_hole_area=0.0,
|
26 |
+
max_sprinkle_area=0.0,
|
27 |
+
**kwargs,
|
28 |
+
) -> None:
|
29 |
+
"""
|
30 |
+
Uses SAM-2 to calculate the image embedding for an image, and then
|
31 |
+
allow repeated, efficient mask prediction given prompts.
|
32 |
+
|
33 |
+
Arguments:
|
34 |
+
sam_model (Sam-2): The model to use for mask prediction.
|
35 |
+
mask_threshold (float): The threshold to use when converting mask logits
|
36 |
+
to binary masks. Masks are thresholded at 0 by default.
|
37 |
+
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
38 |
+
the maximum area of max_hole_area in low_res_masks.
|
39 |
+
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
40 |
+
the maximum area of max_sprinkle_area in low_res_masks.
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
self.model = sam_model
|
44 |
+
self._transforms = SAM2Transforms(
|
45 |
+
resolution=self.model.image_size,
|
46 |
+
mask_threshold=mask_threshold,
|
47 |
+
max_hole_area=max_hole_area,
|
48 |
+
max_sprinkle_area=max_sprinkle_area,
|
49 |
+
)
|
50 |
+
|
51 |
+
# Predictor state
|
52 |
+
self._is_image_set = False
|
53 |
+
self._features = None
|
54 |
+
self._orig_hw = None
|
55 |
+
# Whether the predictor is set for single image or a batch of images
|
56 |
+
self._is_batch = False
|
57 |
+
|
58 |
+
# Predictor config
|
59 |
+
self.mask_threshold = mask_threshold
|
60 |
+
|
61 |
+
# Spatial dim for backbone feature maps
|
62 |
+
self._bb_feat_sizes = [
|
63 |
+
(256, 256),
|
64 |
+
(128, 128),
|
65 |
+
(64, 64),
|
66 |
+
]
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
70 |
+
"""
|
71 |
+
Load a pretrained model from the Hugging Face hub.
|
72 |
+
|
73 |
+
Arguments:
|
74 |
+
model_id (str): The Hugging Face repository ID.
|
75 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
(SAM2ImagePredictor): The loaded model.
|
79 |
+
"""
|
80 |
+
from sam2.build_sam import build_sam2_hf
|
81 |
+
|
82 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
83 |
+
return cls(sam_model, **kwargs)
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def set_image(
|
87 |
+
self,
|
88 |
+
image: Union[np.ndarray, Image],
|
89 |
+
) -> None:
|
90 |
+
"""
|
91 |
+
Calculates the image embeddings for the provided image, allowing
|
92 |
+
masks to be predicted with the 'predict' method.
|
93 |
+
|
94 |
+
Arguments:
|
95 |
+
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
|
96 |
+
with pixel values in [0, 255].
|
97 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
98 |
+
"""
|
99 |
+
self.reset_predictor()
|
100 |
+
# Transform the image to the form expected by the model
|
101 |
+
if isinstance(image, np.ndarray):
|
102 |
+
logging.info("For numpy array image, we assume (HxWxC) format")
|
103 |
+
self._orig_hw = [image.shape[:2]]
|
104 |
+
elif isinstance(image, Image):
|
105 |
+
w, h = image.size
|
106 |
+
self._orig_hw = [(h, w)]
|
107 |
+
else:
|
108 |
+
raise NotImplementedError("Image format not supported")
|
109 |
+
|
110 |
+
input_image = self._transforms(image)
|
111 |
+
input_image = input_image[None, ...].to(self.device)
|
112 |
+
|
113 |
+
assert (
|
114 |
+
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
115 |
+
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
116 |
+
logging.info("Computing image embeddings for the provided image...")
|
117 |
+
backbone_out = self.model.forward_image(input_image)
|
118 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
119 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
120 |
+
if self.model.directly_add_no_mem_embed:
|
121 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
122 |
+
|
123 |
+
feats = [
|
124 |
+
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
125 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
126 |
+
][::-1]
|
127 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
128 |
+
self._is_image_set = True
|
129 |
+
logging.info("Image embeddings computed.")
|
130 |
+
|
131 |
+
@torch.no_grad()
|
132 |
+
def set_image_batch(
|
133 |
+
self,
|
134 |
+
image_list: List[Union[np.ndarray]],
|
135 |
+
) -> None:
|
136 |
+
"""
|
137 |
+
Calculates the image embeddings for the provided image batch, allowing
|
138 |
+
masks to be predicted with the 'predict_batch' method.
|
139 |
+
|
140 |
+
Arguments:
|
141 |
+
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
|
142 |
+
with pixel values in [0, 255].
|
143 |
+
"""
|
144 |
+
self.reset_predictor()
|
145 |
+
assert isinstance(image_list, list)
|
146 |
+
self._orig_hw = []
|
147 |
+
for image in image_list:
|
148 |
+
assert isinstance(
|
149 |
+
image, np.ndarray
|
150 |
+
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
151 |
+
self._orig_hw.append(image.shape[:2])
|
152 |
+
# Transform the image to the form expected by the model
|
153 |
+
img_batch = self._transforms.forward_batch(image_list)
|
154 |
+
img_batch = img_batch.to(self.device)
|
155 |
+
batch_size = img_batch.shape[0]
|
156 |
+
assert (
|
157 |
+
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
158 |
+
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
159 |
+
logging.info("Computing image embeddings for the provided images...")
|
160 |
+
backbone_out = self.model.forward_image(img_batch)
|
161 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
162 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
163 |
+
if self.model.directly_add_no_mem_embed:
|
164 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
165 |
+
|
166 |
+
feats = [
|
167 |
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
168 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
169 |
+
][::-1]
|
170 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
171 |
+
self._is_image_set = True
|
172 |
+
self._is_batch = True
|
173 |
+
logging.info("Image embeddings computed.")
|
174 |
+
|
175 |
+
def predict_batch(
|
176 |
+
self,
|
177 |
+
point_coords_batch: List[np.ndarray] = None,
|
178 |
+
point_labels_batch: List[np.ndarray] = None,
|
179 |
+
box_batch: List[np.ndarray] = None,
|
180 |
+
mask_input_batch: List[np.ndarray] = None,
|
181 |
+
multimask_output: bool = True,
|
182 |
+
return_logits: bool = False,
|
183 |
+
normalize_coords=True,
|
184 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
185 |
+
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
186 |
+
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
187 |
+
"""
|
188 |
+
assert self._is_batch, "This function should only be used when in batched mode"
|
189 |
+
if not self._is_image_set:
|
190 |
+
raise RuntimeError(
|
191 |
+
"An image must be set with .set_image_batch(...) before mask prediction."
|
192 |
+
)
|
193 |
+
num_images = len(self._features["image_embed"])
|
194 |
+
all_masks = []
|
195 |
+
all_ious = []
|
196 |
+
all_low_res_masks = []
|
197 |
+
for img_idx in range(num_images):
|
198 |
+
# Transform input prompts
|
199 |
+
point_coords = (
|
200 |
+
point_coords_batch[img_idx] if point_coords_batch is not None else None
|
201 |
+
)
|
202 |
+
point_labels = (
|
203 |
+
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
204 |
+
)
|
205 |
+
box = box_batch[img_idx] if box_batch is not None else None
|
206 |
+
mask_input = (
|
207 |
+
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
208 |
+
)
|
209 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
210 |
+
point_coords,
|
211 |
+
point_labels,
|
212 |
+
box,
|
213 |
+
mask_input,
|
214 |
+
normalize_coords,
|
215 |
+
img_idx=img_idx,
|
216 |
+
)
|
217 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
218 |
+
unnorm_coords,
|
219 |
+
labels,
|
220 |
+
unnorm_box,
|
221 |
+
mask_input,
|
222 |
+
multimask_output,
|
223 |
+
return_logits=return_logits,
|
224 |
+
img_idx=img_idx,
|
225 |
+
)
|
226 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
227 |
+
iou_predictions_np = (
|
228 |
+
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
229 |
+
)
|
230 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
231 |
+
all_masks.append(masks_np)
|
232 |
+
all_ious.append(iou_predictions_np)
|
233 |
+
all_low_res_masks.append(low_res_masks_np)
|
234 |
+
|
235 |
+
return all_masks, all_ious, all_low_res_masks
|
236 |
+
|
237 |
+
def predict(
|
238 |
+
self,
|
239 |
+
point_coords: Optional[np.ndarray] = None,
|
240 |
+
point_labels: Optional[np.ndarray] = None,
|
241 |
+
box: Optional[np.ndarray] = None,
|
242 |
+
mask_input: Optional[np.ndarray] = None,
|
243 |
+
multimask_output: bool = True,
|
244 |
+
return_logits: bool = False,
|
245 |
+
normalize_coords=True,
|
246 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
247 |
+
"""
|
248 |
+
Predict masks for the given input prompts, using the currently set image.
|
249 |
+
|
250 |
+
Arguments:
|
251 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
252 |
+
model. Each point is in (X,Y) in pixels.
|
253 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
254 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
255 |
+
background point.
|
256 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
257 |
+
model, in XYXY format.
|
258 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
259 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
260 |
+
for SAM, H=W=256.
|
261 |
+
multimask_output (bool): If true, the model will return three masks.
|
262 |
+
For ambiguous input prompts (such as a single click), this will often
|
263 |
+
produce better masks than a single prediction. If only a single
|
264 |
+
mask is needed, the model's predicted quality score can be used
|
265 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
266 |
+
input prompts, multimask_output=False can give better results.
|
267 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
268 |
+
instead of a binary mask.
|
269 |
+
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
273 |
+
number of masks, and (H, W) is the original image size.
|
274 |
+
(np.ndarray): An array of length C containing the model's
|
275 |
+
predictions for the quality of each mask.
|
276 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
277 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
278 |
+
a subsequent iteration as mask input.
|
279 |
+
"""
|
280 |
+
if not self._is_image_set:
|
281 |
+
raise RuntimeError(
|
282 |
+
"An image must be set with .set_image(...) before mask prediction."
|
283 |
+
)
|
284 |
+
|
285 |
+
# Transform input prompts
|
286 |
+
|
287 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
288 |
+
point_coords, point_labels, box, mask_input, normalize_coords
|
289 |
+
)
|
290 |
+
|
291 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
292 |
+
unnorm_coords,
|
293 |
+
labels,
|
294 |
+
unnorm_box,
|
295 |
+
mask_input,
|
296 |
+
multimask_output,
|
297 |
+
return_logits=return_logits,
|
298 |
+
)
|
299 |
+
|
300 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
301 |
+
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
302 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
303 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
304 |
+
|
305 |
+
def _prep_prompts(
|
306 |
+
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
307 |
+
):
|
308 |
+
|
309 |
+
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
310 |
+
if point_coords is not None:
|
311 |
+
assert (
|
312 |
+
point_labels is not None
|
313 |
+
), "point_labels must be supplied if point_coords is supplied."
|
314 |
+
point_coords = torch.as_tensor(
|
315 |
+
point_coords, dtype=torch.float, device=self.device
|
316 |
+
)
|
317 |
+
unnorm_coords = self._transforms.transform_coords(
|
318 |
+
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
319 |
+
)
|
320 |
+
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
321 |
+
if len(unnorm_coords.shape) == 2:
|
322 |
+
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
323 |
+
if box is not None:
|
324 |
+
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
325 |
+
unnorm_box = self._transforms.transform_boxes(
|
326 |
+
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
327 |
+
) # Bx2x2
|
328 |
+
if mask_logits is not None:
|
329 |
+
mask_input = torch.as_tensor(
|
330 |
+
mask_logits, dtype=torch.float, device=self.device
|
331 |
+
)
|
332 |
+
if len(mask_input.shape) == 3:
|
333 |
+
mask_input = mask_input[None, :, :, :]
|
334 |
+
return mask_input, unnorm_coords, labels, unnorm_box
|
335 |
+
|
336 |
+
@torch.no_grad()
|
337 |
+
def _predict(
|
338 |
+
self,
|
339 |
+
point_coords: Optional[torch.Tensor],
|
340 |
+
point_labels: Optional[torch.Tensor],
|
341 |
+
boxes: Optional[torch.Tensor] = None,
|
342 |
+
mask_input: Optional[torch.Tensor] = None,
|
343 |
+
multimask_output: bool = True,
|
344 |
+
return_logits: bool = False,
|
345 |
+
img_idx: int = -1,
|
346 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
347 |
+
"""
|
348 |
+
Predict masks for the given input prompts, using the currently set image.
|
349 |
+
Input prompts are batched torch tensors and are expected to already be
|
350 |
+
transformed to the input frame using SAM2Transforms.
|
351 |
+
|
352 |
+
Arguments:
|
353 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
354 |
+
model. Each point is in (X,Y) in pixels.
|
355 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
356 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
357 |
+
background point.
|
358 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
359 |
+
model, in XYXY format.
|
360 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
361 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
362 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
363 |
+
predict method do not need further transformation.
|
364 |
+
multimask_output (bool): If true, the model will return three masks.
|
365 |
+
For ambiguous input prompts (such as a single click), this will often
|
366 |
+
produce better masks than a single prediction. If only a single
|
367 |
+
mask is needed, the model's predicted quality score can be used
|
368 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
369 |
+
input prompts, multimask_output=False can give better results.
|
370 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
371 |
+
instead of a binary mask.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
375 |
+
number of masks, and (H, W) is the original image size.
|
376 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
377 |
+
predictions for the quality of each mask.
|
378 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
379 |
+
of masks and H=W=256. These low res logits can be passed to
|
380 |
+
a subsequent iteration as mask input.
|
381 |
+
"""
|
382 |
+
if not self._is_image_set:
|
383 |
+
raise RuntimeError(
|
384 |
+
"An image must be set with .set_image(...) before mask prediction."
|
385 |
+
)
|
386 |
+
|
387 |
+
if point_coords is not None:
|
388 |
+
concat_points = (point_coords, point_labels)
|
389 |
+
else:
|
390 |
+
concat_points = None
|
391 |
+
|
392 |
+
# Embed prompts
|
393 |
+
if boxes is not None:
|
394 |
+
box_coords = boxes.reshape(-1, 2, 2)
|
395 |
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
396 |
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
397 |
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
398 |
+
# boxes are added at the beginning) to sam_prompt_encoder
|
399 |
+
if concat_points is not None:
|
400 |
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
401 |
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
402 |
+
concat_points = (concat_coords, concat_labels)
|
403 |
+
else:
|
404 |
+
concat_points = (box_coords, box_labels)
|
405 |
+
|
406 |
+
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
407 |
+
points=concat_points,
|
408 |
+
boxes=None,
|
409 |
+
masks=mask_input,
|
410 |
+
)
|
411 |
+
|
412 |
+
# Predict masks
|
413 |
+
batched_mode = (
|
414 |
+
concat_points is not None and concat_points[0].shape[0] > 1
|
415 |
+
) # multi object prediction
|
416 |
+
high_res_features = [
|
417 |
+
feat_level[img_idx].unsqueeze(0)
|
418 |
+
for feat_level in self._features["high_res_feats"]
|
419 |
+
]
|
420 |
+
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
421 |
+
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
422 |
+
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
423 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
424 |
+
dense_prompt_embeddings=dense_embeddings,
|
425 |
+
multimask_output=multimask_output,
|
426 |
+
repeat_image=batched_mode,
|
427 |
+
high_res_features=high_res_features,
|
428 |
+
)
|
429 |
+
|
430 |
+
# Upscale the masks to the original image resolution
|
431 |
+
masks = self._transforms.postprocess_masks(
|
432 |
+
low_res_masks, self._orig_hw[img_idx]
|
433 |
+
)
|
434 |
+
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
435 |
+
if not return_logits:
|
436 |
+
masks = masks > self.mask_threshold
|
437 |
+
|
438 |
+
return masks, iou_predictions, low_res_masks
|
439 |
+
|
440 |
+
def get_image_embedding(self) -> torch.Tensor:
|
441 |
+
"""
|
442 |
+
Returns the image embeddings for the currently set image, with
|
443 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
444 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
445 |
+
"""
|
446 |
+
if not self._is_image_set:
|
447 |
+
raise RuntimeError(
|
448 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
449 |
+
)
|
450 |
+
assert (
|
451 |
+
self._features is not None
|
452 |
+
), "Features must exist if an image has been set."
|
453 |
+
return self._features["image_embed"]
|
454 |
+
|
455 |
+
@property
|
456 |
+
def device(self) -> torch.device:
|
457 |
+
return self.model.device
|
458 |
+
|
459 |
+
def reset_predictor(self) -> None:
|
460 |
+
"""
|
461 |
+
Resets the image embeddings and other state variables.
|
462 |
+
"""
|
463 |
+
self._is_image_set = False
|
464 |
+
self._features = None
|
465 |
+
self._orig_hw = None
|
466 |
+
self._is_batch = False
|
sam2_video_predictor.py
ADDED
@@ -0,0 +1,1172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import warnings
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
15 |
+
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
16 |
+
|
17 |
+
|
18 |
+
class SAM2VideoPredictor(SAM2Base):
|
19 |
+
"""The predictor class to handle user interactions and manage inference states."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
fill_hole_area=0,
|
24 |
+
# whether to apply non-overlapping constraints on the output object masks
|
25 |
+
non_overlap_masks=False,
|
26 |
+
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
27 |
+
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
28 |
+
clear_non_cond_mem_around_input=False,
|
29 |
+
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
30 |
+
clear_non_cond_mem_for_multi_obj=False,
|
31 |
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
32 |
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
33 |
+
add_all_frames_to_correct_as_cond=False,
|
34 |
+
**kwargs,
|
35 |
+
):
|
36 |
+
super().__init__(**kwargs)
|
37 |
+
self.fill_hole_area = fill_hole_area
|
38 |
+
self.non_overlap_masks = non_overlap_masks
|
39 |
+
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
40 |
+
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
41 |
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
42 |
+
|
43 |
+
@torch.inference_mode()
|
44 |
+
def init_state(
|
45 |
+
self,
|
46 |
+
video_path,
|
47 |
+
offload_video_to_cpu=False,
|
48 |
+
offload_state_to_cpu=False,
|
49 |
+
async_loading_frames=False,
|
50 |
+
):
|
51 |
+
"""Initialize an inference state."""
|
52 |
+
compute_device = self.device # device of the model
|
53 |
+
images, video_height, video_width = load_video_frames(
|
54 |
+
video_path=video_path,
|
55 |
+
image_size=self.image_size,
|
56 |
+
offload_video_to_cpu=offload_video_to_cpu,
|
57 |
+
async_loading_frames=async_loading_frames,
|
58 |
+
compute_device=compute_device,
|
59 |
+
)
|
60 |
+
inference_state = {}
|
61 |
+
inference_state["images"] = images
|
62 |
+
inference_state["num_frames"] = len(images)
|
63 |
+
# whether to offload the video frames to CPU memory
|
64 |
+
# turning on this option saves the GPU memory with only a very small overhead
|
65 |
+
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
66 |
+
# whether to offload the inference state to CPU memory
|
67 |
+
# turning on this option saves the GPU memory at the cost of a lower tracking fps
|
68 |
+
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
|
69 |
+
# and from 24 to 21 when tracking two objects)
|
70 |
+
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
|
71 |
+
# the original video height and width, used for resizing final output scores
|
72 |
+
inference_state["video_height"] = video_height
|
73 |
+
inference_state["video_width"] = video_width
|
74 |
+
inference_state["device"] = compute_device
|
75 |
+
if offload_state_to_cpu:
|
76 |
+
inference_state["storage_device"] = torch.device("cpu")
|
77 |
+
else:
|
78 |
+
inference_state["storage_device"] = compute_device
|
79 |
+
# inputs on each frame
|
80 |
+
inference_state["point_inputs_per_obj"] = {}
|
81 |
+
inference_state["mask_inputs_per_obj"] = {}
|
82 |
+
# visual features on a small number of recently visited frames for quick interactions
|
83 |
+
inference_state["cached_features"] = {}
|
84 |
+
# values that don't change across frames (so we only need to hold one copy of them)
|
85 |
+
inference_state["constants"] = {}
|
86 |
+
# mapping between client-side object id and model-side object index
|
87 |
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
88 |
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
89 |
+
inference_state["obj_ids"] = []
|
90 |
+
# A storage to hold the model's tracking results and states on each frame
|
91 |
+
inference_state["output_dict"] = {
|
92 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
93 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
94 |
+
}
|
95 |
+
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
96 |
+
inference_state["output_dict_per_obj"] = {}
|
97 |
+
# A temporary storage to hold new outputs when user interact with a frame
|
98 |
+
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
99 |
+
inference_state["temp_output_dict_per_obj"] = {}
|
100 |
+
# Frames that already holds consolidated outputs from click or mask inputs
|
101 |
+
# (we directly use their consolidated outputs during tracking)
|
102 |
+
inference_state["consolidated_frame_inds"] = {
|
103 |
+
"cond_frame_outputs": set(), # set containing frame indices
|
104 |
+
"non_cond_frame_outputs": set(), # set containing frame indices
|
105 |
+
}
|
106 |
+
# metadata for each tracking frame (e.g. which direction it's tracked)
|
107 |
+
inference_state["tracking_has_started"] = False
|
108 |
+
inference_state["frames_already_tracked"] = {}
|
109 |
+
# Warm up the visual backbone and cache the image feature on frame 0
|
110 |
+
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
111 |
+
return inference_state
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
115 |
+
"""
|
116 |
+
Load a pretrained model from the Hugging Face hub.
|
117 |
+
|
118 |
+
Arguments:
|
119 |
+
model_id (str): The Hugging Face repository ID.
|
120 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
(SAM2VideoPredictor): The loaded model.
|
124 |
+
"""
|
125 |
+
from sam2.build_sam import build_sam2_video_predictor_hf
|
126 |
+
|
127 |
+
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
128 |
+
return sam_model
|
129 |
+
|
130 |
+
def _obj_id_to_idx(self, inference_state, obj_id):
|
131 |
+
"""Map client-side object id to model-side object index."""
|
132 |
+
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
133 |
+
if obj_idx is not None:
|
134 |
+
return obj_idx
|
135 |
+
|
136 |
+
# This is a new object id not sent to the server before. We only allow adding
|
137 |
+
# new objects *before* the tracking starts.
|
138 |
+
allow_new_object = not inference_state["tracking_has_started"]
|
139 |
+
if allow_new_object:
|
140 |
+
# get the next object slot
|
141 |
+
obj_idx = len(inference_state["obj_id_to_idx"])
|
142 |
+
inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
143 |
+
inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
144 |
+
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
145 |
+
# set up input and output structures for this object
|
146 |
+
inference_state["point_inputs_per_obj"][obj_idx] = {}
|
147 |
+
inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
148 |
+
inference_state["output_dict_per_obj"][obj_idx] = {
|
149 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
150 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
151 |
+
}
|
152 |
+
inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
153 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
154 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
155 |
+
}
|
156 |
+
return obj_idx
|
157 |
+
else:
|
158 |
+
raise RuntimeError(
|
159 |
+
f"Cannot add new object id {obj_id} after tracking starts. "
|
160 |
+
f"All existing object ids: {inference_state['obj_ids']}. "
|
161 |
+
f"Please call 'reset_state' to restart from scratch."
|
162 |
+
)
|
163 |
+
|
164 |
+
def _obj_idx_to_id(self, inference_state, obj_idx):
|
165 |
+
"""Map model-side object index to client-side object id."""
|
166 |
+
return inference_state["obj_idx_to_id"][obj_idx]
|
167 |
+
|
168 |
+
def _get_obj_num(self, inference_state):
|
169 |
+
"""Get the total number of unique object ids received so far in this session."""
|
170 |
+
return len(inference_state["obj_idx_to_id"])
|
171 |
+
|
172 |
+
@torch.inference_mode()
|
173 |
+
def add_new_points_or_box(
|
174 |
+
self,
|
175 |
+
inference_state,
|
176 |
+
frame_idx,
|
177 |
+
obj_id,
|
178 |
+
points=None,
|
179 |
+
labels=None,
|
180 |
+
clear_old_points=True,
|
181 |
+
normalize_coords=True,
|
182 |
+
box=None,
|
183 |
+
):
|
184 |
+
"""Add new points to a frame."""
|
185 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
186 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
187 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
188 |
+
|
189 |
+
if (points is not None) != (labels is not None):
|
190 |
+
raise ValueError("points and labels must be provided together")
|
191 |
+
if points is None and box is None:
|
192 |
+
raise ValueError("at least one of points or box must be provided as input")
|
193 |
+
|
194 |
+
if points is None:
|
195 |
+
points = torch.zeros(0, 2, dtype=torch.float32)
|
196 |
+
elif not isinstance(points, torch.Tensor):
|
197 |
+
points = torch.tensor(points, dtype=torch.float32)
|
198 |
+
if labels is None:
|
199 |
+
labels = torch.zeros(0, dtype=torch.int32)
|
200 |
+
elif not isinstance(labels, torch.Tensor):
|
201 |
+
labels = torch.tensor(labels, dtype=torch.int32)
|
202 |
+
if points.dim() == 2:
|
203 |
+
points = points.unsqueeze(0) # add batch dimension
|
204 |
+
if labels.dim() == 1:
|
205 |
+
labels = labels.unsqueeze(0) # add batch dimension
|
206 |
+
|
207 |
+
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
208 |
+
# along with the user-provided points (consistent with how SAM 2 is trained).
|
209 |
+
if box is not None:
|
210 |
+
if not clear_old_points:
|
211 |
+
raise ValueError(
|
212 |
+
"cannot add box without clearing old points, since "
|
213 |
+
"box prompt must be provided before any point prompt "
|
214 |
+
"(please use clear_old_points=True instead)"
|
215 |
+
)
|
216 |
+
if inference_state["tracking_has_started"]:
|
217 |
+
warnings.warn(
|
218 |
+
"You are adding a box after tracking starts. SAM 2 may not always be "
|
219 |
+
"able to incorporate a box prompt for *refinement*. If you intend to "
|
220 |
+
"use box prompt as an *initial* input before tracking, please call "
|
221 |
+
"'reset_state' on the inference state to restart from scratch.",
|
222 |
+
category=UserWarning,
|
223 |
+
stacklevel=2,
|
224 |
+
)
|
225 |
+
if not isinstance(box, torch.Tensor):
|
226 |
+
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
227 |
+
box_coords = box.reshape(1, 2, 2)
|
228 |
+
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
229 |
+
box_labels = box_labels.reshape(1, 2)
|
230 |
+
points = torch.cat([box_coords, points], dim=1)
|
231 |
+
labels = torch.cat([box_labels, labels], dim=1)
|
232 |
+
|
233 |
+
if normalize_coords:
|
234 |
+
video_H = inference_state["video_height"]
|
235 |
+
video_W = inference_state["video_width"]
|
236 |
+
points = points / torch.tensor([video_W, video_H]).to(points.device)
|
237 |
+
# scale the (normalized) coordinates by the model's internal image size
|
238 |
+
points = points * self.image_size
|
239 |
+
points = points.to(inference_state["device"])
|
240 |
+
labels = labels.to(inference_state["device"])
|
241 |
+
|
242 |
+
if not clear_old_points:
|
243 |
+
point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
244 |
+
else:
|
245 |
+
point_inputs = None
|
246 |
+
point_inputs = concat_points(point_inputs, points, labels)
|
247 |
+
|
248 |
+
point_inputs_per_frame[frame_idx] = point_inputs
|
249 |
+
mask_inputs_per_frame.pop(frame_idx, None)
|
250 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
251 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
252 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
253 |
+
# the input points will be used to correct the already tracked masks.
|
254 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
255 |
+
# whether to track in reverse time order
|
256 |
+
if is_init_cond_frame:
|
257 |
+
reverse = False
|
258 |
+
else:
|
259 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
260 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
261 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
262 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
263 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
264 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
265 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
266 |
+
|
267 |
+
# Get any previously predicted mask logits on this object and feed it along with
|
268 |
+
# the new clicks into the SAM mask decoder.
|
269 |
+
prev_sam_mask_logits = None
|
270 |
+
# lookup temporary output dict first, which contains the most recent output
|
271 |
+
# (if not found, then lookup conditioning and non-conditioning frame output)
|
272 |
+
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
273 |
+
if prev_out is None:
|
274 |
+
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
275 |
+
if prev_out is None:
|
276 |
+
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
277 |
+
|
278 |
+
if prev_out is not None and prev_out["pred_masks"] is not None:
|
279 |
+
device = inference_state["device"]
|
280 |
+
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
281 |
+
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
282 |
+
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
283 |
+
current_out, _ = self._run_single_frame_inference(
|
284 |
+
inference_state=inference_state,
|
285 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
286 |
+
frame_idx=frame_idx,
|
287 |
+
batch_size=1, # run on the slice of a single object
|
288 |
+
is_init_cond_frame=is_init_cond_frame,
|
289 |
+
point_inputs=point_inputs,
|
290 |
+
mask_inputs=None,
|
291 |
+
reverse=reverse,
|
292 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
293 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
294 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
295 |
+
# them into memory.
|
296 |
+
run_mem_encoder=False,
|
297 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
298 |
+
)
|
299 |
+
# Add the output to the output dict (to be used as future memory)
|
300 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
301 |
+
|
302 |
+
# Resize the output mask to the original video resolution
|
303 |
+
obj_ids = inference_state["obj_ids"]
|
304 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
305 |
+
inference_state,
|
306 |
+
frame_idx,
|
307 |
+
is_cond=is_cond,
|
308 |
+
run_mem_encoder=False,
|
309 |
+
consolidate_at_video_res=True,
|
310 |
+
)
|
311 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
312 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
313 |
+
)
|
314 |
+
return frame_idx, obj_ids, video_res_masks
|
315 |
+
|
316 |
+
def add_new_points(self, *args, **kwargs):
|
317 |
+
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
318 |
+
return self.add_new_points_or_box(*args, **kwargs)
|
319 |
+
|
320 |
+
@torch.inference_mode()
|
321 |
+
def add_new_mask(
|
322 |
+
self,
|
323 |
+
inference_state,
|
324 |
+
frame_idx,
|
325 |
+
obj_id,
|
326 |
+
mask,
|
327 |
+
):
|
328 |
+
"""Add new mask to a frame."""
|
329 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
330 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
331 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
332 |
+
|
333 |
+
if not isinstance(mask, torch.Tensor):
|
334 |
+
mask = torch.tensor(mask, dtype=torch.bool)
|
335 |
+
assert mask.dim() == 2
|
336 |
+
mask_H, mask_W = mask.shape
|
337 |
+
mask_inputs_orig = mask[None, None] # add batch and channel dimension
|
338 |
+
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
339 |
+
|
340 |
+
# resize the mask if it doesn't match the model's image size
|
341 |
+
if mask_H != self.image_size or mask_W != self.image_size:
|
342 |
+
mask_inputs = torch.nn.functional.interpolate(
|
343 |
+
mask_inputs_orig,
|
344 |
+
size=(self.image_size, self.image_size),
|
345 |
+
align_corners=False,
|
346 |
+
mode="bilinear",
|
347 |
+
antialias=True, # use antialias for downsampling
|
348 |
+
)
|
349 |
+
mask_inputs = (mask_inputs >= 0.5).float()
|
350 |
+
else:
|
351 |
+
mask_inputs = mask_inputs_orig
|
352 |
+
|
353 |
+
mask_inputs_per_frame[frame_idx] = mask_inputs
|
354 |
+
point_inputs_per_frame.pop(frame_idx, None)
|
355 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
356 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
357 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
358 |
+
# the input points will be used to correct the already tracked masks.
|
359 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
360 |
+
# whether to track in reverse time order
|
361 |
+
if is_init_cond_frame:
|
362 |
+
reverse = False
|
363 |
+
else:
|
364 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
365 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
366 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
367 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
368 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
369 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
370 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
371 |
+
|
372 |
+
current_out, _ = self._run_single_frame_inference(
|
373 |
+
inference_state=inference_state,
|
374 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
375 |
+
frame_idx=frame_idx,
|
376 |
+
batch_size=1, # run on the slice of a single object
|
377 |
+
is_init_cond_frame=is_init_cond_frame,
|
378 |
+
point_inputs=None,
|
379 |
+
mask_inputs=mask_inputs,
|
380 |
+
reverse=reverse,
|
381 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
382 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
383 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
384 |
+
# them into memory.
|
385 |
+
run_mem_encoder=False,
|
386 |
+
)
|
387 |
+
# Add the output to the output dict (to be used as future memory)
|
388 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
389 |
+
|
390 |
+
# Resize the output mask to the original video resolution
|
391 |
+
obj_ids = inference_state["obj_ids"]
|
392 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
393 |
+
inference_state,
|
394 |
+
frame_idx,
|
395 |
+
is_cond=is_cond,
|
396 |
+
run_mem_encoder=False,
|
397 |
+
consolidate_at_video_res=True,
|
398 |
+
)
|
399 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
400 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
401 |
+
)
|
402 |
+
return frame_idx, obj_ids, video_res_masks
|
403 |
+
|
404 |
+
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
405 |
+
"""
|
406 |
+
Resize the object scores to the original video resolution (video_res_masks)
|
407 |
+
and apply non-overlapping constraints for final output.
|
408 |
+
"""
|
409 |
+
device = inference_state["device"]
|
410 |
+
video_H = inference_state["video_height"]
|
411 |
+
video_W = inference_state["video_width"]
|
412 |
+
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
413 |
+
if any_res_masks.shape[-2:] == (video_H, video_W):
|
414 |
+
video_res_masks = any_res_masks
|
415 |
+
else:
|
416 |
+
video_res_masks = torch.nn.functional.interpolate(
|
417 |
+
any_res_masks,
|
418 |
+
size=(video_H, video_W),
|
419 |
+
mode="bilinear",
|
420 |
+
align_corners=False,
|
421 |
+
)
|
422 |
+
if self.non_overlap_masks:
|
423 |
+
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
424 |
+
return any_res_masks, video_res_masks
|
425 |
+
|
426 |
+
def _consolidate_temp_output_across_obj(
|
427 |
+
self,
|
428 |
+
inference_state,
|
429 |
+
frame_idx,
|
430 |
+
is_cond,
|
431 |
+
run_mem_encoder,
|
432 |
+
consolidate_at_video_res=False,
|
433 |
+
):
|
434 |
+
"""
|
435 |
+
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
436 |
+
a frame into a single output for all objects, including
|
437 |
+
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
438 |
+
`output_dict_per_obj` for this frame) or leave them as placeholder values
|
439 |
+
(if they don't exist in `output_dict_per_obj` for this frame);
|
440 |
+
2) if specified, rerun memory encoder after apply non-overlapping constraints
|
441 |
+
on the object scores.
|
442 |
+
"""
|
443 |
+
batch_size = self._get_obj_num(inference_state)
|
444 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
445 |
+
# Optionally, we allow consolidating the temporary outputs at the original
|
446 |
+
# video resolution (to provide a better editing experience for mask prompts).
|
447 |
+
if consolidate_at_video_res:
|
448 |
+
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
449 |
+
consolidated_H = inference_state["video_height"]
|
450 |
+
consolidated_W = inference_state["video_width"]
|
451 |
+
consolidated_mask_key = "pred_masks_video_res"
|
452 |
+
else:
|
453 |
+
consolidated_H = consolidated_W = self.image_size // 4
|
454 |
+
consolidated_mask_key = "pred_masks"
|
455 |
+
|
456 |
+
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
457 |
+
# will be added when rerunning the memory encoder after applying non-overlapping
|
458 |
+
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
459 |
+
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
460 |
+
consolidated_out = {
|
461 |
+
"maskmem_features": None,
|
462 |
+
"maskmem_pos_enc": None,
|
463 |
+
consolidated_mask_key: torch.full(
|
464 |
+
size=(batch_size, 1, consolidated_H, consolidated_W),
|
465 |
+
fill_value=NO_OBJ_SCORE,
|
466 |
+
dtype=torch.float32,
|
467 |
+
device=inference_state["storage_device"],
|
468 |
+
),
|
469 |
+
"obj_ptr": torch.full(
|
470 |
+
size=(batch_size, self.hidden_dim),
|
471 |
+
fill_value=NO_OBJ_SCORE,
|
472 |
+
dtype=torch.float32,
|
473 |
+
device=inference_state["device"],
|
474 |
+
),
|
475 |
+
"object_score_logits": torch.full(
|
476 |
+
size=(batch_size, 1),
|
477 |
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
478 |
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
479 |
+
fill_value=10.0,
|
480 |
+
dtype=torch.float32,
|
481 |
+
device=inference_state["device"],
|
482 |
+
),
|
483 |
+
}
|
484 |
+
empty_mask_ptr = None
|
485 |
+
for obj_idx in range(batch_size):
|
486 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
487 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
488 |
+
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
489 |
+
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
490 |
+
# we fall back and look up its previous output in "output_dict_per_obj".
|
491 |
+
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
492 |
+
# "output_dict_per_obj" to find a previous output for this object.
|
493 |
+
if out is None:
|
494 |
+
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
495 |
+
if out is None:
|
496 |
+
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
497 |
+
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
498 |
+
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
499 |
+
# placeholder above) and set its object pointer to be a dummy pointer.
|
500 |
+
if out is None:
|
501 |
+
# Fill in dummy object pointers for those objects without any inputs or
|
502 |
+
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
503 |
+
# i.e. when we need to build the memory for tracking).
|
504 |
+
if run_mem_encoder:
|
505 |
+
if empty_mask_ptr is None:
|
506 |
+
empty_mask_ptr = self._get_empty_mask_ptr(
|
507 |
+
inference_state, frame_idx
|
508 |
+
)
|
509 |
+
# fill object pointer with a dummy pointer (based on an empty mask)
|
510 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
511 |
+
continue
|
512 |
+
# Add the temporary object output mask to consolidated output mask
|
513 |
+
obj_mask = out["pred_masks"]
|
514 |
+
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
515 |
+
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
516 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
517 |
+
else:
|
518 |
+
# Resize first if temporary object mask has a different resolution
|
519 |
+
resized_obj_mask = torch.nn.functional.interpolate(
|
520 |
+
obj_mask,
|
521 |
+
size=consolidated_pred_masks.shape[-2:],
|
522 |
+
mode="bilinear",
|
523 |
+
align_corners=False,
|
524 |
+
)
|
525 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
526 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
527 |
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
528 |
+
"object_score_logits"
|
529 |
+
]
|
530 |
+
|
531 |
+
# Optionally, apply non-overlapping constraints on the consolidated scores
|
532 |
+
# and rerun the memory encoder
|
533 |
+
if run_mem_encoder:
|
534 |
+
device = inference_state["device"]
|
535 |
+
high_res_masks = torch.nn.functional.interpolate(
|
536 |
+
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
537 |
+
size=(self.image_size, self.image_size),
|
538 |
+
mode="bilinear",
|
539 |
+
align_corners=False,
|
540 |
+
)
|
541 |
+
if self.non_overlap_masks_for_mem_enc:
|
542 |
+
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
543 |
+
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
544 |
+
inference_state=inference_state,
|
545 |
+
frame_idx=frame_idx,
|
546 |
+
batch_size=batch_size,
|
547 |
+
high_res_masks=high_res_masks,
|
548 |
+
object_score_logits=consolidated_out["object_score_logits"],
|
549 |
+
is_mask_from_pts=True, # these frames are what the user interacted with
|
550 |
+
)
|
551 |
+
consolidated_out["maskmem_features"] = maskmem_features
|
552 |
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
553 |
+
|
554 |
+
return consolidated_out
|
555 |
+
|
556 |
+
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
557 |
+
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
558 |
+
# A dummy (empty) mask with a single object
|
559 |
+
batch_size = 1
|
560 |
+
mask_inputs = torch.zeros(
|
561 |
+
(batch_size, 1, self.image_size, self.image_size),
|
562 |
+
dtype=torch.float32,
|
563 |
+
device=inference_state["device"],
|
564 |
+
)
|
565 |
+
|
566 |
+
# Retrieve correct image features
|
567 |
+
(
|
568 |
+
_,
|
569 |
+
_,
|
570 |
+
current_vision_feats,
|
571 |
+
current_vision_pos_embeds,
|
572 |
+
feat_sizes,
|
573 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
574 |
+
|
575 |
+
# Feed the empty mask and image feature above to get a dummy object pointer
|
576 |
+
current_out = self.track_step(
|
577 |
+
frame_idx=frame_idx,
|
578 |
+
is_init_cond_frame=True,
|
579 |
+
current_vision_feats=current_vision_feats,
|
580 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
581 |
+
feat_sizes=feat_sizes,
|
582 |
+
point_inputs=None,
|
583 |
+
mask_inputs=mask_inputs,
|
584 |
+
output_dict={},
|
585 |
+
num_frames=inference_state["num_frames"],
|
586 |
+
track_in_reverse=False,
|
587 |
+
run_mem_encoder=False,
|
588 |
+
prev_sam_mask_logits=None,
|
589 |
+
)
|
590 |
+
return current_out["obj_ptr"]
|
591 |
+
|
592 |
+
@torch.inference_mode()
|
593 |
+
def propagate_in_video_preflight(self, inference_state):
|
594 |
+
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
595 |
+
# Tracking has started and we don't allow adding new objects until session is reset.
|
596 |
+
inference_state["tracking_has_started"] = True
|
597 |
+
batch_size = self._get_obj_num(inference_state)
|
598 |
+
|
599 |
+
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
600 |
+
# add them into "output_dict".
|
601 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
602 |
+
output_dict = inference_state["output_dict"]
|
603 |
+
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
604 |
+
# temporary outputs have been added (either in this call or any previous calls
|
605 |
+
# to `propagate_in_video_preflight`).
|
606 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
607 |
+
for is_cond in [False, True]:
|
608 |
+
# Separately consolidate conditioning and non-conditioning temp outputs
|
609 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
610 |
+
# Find all the frames that contain temporary outputs for any objects
|
611 |
+
# (these should be the frames that have just received clicks for mask inputs
|
612 |
+
# via `add_new_points_or_box` or `add_new_mask`)
|
613 |
+
temp_frame_inds = set()
|
614 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
615 |
+
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
616 |
+
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
617 |
+
# consolidate the temporary output across all objects on this frame
|
618 |
+
for frame_idx in temp_frame_inds:
|
619 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
620 |
+
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
621 |
+
)
|
622 |
+
# merge them into "output_dict" and also create per-object slices
|
623 |
+
output_dict[storage_key][frame_idx] = consolidated_out
|
624 |
+
self._add_output_per_object(
|
625 |
+
inference_state, frame_idx, consolidated_out, storage_key
|
626 |
+
)
|
627 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
628 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
629 |
+
)
|
630 |
+
if clear_non_cond_mem:
|
631 |
+
# clear non-conditioning memory of the surrounding frames
|
632 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
633 |
+
|
634 |
+
# clear temporary outputs in `temp_output_dict_per_obj`
|
635 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
636 |
+
obj_temp_output_dict[storage_key].clear()
|
637 |
+
|
638 |
+
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
639 |
+
# output on the same frame in "non_cond_frame_outputs"
|
640 |
+
for frame_idx in output_dict["cond_frame_outputs"]:
|
641 |
+
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
642 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
643 |
+
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
644 |
+
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
645 |
+
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
646 |
+
assert frame_idx in output_dict["cond_frame_outputs"]
|
647 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
648 |
+
|
649 |
+
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
650 |
+
# with either points or mask inputs (which should be true under a correct workflow).
|
651 |
+
all_consolidated_frame_inds = (
|
652 |
+
consolidated_frame_inds["cond_frame_outputs"]
|
653 |
+
| consolidated_frame_inds["non_cond_frame_outputs"]
|
654 |
+
)
|
655 |
+
input_frames_inds = set()
|
656 |
+
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
657 |
+
input_frames_inds.update(point_inputs_per_frame.keys())
|
658 |
+
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
659 |
+
input_frames_inds.update(mask_inputs_per_frame.keys())
|
660 |
+
assert all_consolidated_frame_inds == input_frames_inds
|
661 |
+
|
662 |
+
@torch.inference_mode()
|
663 |
+
def propagate_in_video(
|
664 |
+
self,
|
665 |
+
inference_state,
|
666 |
+
start_frame_idx=None,
|
667 |
+
max_frame_num_to_track=None,
|
668 |
+
reverse=False,
|
669 |
+
):
|
670 |
+
"""Propagate the input points across frames to track in the entire video."""
|
671 |
+
self.propagate_in_video_preflight(inference_state)
|
672 |
+
|
673 |
+
output_dict = inference_state["output_dict"]
|
674 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
675 |
+
obj_ids = inference_state["obj_ids"]
|
676 |
+
num_frames = inference_state["num_frames"]
|
677 |
+
batch_size = self._get_obj_num(inference_state)
|
678 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
679 |
+
raise RuntimeError("No points are provided; please add points first")
|
680 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
681 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
682 |
+
)
|
683 |
+
|
684 |
+
# set start index, end index, and processing order
|
685 |
+
if start_frame_idx is None:
|
686 |
+
# default: start from the earliest frame with input points
|
687 |
+
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
688 |
+
if max_frame_num_to_track is None:
|
689 |
+
# default: track all the frames in the video
|
690 |
+
max_frame_num_to_track = num_frames
|
691 |
+
if reverse:
|
692 |
+
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
693 |
+
if start_frame_idx > 0:
|
694 |
+
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
695 |
+
else:
|
696 |
+
processing_order = [] # skip reverse tracking if starting from frame 0
|
697 |
+
else:
|
698 |
+
end_frame_idx = min(
|
699 |
+
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
700 |
+
)
|
701 |
+
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
702 |
+
|
703 |
+
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
704 |
+
# We skip those frames already in consolidated outputs (these are frames
|
705 |
+
# that received input clicks or mask). Note that we cannot directly run
|
706 |
+
# batched forward on them via `_run_single_frame_inference` because the
|
707 |
+
# number of clicks on each object might be different.
|
708 |
+
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
709 |
+
storage_key = "cond_frame_outputs"
|
710 |
+
current_out = output_dict[storage_key][frame_idx]
|
711 |
+
pred_masks = current_out["pred_masks"]
|
712 |
+
if clear_non_cond_mem:
|
713 |
+
# clear non-conditioning memory of the surrounding frames
|
714 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
715 |
+
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
716 |
+
storage_key = "non_cond_frame_outputs"
|
717 |
+
current_out = output_dict[storage_key][frame_idx]
|
718 |
+
pred_masks = current_out["pred_masks"]
|
719 |
+
else:
|
720 |
+
storage_key = "non_cond_frame_outputs"
|
721 |
+
current_out, pred_masks = self._run_single_frame_inference(
|
722 |
+
inference_state=inference_state,
|
723 |
+
output_dict=output_dict,
|
724 |
+
frame_idx=frame_idx,
|
725 |
+
batch_size=batch_size,
|
726 |
+
is_init_cond_frame=False,
|
727 |
+
point_inputs=None,
|
728 |
+
mask_inputs=None,
|
729 |
+
reverse=reverse,
|
730 |
+
run_mem_encoder=True,
|
731 |
+
)
|
732 |
+
output_dict[storage_key][frame_idx] = current_out
|
733 |
+
# Create slices of per-object outputs for subsequent interaction with each
|
734 |
+
# individual object after tracking.
|
735 |
+
self._add_output_per_object(
|
736 |
+
inference_state, frame_idx, current_out, storage_key
|
737 |
+
)
|
738 |
+
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
739 |
+
|
740 |
+
# Resize the output mask to the original video resolution (we directly use
|
741 |
+
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
742 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
743 |
+
inference_state, pred_masks
|
744 |
+
)
|
745 |
+
yield frame_idx, obj_ids, video_res_masks
|
746 |
+
|
747 |
+
def _add_output_per_object(
|
748 |
+
self, inference_state, frame_idx, current_out, storage_key
|
749 |
+
):
|
750 |
+
"""
|
751 |
+
Split a multi-object output into per-object output slices and add them into
|
752 |
+
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
753 |
+
"""
|
754 |
+
maskmem_features = current_out["maskmem_features"]
|
755 |
+
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
756 |
+
|
757 |
+
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
758 |
+
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
759 |
+
|
760 |
+
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
761 |
+
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
762 |
+
obj_slice = slice(obj_idx, obj_idx + 1)
|
763 |
+
obj_out = {
|
764 |
+
"maskmem_features": None,
|
765 |
+
"maskmem_pos_enc": None,
|
766 |
+
"pred_masks": current_out["pred_masks"][obj_slice],
|
767 |
+
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
768 |
+
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
769 |
+
}
|
770 |
+
if maskmem_features is not None:
|
771 |
+
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
772 |
+
if maskmem_pos_enc is not None:
|
773 |
+
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
774 |
+
obj_output_dict[storage_key][frame_idx] = obj_out
|
775 |
+
|
776 |
+
@torch.inference_mode()
|
777 |
+
def clear_all_prompts_in_frame(
|
778 |
+
self, inference_state, frame_idx, obj_id, need_output=True
|
779 |
+
):
|
780 |
+
"""Remove all input points or mask in a specific frame for a given object."""
|
781 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
782 |
+
|
783 |
+
# Clear the conditioning information on the given frame
|
784 |
+
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
785 |
+
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
786 |
+
|
787 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
788 |
+
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
789 |
+
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
790 |
+
|
791 |
+
# Check and see if there are still any inputs left on this frame
|
792 |
+
batch_size = self._get_obj_num(inference_state)
|
793 |
+
frame_has_input = False
|
794 |
+
for obj_idx2 in range(batch_size):
|
795 |
+
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
796 |
+
frame_has_input = True
|
797 |
+
break
|
798 |
+
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
799 |
+
frame_has_input = True
|
800 |
+
break
|
801 |
+
|
802 |
+
# If this frame has no remaining inputs for any objects, we further clear its
|
803 |
+
# conditioning frame status
|
804 |
+
if not frame_has_input:
|
805 |
+
output_dict = inference_state["output_dict"]
|
806 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
807 |
+
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
808 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
809 |
+
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
810 |
+
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
811 |
+
if out is not None:
|
812 |
+
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
813 |
+
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
814 |
+
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
815 |
+
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
816 |
+
# Similarly, do it for the sliced output on each object.
|
817 |
+
for obj_idx2 in range(batch_size):
|
818 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
819 |
+
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
820 |
+
if obj_out is not None:
|
821 |
+
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
822 |
+
|
823 |
+
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
824 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
825 |
+
self._reset_tracking_results(inference_state)
|
826 |
+
|
827 |
+
if not need_output:
|
828 |
+
return
|
829 |
+
# Finally, output updated masks per object (after removing the inputs above)
|
830 |
+
obj_ids = inference_state["obj_ids"]
|
831 |
+
is_cond = any(
|
832 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
833 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
834 |
+
)
|
835 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
836 |
+
inference_state,
|
837 |
+
frame_idx,
|
838 |
+
is_cond=is_cond,
|
839 |
+
run_mem_encoder=False,
|
840 |
+
consolidate_at_video_res=True,
|
841 |
+
)
|
842 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
843 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
844 |
+
)
|
845 |
+
return frame_idx, obj_ids, video_res_masks
|
846 |
+
|
847 |
+
@torch.inference_mode()
|
848 |
+
def reset_state(self, inference_state):
|
849 |
+
"""Remove all input points or mask in all frames throughout the video."""
|
850 |
+
self._reset_tracking_results(inference_state)
|
851 |
+
# Remove all object ids
|
852 |
+
inference_state["obj_id_to_idx"].clear()
|
853 |
+
inference_state["obj_idx_to_id"].clear()
|
854 |
+
inference_state["obj_ids"].clear()
|
855 |
+
inference_state["point_inputs_per_obj"].clear()
|
856 |
+
inference_state["mask_inputs_per_obj"].clear()
|
857 |
+
inference_state["output_dict_per_obj"].clear()
|
858 |
+
inference_state["temp_output_dict_per_obj"].clear()
|
859 |
+
|
860 |
+
def _reset_tracking_results(self, inference_state):
|
861 |
+
"""Reset all tracking inputs and results across the videos."""
|
862 |
+
for v in inference_state["point_inputs_per_obj"].values():
|
863 |
+
v.clear()
|
864 |
+
for v in inference_state["mask_inputs_per_obj"].values():
|
865 |
+
v.clear()
|
866 |
+
for v in inference_state["output_dict_per_obj"].values():
|
867 |
+
v["cond_frame_outputs"].clear()
|
868 |
+
v["non_cond_frame_outputs"].clear()
|
869 |
+
for v in inference_state["temp_output_dict_per_obj"].values():
|
870 |
+
v["cond_frame_outputs"].clear()
|
871 |
+
v["non_cond_frame_outputs"].clear()
|
872 |
+
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
873 |
+
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
874 |
+
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
875 |
+
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
876 |
+
inference_state["tracking_has_started"] = False
|
877 |
+
inference_state["frames_already_tracked"].clear()
|
878 |
+
|
879 |
+
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
880 |
+
"""Compute the image features on a given frame."""
|
881 |
+
# Look up in the cache first
|
882 |
+
image, backbone_out = inference_state["cached_features"].get(
|
883 |
+
frame_idx, (None, None)
|
884 |
+
)
|
885 |
+
if backbone_out is None:
|
886 |
+
# Cache miss -- we will run inference on a single image
|
887 |
+
device = inference_state["device"]
|
888 |
+
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
889 |
+
backbone_out = self.forward_image(image)
|
890 |
+
# Cache the most recent frame's feature (for repeated interactions with
|
891 |
+
# a frame; we can use an LRU cache for more frames in the future).
|
892 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
893 |
+
|
894 |
+
# expand the features to have the same dimension as the number of objects
|
895 |
+
expanded_image = image.expand(batch_size, -1, -1, -1)
|
896 |
+
expanded_backbone_out = {
|
897 |
+
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
898 |
+
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
899 |
+
}
|
900 |
+
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
901 |
+
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
902 |
+
batch_size, -1, -1, -1
|
903 |
+
)
|
904 |
+
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
905 |
+
pos = pos.expand(batch_size, -1, -1, -1)
|
906 |
+
expanded_backbone_out["vision_pos_enc"][i] = pos
|
907 |
+
|
908 |
+
features = self._prepare_backbone_features(expanded_backbone_out)
|
909 |
+
features = (expanded_image,) + features
|
910 |
+
return features
|
911 |
+
|
912 |
+
def _run_single_frame_inference(
|
913 |
+
self,
|
914 |
+
inference_state,
|
915 |
+
output_dict,
|
916 |
+
frame_idx,
|
917 |
+
batch_size,
|
918 |
+
is_init_cond_frame,
|
919 |
+
point_inputs,
|
920 |
+
mask_inputs,
|
921 |
+
reverse,
|
922 |
+
run_mem_encoder,
|
923 |
+
prev_sam_mask_logits=None,
|
924 |
+
):
|
925 |
+
"""Run tracking on a single frame based on current inputs and previous memory."""
|
926 |
+
# Retrieve correct image features
|
927 |
+
(
|
928 |
+
_,
|
929 |
+
_,
|
930 |
+
current_vision_feats,
|
931 |
+
current_vision_pos_embeds,
|
932 |
+
feat_sizes,
|
933 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
934 |
+
|
935 |
+
# point and mask should not appear as input simultaneously on the same frame
|
936 |
+
assert point_inputs is None or mask_inputs is None
|
937 |
+
current_out = self.track_step(
|
938 |
+
frame_idx=frame_idx,
|
939 |
+
is_init_cond_frame=is_init_cond_frame,
|
940 |
+
current_vision_feats=current_vision_feats,
|
941 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
942 |
+
feat_sizes=feat_sizes,
|
943 |
+
point_inputs=point_inputs,
|
944 |
+
mask_inputs=mask_inputs,
|
945 |
+
output_dict=output_dict,
|
946 |
+
num_frames=inference_state["num_frames"],
|
947 |
+
track_in_reverse=reverse,
|
948 |
+
run_mem_encoder=run_mem_encoder,
|
949 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
950 |
+
)
|
951 |
+
|
952 |
+
# optionally offload the output to CPU memory to save GPU space
|
953 |
+
storage_device = inference_state["storage_device"]
|
954 |
+
maskmem_features = current_out["maskmem_features"]
|
955 |
+
if maskmem_features is not None:
|
956 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
957 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
958 |
+
pred_masks_gpu = current_out["pred_masks"]
|
959 |
+
# potentially fill holes in the predicted masks
|
960 |
+
if self.fill_hole_area > 0:
|
961 |
+
pred_masks_gpu = fill_holes_in_mask_scores(
|
962 |
+
pred_masks_gpu, self.fill_hole_area
|
963 |
+
)
|
964 |
+
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
965 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
966 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
967 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
968 |
+
obj_ptr = current_out["obj_ptr"]
|
969 |
+
object_score_logits = current_out["object_score_logits"]
|
970 |
+
# make a compact version of this frame's output to reduce the state size
|
971 |
+
compact_current_out = {
|
972 |
+
"maskmem_features": maskmem_features,
|
973 |
+
"maskmem_pos_enc": maskmem_pos_enc,
|
974 |
+
"pred_masks": pred_masks,
|
975 |
+
"obj_ptr": obj_ptr,
|
976 |
+
"object_score_logits": object_score_logits,
|
977 |
+
}
|
978 |
+
return compact_current_out, pred_masks_gpu
|
979 |
+
|
980 |
+
def _run_memory_encoder(
|
981 |
+
self,
|
982 |
+
inference_state,
|
983 |
+
frame_idx,
|
984 |
+
batch_size,
|
985 |
+
high_res_masks,
|
986 |
+
object_score_logits,
|
987 |
+
is_mask_from_pts,
|
988 |
+
):
|
989 |
+
"""
|
990 |
+
Run the memory encoder on `high_res_masks`. This is usually after applying
|
991 |
+
non-overlapping constraints to object scores. Since their scores changed, their
|
992 |
+
memory also need to be computed again with the memory encoder.
|
993 |
+
"""
|
994 |
+
# Retrieve correct image features
|
995 |
+
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
996 |
+
inference_state, frame_idx, batch_size
|
997 |
+
)
|
998 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
999 |
+
current_vision_feats=current_vision_feats,
|
1000 |
+
feat_sizes=feat_sizes,
|
1001 |
+
pred_masks_high_res=high_res_masks,
|
1002 |
+
object_score_logits=object_score_logits,
|
1003 |
+
is_mask_from_pts=is_mask_from_pts,
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
# optionally offload the output to CPU memory to save GPU space
|
1007 |
+
storage_device = inference_state["storage_device"]
|
1008 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
1009 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
1010 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
1011 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(
|
1012 |
+
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
1013 |
+
)
|
1014 |
+
return maskmem_features, maskmem_pos_enc
|
1015 |
+
|
1016 |
+
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
1017 |
+
"""
|
1018 |
+
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
1019 |
+
a constant in the inference session to reduce session storage size.
|
1020 |
+
"""
|
1021 |
+
model_constants = inference_state["constants"]
|
1022 |
+
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
1023 |
+
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
1024 |
+
if out_maskmem_pos_enc is not None:
|
1025 |
+
if "maskmem_pos_enc" not in model_constants:
|
1026 |
+
assert isinstance(out_maskmem_pos_enc, list)
|
1027 |
+
# only take the slice for one object, since it's same across objects
|
1028 |
+
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
1029 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
1030 |
+
else:
|
1031 |
+
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
1032 |
+
# expand the cached maskmem_pos_enc to the actual batch size
|
1033 |
+
batch_size = out_maskmem_pos_enc[0].size(0)
|
1034 |
+
expanded_maskmem_pos_enc = [
|
1035 |
+
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
1036 |
+
]
|
1037 |
+
else:
|
1038 |
+
expanded_maskmem_pos_enc = None
|
1039 |
+
return expanded_maskmem_pos_enc
|
1040 |
+
|
1041 |
+
@torch.inference_mode()
|
1042 |
+
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
1043 |
+
"""
|
1044 |
+
Remove an object id from the tracking state. If strict is True, we check whether
|
1045 |
+
the object id actually exists and raise an error if it doesn't exist.
|
1046 |
+
"""
|
1047 |
+
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
1048 |
+
updated_frames = []
|
1049 |
+
# Check whether this object_id to remove actually exists and possibly raise an error.
|
1050 |
+
if old_obj_idx_to_rm is None:
|
1051 |
+
if not strict:
|
1052 |
+
return inference_state["obj_ids"], updated_frames
|
1053 |
+
raise RuntimeError(
|
1054 |
+
f"Cannot remove object id {obj_id} as it doesn't exist. "
|
1055 |
+
f"All existing object ids: {inference_state['obj_ids']}."
|
1056 |
+
)
|
1057 |
+
|
1058 |
+
# If this is the only remaining object id, we simply reset the state.
|
1059 |
+
if len(inference_state["obj_id_to_idx"]) == 1:
|
1060 |
+
self.reset_state(inference_state)
|
1061 |
+
return inference_state["obj_ids"], updated_frames
|
1062 |
+
|
1063 |
+
# There are still remaining objects after removing this object id. In this case,
|
1064 |
+
# we need to delete the object storage from inference state tensors.
|
1065 |
+
# Step 0: clear the input on those frames where this object id has point or mask input
|
1066 |
+
# (note that this step is required as it might downgrade conditioning frames to
|
1067 |
+
# non-conditioning ones)
|
1068 |
+
obj_input_frames_inds = set()
|
1069 |
+
obj_input_frames_inds.update(
|
1070 |
+
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
1071 |
+
)
|
1072 |
+
obj_input_frames_inds.update(
|
1073 |
+
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
1074 |
+
)
|
1075 |
+
for frame_idx in obj_input_frames_inds:
|
1076 |
+
self.clear_all_prompts_in_frame(
|
1077 |
+
inference_state, frame_idx, obj_id, need_output=False
|
1078 |
+
)
|
1079 |
+
|
1080 |
+
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
1081 |
+
# since Step 0 still requires the old object id mappings in inference_state)
|
1082 |
+
old_obj_ids = inference_state["obj_ids"]
|
1083 |
+
old_obj_inds = list(range(len(old_obj_ids)))
|
1084 |
+
remain_old_obj_inds = old_obj_inds.copy()
|
1085 |
+
remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
1086 |
+
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
1087 |
+
new_obj_inds = list(range(len(new_obj_ids)))
|
1088 |
+
# build new mappings
|
1089 |
+
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
1090 |
+
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
1091 |
+
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
1092 |
+
inference_state["obj_ids"] = new_obj_ids
|
1093 |
+
|
1094 |
+
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
1095 |
+
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
1096 |
+
# it's already handled in Step 0)
|
1097 |
+
def _map_keys(container):
|
1098 |
+
new_kvs = []
|
1099 |
+
for k in old_obj_inds:
|
1100 |
+
v = container.pop(k)
|
1101 |
+
if k in old_idx_to_new_idx:
|
1102 |
+
new_kvs.append((old_idx_to_new_idx[k], v))
|
1103 |
+
container.update(new_kvs)
|
1104 |
+
|
1105 |
+
_map_keys(inference_state["point_inputs_per_obj"])
|
1106 |
+
_map_keys(inference_state["mask_inputs_per_obj"])
|
1107 |
+
_map_keys(inference_state["output_dict_per_obj"])
|
1108 |
+
_map_keys(inference_state["temp_output_dict_per_obj"])
|
1109 |
+
|
1110 |
+
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
1111 |
+
def _slice_state(output_dict, storage_key):
|
1112 |
+
for frame_idx, out in output_dict[storage_key].items():
|
1113 |
+
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
1114 |
+
out["maskmem_pos_enc"] = [
|
1115 |
+
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
1116 |
+
]
|
1117 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
1118 |
+
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
1119 |
+
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
1120 |
+
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
1121 |
+
out["object_score_logits"] = out["object_score_logits"][
|
1122 |
+
remain_old_obj_inds
|
1123 |
+
]
|
1124 |
+
# also update the per-object slices
|
1125 |
+
self._add_output_per_object(
|
1126 |
+
inference_state, frame_idx, out, storage_key
|
1127 |
+
)
|
1128 |
+
|
1129 |
+
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
1130 |
+
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
1131 |
+
|
1132 |
+
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
1133 |
+
# could show an updated mask for objects previously occluded by the object being removed
|
1134 |
+
if need_output:
|
1135 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
1136 |
+
for frame_idx in obj_input_frames_inds:
|
1137 |
+
is_cond = any(
|
1138 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
1139 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
1140 |
+
)
|
1141 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
1142 |
+
inference_state,
|
1143 |
+
frame_idx,
|
1144 |
+
is_cond=is_cond,
|
1145 |
+
run_mem_encoder=False,
|
1146 |
+
consolidate_at_video_res=True,
|
1147 |
+
)
|
1148 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
1149 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
1150 |
+
)
|
1151 |
+
updated_frames.append((frame_idx, video_res_masks))
|
1152 |
+
|
1153 |
+
return inference_state["obj_ids"], updated_frames
|
1154 |
+
|
1155 |
+
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
1156 |
+
"""
|
1157 |
+
Remove the non-conditioning memory around the input frame. When users provide
|
1158 |
+
correction clicks, the surrounding frames' non-conditioning memories can still
|
1159 |
+
contain outdated object appearance information and could confuse the model.
|
1160 |
+
|
1161 |
+
This method clears those non-conditioning memories surrounding the interacted
|
1162 |
+
frame to avoid giving the model both old and new information about the object.
|
1163 |
+
"""
|
1164 |
+
r = self.memory_temporal_stride_for_eval
|
1165 |
+
frame_idx_begin = frame_idx - r * self.num_maskmem
|
1166 |
+
frame_idx_end = frame_idx + r * self.num_maskmem
|
1167 |
+
output_dict = inference_state["output_dict"]
|
1168 |
+
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
1169 |
+
for t in range(frame_idx_begin, frame_idx_end + 1):
|
1170 |
+
non_cond_frame_outputs.pop(t, None)
|
1171 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
1172 |
+
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|