Spaces:
Runtime error
Runtime error
Upload 10 files
Browse files- __init__.py +11 -0
- __init__.pyc +0 -0
- automatic_mask_generator.py +454 -0
- build_sam.py +167 -0
- sam2_hiera_b+.yaml +1 -0
- sam2_hiera_l.yaml +1 -0
- sam2_hiera_s.yaml +1 -0
- sam2_hiera_t.yaml +1 -0
- sam2_image_predictor.py +466 -0
- sam2_video_predictor.py +1172 -0
__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from hydra import initialize_config_module
|
| 8 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 9 |
+
|
| 10 |
+
if not GlobalHydra.instance().is_initialized():
|
| 11 |
+
initialize_config_module("sam2", version_base="1.2")
|
__init__.pyc
ADDED
|
Binary file (352 Bytes). View file
|
|
|
automatic_mask_generator.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_base import SAM2Base
|
| 15 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
| 16 |
+
from sam2.utils.amg import (
|
| 17 |
+
area_from_rle,
|
| 18 |
+
batch_iterator,
|
| 19 |
+
batched_mask_to_box,
|
| 20 |
+
box_xyxy_to_xywh,
|
| 21 |
+
build_all_layer_point_grids,
|
| 22 |
+
calculate_stability_score,
|
| 23 |
+
coco_encode_rle,
|
| 24 |
+
generate_crop_boxes,
|
| 25 |
+
is_box_near_crop_edge,
|
| 26 |
+
mask_to_rle_pytorch,
|
| 27 |
+
MaskData,
|
| 28 |
+
remove_small_regions,
|
| 29 |
+
rle_to_mask,
|
| 30 |
+
uncrop_boxes_xyxy,
|
| 31 |
+
uncrop_masks,
|
| 32 |
+
uncrop_points,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SAM2AutomaticMaskGenerator:
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
model: SAM2Base,
|
| 40 |
+
points_per_side: Optional[int] = 32,
|
| 41 |
+
points_per_batch: int = 64,
|
| 42 |
+
pred_iou_thresh: float = 0.8,
|
| 43 |
+
stability_score_thresh: float = 0.95,
|
| 44 |
+
stability_score_offset: float = 1.0,
|
| 45 |
+
mask_threshold: float = 0.0,
|
| 46 |
+
box_nms_thresh: float = 0.7,
|
| 47 |
+
crop_n_layers: int = 0,
|
| 48 |
+
crop_nms_thresh: float = 0.7,
|
| 49 |
+
crop_overlap_ratio: float = 512 / 1500,
|
| 50 |
+
crop_n_points_downscale_factor: int = 1,
|
| 51 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
| 52 |
+
min_mask_region_area: int = 0,
|
| 53 |
+
output_mode: str = "binary_mask",
|
| 54 |
+
use_m2m: bool = False,
|
| 55 |
+
multimask_output: bool = True,
|
| 56 |
+
**kwargs,
|
| 57 |
+
) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Using a SAM 2 model, generates masks for the entire image.
|
| 60 |
+
Generates a grid of point prompts over the image, then filters
|
| 61 |
+
low quality and duplicate masks. The default settings are chosen
|
| 62 |
+
for SAM 2 with a HieraL backbone.
|
| 63 |
+
|
| 64 |
+
Arguments:
|
| 65 |
+
model (Sam): The SAM 2 model to use for mask prediction.
|
| 66 |
+
points_per_side (int or None): The number of points to be sampled
|
| 67 |
+
along one side of the image. The total number of points is
|
| 68 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
| 69 |
+
point sampling.
|
| 70 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
| 71 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
| 72 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
| 73 |
+
model's predicted mask quality.
|
| 74 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
| 75 |
+
the stability of the mask under changes to the cutoff used to binarize
|
| 76 |
+
the model's mask predictions.
|
| 77 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
| 78 |
+
calculated the stability score.
|
| 79 |
+
mask_threshold (float): Threshold for binarizing the mask logits
|
| 80 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 81 |
+
suppression to filter duplicate masks.
|
| 82 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
| 83 |
+
crops of the image. Sets the number of layers to run, where each
|
| 84 |
+
layer has 2**i_layer number of image crops.
|
| 85 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
| 86 |
+
suppression to filter duplicate masks between different crops.
|
| 87 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
| 88 |
+
In the first crop layer, crops will overlap by this fraction of
|
| 89 |
+
the image length. Later layers with more crops scale down this overlap.
|
| 90 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
| 91 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
| 92 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
| 93 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
| 94 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
| 95 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
| 96 |
+
to remove disconnected regions and holes in masks with area smaller
|
| 97 |
+
than min_mask_region_area. Requires opencv.
|
| 98 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
| 99 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
| 100 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
| 101 |
+
memory.
|
| 102 |
+
use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
|
| 103 |
+
multimask_output (bool): Whether to output multimask at each point of the grid.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
assert (points_per_side is None) != (
|
| 107 |
+
point_grids is None
|
| 108 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
| 109 |
+
if points_per_side is not None:
|
| 110 |
+
self.point_grids = build_all_layer_point_grids(
|
| 111 |
+
points_per_side,
|
| 112 |
+
crop_n_layers,
|
| 113 |
+
crop_n_points_downscale_factor,
|
| 114 |
+
)
|
| 115 |
+
elif point_grids is not None:
|
| 116 |
+
self.point_grids = point_grids
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
| 119 |
+
|
| 120 |
+
assert output_mode in [
|
| 121 |
+
"binary_mask",
|
| 122 |
+
"uncompressed_rle",
|
| 123 |
+
"coco_rle",
|
| 124 |
+
], f"Unknown output_mode {output_mode}."
|
| 125 |
+
if output_mode == "coco_rle":
|
| 126 |
+
try:
|
| 127 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
| 128 |
+
except ImportError as e:
|
| 129 |
+
print("Please install pycocotools")
|
| 130 |
+
raise e
|
| 131 |
+
|
| 132 |
+
self.predictor = SAM2ImagePredictor(
|
| 133 |
+
model,
|
| 134 |
+
max_hole_area=min_mask_region_area,
|
| 135 |
+
max_sprinkle_area=min_mask_region_area,
|
| 136 |
+
)
|
| 137 |
+
self.points_per_batch = points_per_batch
|
| 138 |
+
self.pred_iou_thresh = pred_iou_thresh
|
| 139 |
+
self.stability_score_thresh = stability_score_thresh
|
| 140 |
+
self.stability_score_offset = stability_score_offset
|
| 141 |
+
self.mask_threshold = mask_threshold
|
| 142 |
+
self.box_nms_thresh = box_nms_thresh
|
| 143 |
+
self.crop_n_layers = crop_n_layers
|
| 144 |
+
self.crop_nms_thresh = crop_nms_thresh
|
| 145 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
| 146 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
| 147 |
+
self.min_mask_region_area = min_mask_region_area
|
| 148 |
+
self.output_mode = output_mode
|
| 149 |
+
self.use_m2m = use_m2m
|
| 150 |
+
self.multimask_output = multimask_output
|
| 151 |
+
|
| 152 |
+
@classmethod
|
| 153 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator":
|
| 154 |
+
"""
|
| 155 |
+
Load a pretrained model from the Hugging Face hub.
|
| 156 |
+
|
| 157 |
+
Arguments:
|
| 158 |
+
model_id (str): The Hugging Face repository ID.
|
| 159 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
(SAM2AutomaticMaskGenerator): The loaded model.
|
| 163 |
+
"""
|
| 164 |
+
from sam2.build_sam import build_sam2_hf
|
| 165 |
+
|
| 166 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 167 |
+
return cls(sam_model, **kwargs)
|
| 168 |
+
|
| 169 |
+
@torch.no_grad()
|
| 170 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
| 171 |
+
"""
|
| 172 |
+
Generates masks for the given image.
|
| 173 |
+
|
| 174 |
+
Arguments:
|
| 175 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
| 179 |
+
a dict containing the following keys:
|
| 180 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
| 181 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
| 182 |
+
is a dictionary containing the RLE.
|
| 183 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
| 184 |
+
area (int): The area in pixels of the mask.
|
| 185 |
+
predicted_iou (float): The model's own prediction of the mask's
|
| 186 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
| 187 |
+
point_coords (list(list(float))): The point coordinates input
|
| 188 |
+
to the model to generate this mask.
|
| 189 |
+
stability_score (float): A measure of the mask's quality. This
|
| 190 |
+
is filtered on using the stability_score_thresh parameter.
|
| 191 |
+
crop_box (list(float)): The crop of the image used to generate
|
| 192 |
+
the mask, given in XYWH format.
|
| 193 |
+
"""
|
| 194 |
+
|
| 195 |
+
# Generate masks
|
| 196 |
+
mask_data = self._generate_masks(image)
|
| 197 |
+
|
| 198 |
+
# Encode masks
|
| 199 |
+
if self.output_mode == "coco_rle":
|
| 200 |
+
mask_data["segmentations"] = [
|
| 201 |
+
coco_encode_rle(rle) for rle in mask_data["rles"]
|
| 202 |
+
]
|
| 203 |
+
elif self.output_mode == "binary_mask":
|
| 204 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
| 205 |
+
else:
|
| 206 |
+
mask_data["segmentations"] = mask_data["rles"]
|
| 207 |
+
|
| 208 |
+
# Write mask records
|
| 209 |
+
curr_anns = []
|
| 210 |
+
for idx in range(len(mask_data["segmentations"])):
|
| 211 |
+
ann = {
|
| 212 |
+
"segmentation": mask_data["segmentations"][idx],
|
| 213 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
| 214 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
| 215 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
| 216 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
| 217 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
| 218 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
| 219 |
+
}
|
| 220 |
+
curr_anns.append(ann)
|
| 221 |
+
|
| 222 |
+
return curr_anns
|
| 223 |
+
|
| 224 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
| 225 |
+
orig_size = image.shape[:2]
|
| 226 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
| 227 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Iterate over image crops
|
| 231 |
+
data = MaskData()
|
| 232 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
| 233 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
| 234 |
+
data.cat(crop_data)
|
| 235 |
+
|
| 236 |
+
# Remove duplicate masks between crops
|
| 237 |
+
if len(crop_boxes) > 1:
|
| 238 |
+
# Prefer masks from smaller crops
|
| 239 |
+
scores = 1 / box_area(data["crop_boxes"])
|
| 240 |
+
scores = scores.to(data["boxes"].device)
|
| 241 |
+
keep_by_nms = batched_nms(
|
| 242 |
+
data["boxes"].float(),
|
| 243 |
+
scores,
|
| 244 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 245 |
+
iou_threshold=self.crop_nms_thresh,
|
| 246 |
+
)
|
| 247 |
+
data.filter(keep_by_nms)
|
| 248 |
+
data.to_numpy()
|
| 249 |
+
return data
|
| 250 |
+
|
| 251 |
+
def _process_crop(
|
| 252 |
+
self,
|
| 253 |
+
image: np.ndarray,
|
| 254 |
+
crop_box: List[int],
|
| 255 |
+
crop_layer_idx: int,
|
| 256 |
+
orig_size: Tuple[int, ...],
|
| 257 |
+
) -> MaskData:
|
| 258 |
+
# Crop the image and calculate embeddings
|
| 259 |
+
x0, y0, x1, y1 = crop_box
|
| 260 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
| 261 |
+
cropped_im_size = cropped_im.shape[:2]
|
| 262 |
+
self.predictor.set_image(cropped_im)
|
| 263 |
+
|
| 264 |
+
# Get points for this crop
|
| 265 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
| 266 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
| 267 |
+
|
| 268 |
+
# Generate masks for this crop in batches
|
| 269 |
+
data = MaskData()
|
| 270 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
| 271 |
+
batch_data = self._process_batch(
|
| 272 |
+
points, cropped_im_size, crop_box, orig_size, normalize=True
|
| 273 |
+
)
|
| 274 |
+
data.cat(batch_data)
|
| 275 |
+
del batch_data
|
| 276 |
+
self.predictor.reset_predictor()
|
| 277 |
+
|
| 278 |
+
# Remove duplicates within this crop.
|
| 279 |
+
keep_by_nms = batched_nms(
|
| 280 |
+
data["boxes"].float(),
|
| 281 |
+
data["iou_preds"],
|
| 282 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
| 283 |
+
iou_threshold=self.box_nms_thresh,
|
| 284 |
+
)
|
| 285 |
+
data.filter(keep_by_nms)
|
| 286 |
+
|
| 287 |
+
# Return to the original image frame
|
| 288 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
| 289 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
| 290 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
| 291 |
+
|
| 292 |
+
return data
|
| 293 |
+
|
| 294 |
+
def _process_batch(
|
| 295 |
+
self,
|
| 296 |
+
points: np.ndarray,
|
| 297 |
+
im_size: Tuple[int, ...],
|
| 298 |
+
crop_box: List[int],
|
| 299 |
+
orig_size: Tuple[int, ...],
|
| 300 |
+
normalize=False,
|
| 301 |
+
) -> MaskData:
|
| 302 |
+
orig_h, orig_w = orig_size
|
| 303 |
+
|
| 304 |
+
# Run model on this batch
|
| 305 |
+
points = torch.as_tensor(
|
| 306 |
+
points, dtype=torch.float32, device=self.predictor.device
|
| 307 |
+
)
|
| 308 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 309 |
+
points, normalize=normalize, orig_hw=im_size
|
| 310 |
+
)
|
| 311 |
+
in_labels = torch.ones(
|
| 312 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 313 |
+
)
|
| 314 |
+
masks, iou_preds, low_res_masks = self.predictor._predict(
|
| 315 |
+
in_points[:, None, :],
|
| 316 |
+
in_labels[:, None],
|
| 317 |
+
multimask_output=self.multimask_output,
|
| 318 |
+
return_logits=True,
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Serialize predictions and store in MaskData
|
| 322 |
+
data = MaskData(
|
| 323 |
+
masks=masks.flatten(0, 1),
|
| 324 |
+
iou_preds=iou_preds.flatten(0, 1),
|
| 325 |
+
points=points.repeat_interleave(masks.shape[1], dim=0),
|
| 326 |
+
low_res_masks=low_res_masks.flatten(0, 1),
|
| 327 |
+
)
|
| 328 |
+
del masks
|
| 329 |
+
|
| 330 |
+
if not self.use_m2m:
|
| 331 |
+
# Filter by predicted IoU
|
| 332 |
+
if self.pred_iou_thresh > 0.0:
|
| 333 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 334 |
+
data.filter(keep_mask)
|
| 335 |
+
|
| 336 |
+
# Calculate and filter by stability score
|
| 337 |
+
data["stability_score"] = calculate_stability_score(
|
| 338 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 339 |
+
)
|
| 340 |
+
if self.stability_score_thresh > 0.0:
|
| 341 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 342 |
+
data.filter(keep_mask)
|
| 343 |
+
else:
|
| 344 |
+
# One step refinement using previous mask predictions
|
| 345 |
+
in_points = self.predictor._transforms.transform_coords(
|
| 346 |
+
data["points"], normalize=normalize, orig_hw=im_size
|
| 347 |
+
)
|
| 348 |
+
labels = torch.ones(
|
| 349 |
+
in_points.shape[0], dtype=torch.int, device=in_points.device
|
| 350 |
+
)
|
| 351 |
+
masks, ious = self.refine_with_m2m(
|
| 352 |
+
in_points, labels, data["low_res_masks"], self.points_per_batch
|
| 353 |
+
)
|
| 354 |
+
data["masks"] = masks.squeeze(1)
|
| 355 |
+
data["iou_preds"] = ious.squeeze(1)
|
| 356 |
+
|
| 357 |
+
if self.pred_iou_thresh > 0.0:
|
| 358 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
| 359 |
+
data.filter(keep_mask)
|
| 360 |
+
|
| 361 |
+
data["stability_score"] = calculate_stability_score(
|
| 362 |
+
data["masks"], self.mask_threshold, self.stability_score_offset
|
| 363 |
+
)
|
| 364 |
+
if self.stability_score_thresh > 0.0:
|
| 365 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
| 366 |
+
data.filter(keep_mask)
|
| 367 |
+
|
| 368 |
+
# Threshold masks and calculate boxes
|
| 369 |
+
data["masks"] = data["masks"] > self.mask_threshold
|
| 370 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
| 371 |
+
|
| 372 |
+
# Filter boxes that touch crop boundaries
|
| 373 |
+
keep_mask = ~is_box_near_crop_edge(
|
| 374 |
+
data["boxes"], crop_box, [0, 0, orig_w, orig_h]
|
| 375 |
+
)
|
| 376 |
+
if not torch.all(keep_mask):
|
| 377 |
+
data.filter(keep_mask)
|
| 378 |
+
|
| 379 |
+
# Compress to RLE
|
| 380 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
| 381 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
| 382 |
+
del data["masks"]
|
| 383 |
+
|
| 384 |
+
return data
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def postprocess_small_regions(
|
| 388 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
| 389 |
+
) -> MaskData:
|
| 390 |
+
"""
|
| 391 |
+
Removes small disconnected regions and holes in masks, then reruns
|
| 392 |
+
box NMS to remove any new duplicates.
|
| 393 |
+
|
| 394 |
+
Edits mask_data in place.
|
| 395 |
+
|
| 396 |
+
Requires open-cv as a dependency.
|
| 397 |
+
"""
|
| 398 |
+
if len(mask_data["rles"]) == 0:
|
| 399 |
+
return mask_data
|
| 400 |
+
|
| 401 |
+
# Filter small disconnected regions and holes
|
| 402 |
+
new_masks = []
|
| 403 |
+
scores = []
|
| 404 |
+
for rle in mask_data["rles"]:
|
| 405 |
+
mask = rle_to_mask(rle)
|
| 406 |
+
|
| 407 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
| 408 |
+
unchanged = not changed
|
| 409 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
| 410 |
+
unchanged = unchanged and not changed
|
| 411 |
+
|
| 412 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
| 413 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
| 414 |
+
# so NMS will prefer ones that didn't need postprocessing
|
| 415 |
+
scores.append(float(unchanged))
|
| 416 |
+
|
| 417 |
+
# Recalculate boxes and remove any new duplicates
|
| 418 |
+
masks = torch.cat(new_masks, dim=0)
|
| 419 |
+
boxes = batched_mask_to_box(masks)
|
| 420 |
+
keep_by_nms = batched_nms(
|
| 421 |
+
boxes.float(),
|
| 422 |
+
torch.as_tensor(scores),
|
| 423 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
| 424 |
+
iou_threshold=nms_thresh,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
# Only recalculate RLEs for masks that have changed
|
| 428 |
+
for i_mask in keep_by_nms:
|
| 429 |
+
if scores[i_mask] == 0.0:
|
| 430 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
| 431 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
| 432 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
| 433 |
+
mask_data.filter(keep_by_nms)
|
| 434 |
+
|
| 435 |
+
return mask_data
|
| 436 |
+
|
| 437 |
+
def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
|
| 438 |
+
new_masks = []
|
| 439 |
+
new_iou_preds = []
|
| 440 |
+
|
| 441 |
+
for cur_points, cur_point_labels, low_res_mask in batch_iterator(
|
| 442 |
+
points_per_batch, points, point_labels, low_res_masks
|
| 443 |
+
):
|
| 444 |
+
best_masks, best_iou_preds, _ = self.predictor._predict(
|
| 445 |
+
cur_points[:, None, :],
|
| 446 |
+
cur_point_labels[:, None],
|
| 447 |
+
mask_input=low_res_mask[:, None, :],
|
| 448 |
+
multimask_output=False,
|
| 449 |
+
return_logits=True,
|
| 450 |
+
)
|
| 451 |
+
new_masks.append(best_masks)
|
| 452 |
+
new_iou_preds.append(best_iou_preds)
|
| 453 |
+
masks = torch.cat(new_masks, dim=0)
|
| 454 |
+
return masks, torch.cat(new_iou_preds, dim=0)
|
build_sam.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from hydra import compose
|
| 12 |
+
from hydra.utils import instantiate
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
|
| 15 |
+
import sam2
|
| 16 |
+
|
| 17 |
+
# Check if the user is running Python from the parent directory of the sam2 repo
|
| 18 |
+
# (i.e. the directory where this repo is cloned into) -- this is not supported since
|
| 19 |
+
# it could shadow the sam2 package and cause issues.
|
| 20 |
+
if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
|
| 21 |
+
# If the user has "sam2/sam2" in their path, they are likey importing the repo itself
|
| 22 |
+
# as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
|
| 23 |
+
# This typically happens because the user is running Python from the parent directory
|
| 24 |
+
# that contains the sam2 repo they cloned.
|
| 25 |
+
raise RuntimeError(
|
| 26 |
+
"You're likely running Python from the parent directory of the sam2 repository "
|
| 27 |
+
"(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
|
| 28 |
+
"This is not supported since the `sam2` Python package could be shadowed by the "
|
| 29 |
+
"repository name (the repository is also named `sam2` and contains the Python package "
|
| 30 |
+
"in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
|
| 31 |
+
"rather than its parent dir, or from your home directory) after installing SAM 2."
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
HF_MODEL_ID_TO_FILENAMES = {
|
| 36 |
+
"facebook/sam2-hiera-tiny": (
|
| 37 |
+
"configs/sam2/sam2_hiera_t.yaml",
|
| 38 |
+
"sam2_hiera_tiny.pt",
|
| 39 |
+
),
|
| 40 |
+
"facebook/sam2-hiera-small": (
|
| 41 |
+
"configs/sam2/sam2_hiera_s.yaml",
|
| 42 |
+
"sam2_hiera_small.pt",
|
| 43 |
+
),
|
| 44 |
+
"facebook/sam2-hiera-base-plus": (
|
| 45 |
+
"configs/sam2/sam2_hiera_b+.yaml",
|
| 46 |
+
"sam2_hiera_base_plus.pt",
|
| 47 |
+
),
|
| 48 |
+
"facebook/sam2-hiera-large": (
|
| 49 |
+
"configs/sam2/sam2_hiera_l.yaml",
|
| 50 |
+
"sam2_hiera_large.pt",
|
| 51 |
+
),
|
| 52 |
+
"facebook/sam2.1-hiera-tiny": (
|
| 53 |
+
"configs/sam2.1/sam2.1_hiera_t.yaml",
|
| 54 |
+
"sam2.1_hiera_tiny.pt",
|
| 55 |
+
),
|
| 56 |
+
"facebook/sam2.1-hiera-small": (
|
| 57 |
+
"configs/sam2.1/sam2.1_hiera_s.yaml",
|
| 58 |
+
"sam2.1_hiera_small.pt",
|
| 59 |
+
),
|
| 60 |
+
"facebook/sam2.1-hiera-base-plus": (
|
| 61 |
+
"configs/sam2.1/sam2.1_hiera_b+.yaml",
|
| 62 |
+
"sam2.1_hiera_base_plus.pt",
|
| 63 |
+
),
|
| 64 |
+
"facebook/sam2.1-hiera-large": (
|
| 65 |
+
"configs/sam2.1/sam2.1_hiera_l.yaml",
|
| 66 |
+
"sam2.1_hiera_large.pt",
|
| 67 |
+
),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_sam2(
|
| 72 |
+
config_file,
|
| 73 |
+
ckpt_path=None,
|
| 74 |
+
device="cuda",
|
| 75 |
+
mode="eval",
|
| 76 |
+
hydra_overrides_extra=[],
|
| 77 |
+
apply_postprocessing=True,
|
| 78 |
+
**kwargs,
|
| 79 |
+
):
|
| 80 |
+
|
| 81 |
+
if apply_postprocessing:
|
| 82 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 83 |
+
hydra_overrides_extra += [
|
| 84 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 85 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 86 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 87 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 88 |
+
]
|
| 89 |
+
# Read config and init model
|
| 90 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
|
| 91 |
+
OmegaConf.resolve(cfg)
|
| 92 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 93 |
+
_load_checkpoint(model, ckpt_path)
|
| 94 |
+
model = model.to(device)
|
| 95 |
+
if mode == "eval":
|
| 96 |
+
model.eval()
|
| 97 |
+
return model
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_sam2_video_predictor(
|
| 101 |
+
config_file,
|
| 102 |
+
ckpt_path=None,
|
| 103 |
+
device="cuda",
|
| 104 |
+
mode="eval",
|
| 105 |
+
hydra_overrides_extra=[],
|
| 106 |
+
apply_postprocessing=True,
|
| 107 |
+
**kwargs,
|
| 108 |
+
):
|
| 109 |
+
hydra_overrides = [
|
| 110 |
+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
|
| 111 |
+
]
|
| 112 |
+
if apply_postprocessing:
|
| 113 |
+
hydra_overrides_extra = hydra_overrides_extra.copy()
|
| 114 |
+
hydra_overrides_extra += [
|
| 115 |
+
# dynamically fall back to multi-mask if the single mask is not stable
|
| 116 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
|
| 117 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
|
| 118 |
+
"++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
|
| 119 |
+
# the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
|
| 120 |
+
"++model.binarize_mask_from_pts_for_mem_enc=true",
|
| 121 |
+
# fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
|
| 122 |
+
"++model.fill_hole_area=8",
|
| 123 |
+
]
|
| 124 |
+
hydra_overrides.extend(hydra_overrides_extra)
|
| 125 |
+
|
| 126 |
+
# Read config and init model
|
| 127 |
+
cfg = compose(config_name=config_file, overrides=hydra_overrides)
|
| 128 |
+
OmegaConf.resolve(cfg)
|
| 129 |
+
model = instantiate(cfg.model, _recursive_=True)
|
| 130 |
+
_load_checkpoint(model, ckpt_path)
|
| 131 |
+
model = model.to(device)
|
| 132 |
+
if mode == "eval":
|
| 133 |
+
model.eval()
|
| 134 |
+
return model
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _hf_download(model_id):
|
| 138 |
+
from huggingface_hub import hf_hub_download
|
| 139 |
+
|
| 140 |
+
config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
|
| 141 |
+
ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
|
| 142 |
+
return config_name, ckpt_path
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def build_sam2_hf(model_id, **kwargs):
|
| 146 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 147 |
+
return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def build_sam2_video_predictor_hf(model_id, **kwargs):
|
| 151 |
+
config_name, ckpt_path = _hf_download(model_id)
|
| 152 |
+
return build_sam2_video_predictor(
|
| 153 |
+
config_file=config_name, ckpt_path=ckpt_path, **kwargs
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _load_checkpoint(model, ckpt_path):
|
| 158 |
+
if ckpt_path is not None:
|
| 159 |
+
sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
|
| 160 |
+
missing_keys, unexpected_keys = model.load_state_dict(sd)
|
| 161 |
+
if missing_keys:
|
| 162 |
+
logging.error(missing_keys)
|
| 163 |
+
raise RuntimeError()
|
| 164 |
+
if unexpected_keys:
|
| 165 |
+
logging.error(unexpected_keys)
|
| 166 |
+
raise RuntimeError()
|
| 167 |
+
logging.info("Loaded checkpoint sucessfully")
|
sam2_hiera_b+.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
configs/sam2/sam2_hiera_b+.yaml
|
sam2_hiera_l.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
configs/sam2/sam2_hiera_l.yaml
|
sam2_hiera_s.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
configs/sam2/sam2_hiera_s.yaml
|
sam2_hiera_t.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
configs/sam2/sam2_hiera_t.yaml
|
sam2_image_predictor.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
from typing import List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
from PIL.Image import Image
|
| 14 |
+
|
| 15 |
+
from sam2.modeling.sam2_base import SAM2Base
|
| 16 |
+
|
| 17 |
+
from sam2.utils.transforms import SAM2Transforms
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SAM2ImagePredictor:
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
sam_model: SAM2Base,
|
| 24 |
+
mask_threshold=0.0,
|
| 25 |
+
max_hole_area=0.0,
|
| 26 |
+
max_sprinkle_area=0.0,
|
| 27 |
+
**kwargs,
|
| 28 |
+
) -> None:
|
| 29 |
+
"""
|
| 30 |
+
Uses SAM-2 to calculate the image embedding for an image, and then
|
| 31 |
+
allow repeated, efficient mask prediction given prompts.
|
| 32 |
+
|
| 33 |
+
Arguments:
|
| 34 |
+
sam_model (Sam-2): The model to use for mask prediction.
|
| 35 |
+
mask_threshold (float): The threshold to use when converting mask logits
|
| 36 |
+
to binary masks. Masks are thresholded at 0 by default.
|
| 37 |
+
max_hole_area (int): If max_hole_area > 0, we fill small holes in up to
|
| 38 |
+
the maximum area of max_hole_area in low_res_masks.
|
| 39 |
+
max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to
|
| 40 |
+
the maximum area of max_sprinkle_area in low_res_masks.
|
| 41 |
+
"""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.model = sam_model
|
| 44 |
+
self._transforms = SAM2Transforms(
|
| 45 |
+
resolution=self.model.image_size,
|
| 46 |
+
mask_threshold=mask_threshold,
|
| 47 |
+
max_hole_area=max_hole_area,
|
| 48 |
+
max_sprinkle_area=max_sprinkle_area,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Predictor state
|
| 52 |
+
self._is_image_set = False
|
| 53 |
+
self._features = None
|
| 54 |
+
self._orig_hw = None
|
| 55 |
+
# Whether the predictor is set for single image or a batch of images
|
| 56 |
+
self._is_batch = False
|
| 57 |
+
|
| 58 |
+
# Predictor config
|
| 59 |
+
self.mask_threshold = mask_threshold
|
| 60 |
+
|
| 61 |
+
# Spatial dim for backbone feature maps
|
| 62 |
+
self._bb_feat_sizes = [
|
| 63 |
+
(256, 256),
|
| 64 |
+
(128, 128),
|
| 65 |
+
(64, 64),
|
| 66 |
+
]
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
|
| 70 |
+
"""
|
| 71 |
+
Load a pretrained model from the Hugging Face hub.
|
| 72 |
+
|
| 73 |
+
Arguments:
|
| 74 |
+
model_id (str): The Hugging Face repository ID.
|
| 75 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
(SAM2ImagePredictor): The loaded model.
|
| 79 |
+
"""
|
| 80 |
+
from sam2.build_sam import build_sam2_hf
|
| 81 |
+
|
| 82 |
+
sam_model = build_sam2_hf(model_id, **kwargs)
|
| 83 |
+
return cls(sam_model, **kwargs)
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def set_image(
|
| 87 |
+
self,
|
| 88 |
+
image: Union[np.ndarray, Image],
|
| 89 |
+
) -> None:
|
| 90 |
+
"""
|
| 91 |
+
Calculates the image embeddings for the provided image, allowing
|
| 92 |
+
masks to be predicted with the 'predict' method.
|
| 93 |
+
|
| 94 |
+
Arguments:
|
| 95 |
+
image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
|
| 96 |
+
with pixel values in [0, 255].
|
| 97 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
| 98 |
+
"""
|
| 99 |
+
self.reset_predictor()
|
| 100 |
+
# Transform the image to the form expected by the model
|
| 101 |
+
if isinstance(image, np.ndarray):
|
| 102 |
+
logging.info("For numpy array image, we assume (HxWxC) format")
|
| 103 |
+
self._orig_hw = [image.shape[:2]]
|
| 104 |
+
elif isinstance(image, Image):
|
| 105 |
+
w, h = image.size
|
| 106 |
+
self._orig_hw = [(h, w)]
|
| 107 |
+
else:
|
| 108 |
+
raise NotImplementedError("Image format not supported")
|
| 109 |
+
|
| 110 |
+
input_image = self._transforms(image)
|
| 111 |
+
input_image = input_image[None, ...].to(self.device)
|
| 112 |
+
|
| 113 |
+
assert (
|
| 114 |
+
len(input_image.shape) == 4 and input_image.shape[1] == 3
|
| 115 |
+
), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
|
| 116 |
+
logging.info("Computing image embeddings for the provided image...")
|
| 117 |
+
backbone_out = self.model.forward_image(input_image)
|
| 118 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
| 119 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
| 120 |
+
if self.model.directly_add_no_mem_embed:
|
| 121 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 122 |
+
|
| 123 |
+
feats = [
|
| 124 |
+
feat.permute(1, 2, 0).view(1, -1, *feat_size)
|
| 125 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
| 126 |
+
][::-1]
|
| 127 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 128 |
+
self._is_image_set = True
|
| 129 |
+
logging.info("Image embeddings computed.")
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def set_image_batch(
|
| 133 |
+
self,
|
| 134 |
+
image_list: List[Union[np.ndarray]],
|
| 135 |
+
) -> None:
|
| 136 |
+
"""
|
| 137 |
+
Calculates the image embeddings for the provided image batch, allowing
|
| 138 |
+
masks to be predicted with the 'predict_batch' method.
|
| 139 |
+
|
| 140 |
+
Arguments:
|
| 141 |
+
image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
|
| 142 |
+
with pixel values in [0, 255].
|
| 143 |
+
"""
|
| 144 |
+
self.reset_predictor()
|
| 145 |
+
assert isinstance(image_list, list)
|
| 146 |
+
self._orig_hw = []
|
| 147 |
+
for image in image_list:
|
| 148 |
+
assert isinstance(
|
| 149 |
+
image, np.ndarray
|
| 150 |
+
), "Images are expected to be an np.ndarray in RGB format, and of shape HWC"
|
| 151 |
+
self._orig_hw.append(image.shape[:2])
|
| 152 |
+
# Transform the image to the form expected by the model
|
| 153 |
+
img_batch = self._transforms.forward_batch(image_list)
|
| 154 |
+
img_batch = img_batch.to(self.device)
|
| 155 |
+
batch_size = img_batch.shape[0]
|
| 156 |
+
assert (
|
| 157 |
+
len(img_batch.shape) == 4 and img_batch.shape[1] == 3
|
| 158 |
+
), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
|
| 159 |
+
logging.info("Computing image embeddings for the provided images...")
|
| 160 |
+
backbone_out = self.model.forward_image(img_batch)
|
| 161 |
+
_, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
|
| 162 |
+
# Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
|
| 163 |
+
if self.model.directly_add_no_mem_embed:
|
| 164 |
+
vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
|
| 165 |
+
|
| 166 |
+
feats = [
|
| 167 |
+
feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
|
| 168 |
+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
|
| 169 |
+
][::-1]
|
| 170 |
+
self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
|
| 171 |
+
self._is_image_set = True
|
| 172 |
+
self._is_batch = True
|
| 173 |
+
logging.info("Image embeddings computed.")
|
| 174 |
+
|
| 175 |
+
def predict_batch(
|
| 176 |
+
self,
|
| 177 |
+
point_coords_batch: List[np.ndarray] = None,
|
| 178 |
+
point_labels_batch: List[np.ndarray] = None,
|
| 179 |
+
box_batch: List[np.ndarray] = None,
|
| 180 |
+
mask_input_batch: List[np.ndarray] = None,
|
| 181 |
+
multimask_output: bool = True,
|
| 182 |
+
return_logits: bool = False,
|
| 183 |
+
normalize_coords=True,
|
| 184 |
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
| 185 |
+
"""This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
|
| 186 |
+
It returns a tuple of lists of masks, ious, and low_res_masks_logits.
|
| 187 |
+
"""
|
| 188 |
+
assert self._is_batch, "This function should only be used when in batched mode"
|
| 189 |
+
if not self._is_image_set:
|
| 190 |
+
raise RuntimeError(
|
| 191 |
+
"An image must be set with .set_image_batch(...) before mask prediction."
|
| 192 |
+
)
|
| 193 |
+
num_images = len(self._features["image_embed"])
|
| 194 |
+
all_masks = []
|
| 195 |
+
all_ious = []
|
| 196 |
+
all_low_res_masks = []
|
| 197 |
+
for img_idx in range(num_images):
|
| 198 |
+
# Transform input prompts
|
| 199 |
+
point_coords = (
|
| 200 |
+
point_coords_batch[img_idx] if point_coords_batch is not None else None
|
| 201 |
+
)
|
| 202 |
+
point_labels = (
|
| 203 |
+
point_labels_batch[img_idx] if point_labels_batch is not None else None
|
| 204 |
+
)
|
| 205 |
+
box = box_batch[img_idx] if box_batch is not None else None
|
| 206 |
+
mask_input = (
|
| 207 |
+
mask_input_batch[img_idx] if mask_input_batch is not None else None
|
| 208 |
+
)
|
| 209 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 210 |
+
point_coords,
|
| 211 |
+
point_labels,
|
| 212 |
+
box,
|
| 213 |
+
mask_input,
|
| 214 |
+
normalize_coords,
|
| 215 |
+
img_idx=img_idx,
|
| 216 |
+
)
|
| 217 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
| 218 |
+
unnorm_coords,
|
| 219 |
+
labels,
|
| 220 |
+
unnorm_box,
|
| 221 |
+
mask_input,
|
| 222 |
+
multimask_output,
|
| 223 |
+
return_logits=return_logits,
|
| 224 |
+
img_idx=img_idx,
|
| 225 |
+
)
|
| 226 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 227 |
+
iou_predictions_np = (
|
| 228 |
+
iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
| 229 |
+
)
|
| 230 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 231 |
+
all_masks.append(masks_np)
|
| 232 |
+
all_ious.append(iou_predictions_np)
|
| 233 |
+
all_low_res_masks.append(low_res_masks_np)
|
| 234 |
+
|
| 235 |
+
return all_masks, all_ious, all_low_res_masks
|
| 236 |
+
|
| 237 |
+
def predict(
|
| 238 |
+
self,
|
| 239 |
+
point_coords: Optional[np.ndarray] = None,
|
| 240 |
+
point_labels: Optional[np.ndarray] = None,
|
| 241 |
+
box: Optional[np.ndarray] = None,
|
| 242 |
+
mask_input: Optional[np.ndarray] = None,
|
| 243 |
+
multimask_output: bool = True,
|
| 244 |
+
return_logits: bool = False,
|
| 245 |
+
normalize_coords=True,
|
| 246 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 247 |
+
"""
|
| 248 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 249 |
+
|
| 250 |
+
Arguments:
|
| 251 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
| 252 |
+
model. Each point is in (X,Y) in pixels.
|
| 253 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
| 254 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 255 |
+
background point.
|
| 256 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
| 257 |
+
model, in XYXY format.
|
| 258 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 259 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
| 260 |
+
for SAM, H=W=256.
|
| 261 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 262 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 263 |
+
produce better masks than a single prediction. If only a single
|
| 264 |
+
mask is needed, the model's predicted quality score can be used
|
| 265 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 266 |
+
input prompts, multimask_output=False can give better results.
|
| 267 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 268 |
+
instead of a binary mask.
|
| 269 |
+
normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
| 273 |
+
number of masks, and (H, W) is the original image size.
|
| 274 |
+
(np.ndarray): An array of length C containing the model's
|
| 275 |
+
predictions for the quality of each mask.
|
| 276 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
| 277 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
| 278 |
+
a subsequent iteration as mask input.
|
| 279 |
+
"""
|
| 280 |
+
if not self._is_image_set:
|
| 281 |
+
raise RuntimeError(
|
| 282 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Transform input prompts
|
| 286 |
+
|
| 287 |
+
mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
|
| 288 |
+
point_coords, point_labels, box, mask_input, normalize_coords
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
masks, iou_predictions, low_res_masks = self._predict(
|
| 292 |
+
unnorm_coords,
|
| 293 |
+
labels,
|
| 294 |
+
unnorm_box,
|
| 295 |
+
mask_input,
|
| 296 |
+
multimask_output,
|
| 297 |
+
return_logits=return_logits,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
masks_np = masks.squeeze(0).float().detach().cpu().numpy()
|
| 301 |
+
iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
|
| 302 |
+
low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
|
| 303 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
| 304 |
+
|
| 305 |
+
def _prep_prompts(
|
| 306 |
+
self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
|
| 307 |
+
):
|
| 308 |
+
|
| 309 |
+
unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
|
| 310 |
+
if point_coords is not None:
|
| 311 |
+
assert (
|
| 312 |
+
point_labels is not None
|
| 313 |
+
), "point_labels must be supplied if point_coords is supplied."
|
| 314 |
+
point_coords = torch.as_tensor(
|
| 315 |
+
point_coords, dtype=torch.float, device=self.device
|
| 316 |
+
)
|
| 317 |
+
unnorm_coords = self._transforms.transform_coords(
|
| 318 |
+
point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 319 |
+
)
|
| 320 |
+
labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
|
| 321 |
+
if len(unnorm_coords.shape) == 2:
|
| 322 |
+
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
|
| 323 |
+
if box is not None:
|
| 324 |
+
box = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
| 325 |
+
unnorm_box = self._transforms.transform_boxes(
|
| 326 |
+
box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
|
| 327 |
+
) # Bx2x2
|
| 328 |
+
if mask_logits is not None:
|
| 329 |
+
mask_input = torch.as_tensor(
|
| 330 |
+
mask_logits, dtype=torch.float, device=self.device
|
| 331 |
+
)
|
| 332 |
+
if len(mask_input.shape) == 3:
|
| 333 |
+
mask_input = mask_input[None, :, :, :]
|
| 334 |
+
return mask_input, unnorm_coords, labels, unnorm_box
|
| 335 |
+
|
| 336 |
+
@torch.no_grad()
|
| 337 |
+
def _predict(
|
| 338 |
+
self,
|
| 339 |
+
point_coords: Optional[torch.Tensor],
|
| 340 |
+
point_labels: Optional[torch.Tensor],
|
| 341 |
+
boxes: Optional[torch.Tensor] = None,
|
| 342 |
+
mask_input: Optional[torch.Tensor] = None,
|
| 343 |
+
multimask_output: bool = True,
|
| 344 |
+
return_logits: bool = False,
|
| 345 |
+
img_idx: int = -1,
|
| 346 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 347 |
+
"""
|
| 348 |
+
Predict masks for the given input prompts, using the currently set image.
|
| 349 |
+
Input prompts are batched torch tensors and are expected to already be
|
| 350 |
+
transformed to the input frame using SAM2Transforms.
|
| 351 |
+
|
| 352 |
+
Arguments:
|
| 353 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
| 354 |
+
model. Each point is in (X,Y) in pixels.
|
| 355 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
| 356 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
| 357 |
+
background point.
|
| 358 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
| 359 |
+
model, in XYXY format.
|
| 360 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
| 361 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
| 362 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
| 363 |
+
predict method do not need further transformation.
|
| 364 |
+
multimask_output (bool): If true, the model will return three masks.
|
| 365 |
+
For ambiguous input prompts (such as a single click), this will often
|
| 366 |
+
produce better masks than a single prediction. If only a single
|
| 367 |
+
mask is needed, the model's predicted quality score can be used
|
| 368 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
| 369 |
+
input prompts, multimask_output=False can give better results.
|
| 370 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
| 371 |
+
instead of a binary mask.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
| 375 |
+
number of masks, and (H, W) is the original image size.
|
| 376 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
| 377 |
+
predictions for the quality of each mask.
|
| 378 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
| 379 |
+
of masks and H=W=256. These low res logits can be passed to
|
| 380 |
+
a subsequent iteration as mask input.
|
| 381 |
+
"""
|
| 382 |
+
if not self._is_image_set:
|
| 383 |
+
raise RuntimeError(
|
| 384 |
+
"An image must be set with .set_image(...) before mask prediction."
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
if point_coords is not None:
|
| 388 |
+
concat_points = (point_coords, point_labels)
|
| 389 |
+
else:
|
| 390 |
+
concat_points = None
|
| 391 |
+
|
| 392 |
+
# Embed prompts
|
| 393 |
+
if boxes is not None:
|
| 394 |
+
box_coords = boxes.reshape(-1, 2, 2)
|
| 395 |
+
box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
|
| 396 |
+
box_labels = box_labels.repeat(boxes.size(0), 1)
|
| 397 |
+
# we merge "boxes" and "points" into a single "concat_points" input (where
|
| 398 |
+
# boxes are added at the beginning) to sam_prompt_encoder
|
| 399 |
+
if concat_points is not None:
|
| 400 |
+
concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
|
| 401 |
+
concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
|
| 402 |
+
concat_points = (concat_coords, concat_labels)
|
| 403 |
+
else:
|
| 404 |
+
concat_points = (box_coords, box_labels)
|
| 405 |
+
|
| 406 |
+
sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
|
| 407 |
+
points=concat_points,
|
| 408 |
+
boxes=None,
|
| 409 |
+
masks=mask_input,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Predict masks
|
| 413 |
+
batched_mode = (
|
| 414 |
+
concat_points is not None and concat_points[0].shape[0] > 1
|
| 415 |
+
) # multi object prediction
|
| 416 |
+
high_res_features = [
|
| 417 |
+
feat_level[img_idx].unsqueeze(0)
|
| 418 |
+
for feat_level in self._features["high_res_feats"]
|
| 419 |
+
]
|
| 420 |
+
low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
|
| 421 |
+
image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
|
| 422 |
+
image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
|
| 423 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 424 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 425 |
+
multimask_output=multimask_output,
|
| 426 |
+
repeat_image=batched_mode,
|
| 427 |
+
high_res_features=high_res_features,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Upscale the masks to the original image resolution
|
| 431 |
+
masks = self._transforms.postprocess_masks(
|
| 432 |
+
low_res_masks, self._orig_hw[img_idx]
|
| 433 |
+
)
|
| 434 |
+
low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
|
| 435 |
+
if not return_logits:
|
| 436 |
+
masks = masks > self.mask_threshold
|
| 437 |
+
|
| 438 |
+
return masks, iou_predictions, low_res_masks
|
| 439 |
+
|
| 440 |
+
def get_image_embedding(self) -> torch.Tensor:
|
| 441 |
+
"""
|
| 442 |
+
Returns the image embeddings for the currently set image, with
|
| 443 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
| 444 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
| 445 |
+
"""
|
| 446 |
+
if not self._is_image_set:
|
| 447 |
+
raise RuntimeError(
|
| 448 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
| 449 |
+
)
|
| 450 |
+
assert (
|
| 451 |
+
self._features is not None
|
| 452 |
+
), "Features must exist if an image has been set."
|
| 453 |
+
return self._features["image_embed"]
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def device(self) -> torch.device:
|
| 457 |
+
return self.model.device
|
| 458 |
+
|
| 459 |
+
def reset_predictor(self) -> None:
|
| 460 |
+
"""
|
| 461 |
+
Resets the image embeddings and other state variables.
|
| 462 |
+
"""
|
| 463 |
+
self._is_image_set = False
|
| 464 |
+
self._features = None
|
| 465 |
+
self._orig_hw = None
|
| 466 |
+
self._is_batch = False
|
sam2_video_predictor.py
ADDED
|
@@ -0,0 +1,1172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import warnings
|
| 8 |
+
from collections import OrderedDict
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
|
| 15 |
+
from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SAM2VideoPredictor(SAM2Base):
|
| 19 |
+
"""The predictor class to handle user interactions and manage inference states."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
fill_hole_area=0,
|
| 24 |
+
# whether to apply non-overlapping constraints on the output object masks
|
| 25 |
+
non_overlap_masks=False,
|
| 26 |
+
# whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
|
| 27 |
+
# note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
|
| 28 |
+
clear_non_cond_mem_around_input=False,
|
| 29 |
+
# whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
|
| 30 |
+
clear_non_cond_mem_for_multi_obj=False,
|
| 31 |
+
# if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
|
| 32 |
+
# if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
|
| 33 |
+
add_all_frames_to_correct_as_cond=False,
|
| 34 |
+
**kwargs,
|
| 35 |
+
):
|
| 36 |
+
super().__init__(**kwargs)
|
| 37 |
+
self.fill_hole_area = fill_hole_area
|
| 38 |
+
self.non_overlap_masks = non_overlap_masks
|
| 39 |
+
self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
|
| 40 |
+
self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
|
| 41 |
+
self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
|
| 42 |
+
|
| 43 |
+
@torch.inference_mode()
|
| 44 |
+
def init_state(
|
| 45 |
+
self,
|
| 46 |
+
video_path,
|
| 47 |
+
offload_video_to_cpu=False,
|
| 48 |
+
offload_state_to_cpu=False,
|
| 49 |
+
async_loading_frames=False,
|
| 50 |
+
):
|
| 51 |
+
"""Initialize an inference state."""
|
| 52 |
+
compute_device = self.device # device of the model
|
| 53 |
+
images, video_height, video_width = load_video_frames(
|
| 54 |
+
video_path=video_path,
|
| 55 |
+
image_size=self.image_size,
|
| 56 |
+
offload_video_to_cpu=offload_video_to_cpu,
|
| 57 |
+
async_loading_frames=async_loading_frames,
|
| 58 |
+
compute_device=compute_device,
|
| 59 |
+
)
|
| 60 |
+
inference_state = {}
|
| 61 |
+
inference_state["images"] = images
|
| 62 |
+
inference_state["num_frames"] = len(images)
|
| 63 |
+
# whether to offload the video frames to CPU memory
|
| 64 |
+
# turning on this option saves the GPU memory with only a very small overhead
|
| 65 |
+
inference_state["offload_video_to_cpu"] = offload_video_to_cpu
|
| 66 |
+
# whether to offload the inference state to CPU memory
|
| 67 |
+
# turning on this option saves the GPU memory at the cost of a lower tracking fps
|
| 68 |
+
# (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
|
| 69 |
+
# and from 24 to 21 when tracking two objects)
|
| 70 |
+
inference_state["offload_state_to_cpu"] = offload_state_to_cpu
|
| 71 |
+
# the original video height and width, used for resizing final output scores
|
| 72 |
+
inference_state["video_height"] = video_height
|
| 73 |
+
inference_state["video_width"] = video_width
|
| 74 |
+
inference_state["device"] = compute_device
|
| 75 |
+
if offload_state_to_cpu:
|
| 76 |
+
inference_state["storage_device"] = torch.device("cpu")
|
| 77 |
+
else:
|
| 78 |
+
inference_state["storage_device"] = compute_device
|
| 79 |
+
# inputs on each frame
|
| 80 |
+
inference_state["point_inputs_per_obj"] = {}
|
| 81 |
+
inference_state["mask_inputs_per_obj"] = {}
|
| 82 |
+
# visual features on a small number of recently visited frames for quick interactions
|
| 83 |
+
inference_state["cached_features"] = {}
|
| 84 |
+
# values that don't change across frames (so we only need to hold one copy of them)
|
| 85 |
+
inference_state["constants"] = {}
|
| 86 |
+
# mapping between client-side object id and model-side object index
|
| 87 |
+
inference_state["obj_id_to_idx"] = OrderedDict()
|
| 88 |
+
inference_state["obj_idx_to_id"] = OrderedDict()
|
| 89 |
+
inference_state["obj_ids"] = []
|
| 90 |
+
# A storage to hold the model's tracking results and states on each frame
|
| 91 |
+
inference_state["output_dict"] = {
|
| 92 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 93 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 94 |
+
}
|
| 95 |
+
# Slice (view) of each object tracking results, sharing the same memory with "output_dict"
|
| 96 |
+
inference_state["output_dict_per_obj"] = {}
|
| 97 |
+
# A temporary storage to hold new outputs when user interact with a frame
|
| 98 |
+
# to add clicks or mask (it's merged into "output_dict" before propagation starts)
|
| 99 |
+
inference_state["temp_output_dict_per_obj"] = {}
|
| 100 |
+
# Frames that already holds consolidated outputs from click or mask inputs
|
| 101 |
+
# (we directly use their consolidated outputs during tracking)
|
| 102 |
+
inference_state["consolidated_frame_inds"] = {
|
| 103 |
+
"cond_frame_outputs": set(), # set containing frame indices
|
| 104 |
+
"non_cond_frame_outputs": set(), # set containing frame indices
|
| 105 |
+
}
|
| 106 |
+
# metadata for each tracking frame (e.g. which direction it's tracked)
|
| 107 |
+
inference_state["tracking_has_started"] = False
|
| 108 |
+
inference_state["frames_already_tracked"] = {}
|
| 109 |
+
# Warm up the visual backbone and cache the image feature on frame 0
|
| 110 |
+
self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
|
| 111 |
+
return inference_state
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
|
| 115 |
+
"""
|
| 116 |
+
Load a pretrained model from the Hugging Face hub.
|
| 117 |
+
|
| 118 |
+
Arguments:
|
| 119 |
+
model_id (str): The Hugging Face repository ID.
|
| 120 |
+
**kwargs: Additional arguments to pass to the model constructor.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
(SAM2VideoPredictor): The loaded model.
|
| 124 |
+
"""
|
| 125 |
+
from sam2.build_sam import build_sam2_video_predictor_hf
|
| 126 |
+
|
| 127 |
+
sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
|
| 128 |
+
return sam_model
|
| 129 |
+
|
| 130 |
+
def _obj_id_to_idx(self, inference_state, obj_id):
|
| 131 |
+
"""Map client-side object id to model-side object index."""
|
| 132 |
+
obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
|
| 133 |
+
if obj_idx is not None:
|
| 134 |
+
return obj_idx
|
| 135 |
+
|
| 136 |
+
# This is a new object id not sent to the server before. We only allow adding
|
| 137 |
+
# new objects *before* the tracking starts.
|
| 138 |
+
allow_new_object = not inference_state["tracking_has_started"]
|
| 139 |
+
if allow_new_object:
|
| 140 |
+
# get the next object slot
|
| 141 |
+
obj_idx = len(inference_state["obj_id_to_idx"])
|
| 142 |
+
inference_state["obj_id_to_idx"][obj_id] = obj_idx
|
| 143 |
+
inference_state["obj_idx_to_id"][obj_idx] = obj_id
|
| 144 |
+
inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
|
| 145 |
+
# set up input and output structures for this object
|
| 146 |
+
inference_state["point_inputs_per_obj"][obj_idx] = {}
|
| 147 |
+
inference_state["mask_inputs_per_obj"][obj_idx] = {}
|
| 148 |
+
inference_state["output_dict_per_obj"][obj_idx] = {
|
| 149 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 150 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 151 |
+
}
|
| 152 |
+
inference_state["temp_output_dict_per_obj"][obj_idx] = {
|
| 153 |
+
"cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 154 |
+
"non_cond_frame_outputs": {}, # dict containing {frame_idx: <out>}
|
| 155 |
+
}
|
| 156 |
+
return obj_idx
|
| 157 |
+
else:
|
| 158 |
+
raise RuntimeError(
|
| 159 |
+
f"Cannot add new object id {obj_id} after tracking starts. "
|
| 160 |
+
f"All existing object ids: {inference_state['obj_ids']}. "
|
| 161 |
+
f"Please call 'reset_state' to restart from scratch."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def _obj_idx_to_id(self, inference_state, obj_idx):
|
| 165 |
+
"""Map model-side object index to client-side object id."""
|
| 166 |
+
return inference_state["obj_idx_to_id"][obj_idx]
|
| 167 |
+
|
| 168 |
+
def _get_obj_num(self, inference_state):
|
| 169 |
+
"""Get the total number of unique object ids received so far in this session."""
|
| 170 |
+
return len(inference_state["obj_idx_to_id"])
|
| 171 |
+
|
| 172 |
+
@torch.inference_mode()
|
| 173 |
+
def add_new_points_or_box(
|
| 174 |
+
self,
|
| 175 |
+
inference_state,
|
| 176 |
+
frame_idx,
|
| 177 |
+
obj_id,
|
| 178 |
+
points=None,
|
| 179 |
+
labels=None,
|
| 180 |
+
clear_old_points=True,
|
| 181 |
+
normalize_coords=True,
|
| 182 |
+
box=None,
|
| 183 |
+
):
|
| 184 |
+
"""Add new points to a frame."""
|
| 185 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 186 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| 187 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
| 188 |
+
|
| 189 |
+
if (points is not None) != (labels is not None):
|
| 190 |
+
raise ValueError("points and labels must be provided together")
|
| 191 |
+
if points is None and box is None:
|
| 192 |
+
raise ValueError("at least one of points or box must be provided as input")
|
| 193 |
+
|
| 194 |
+
if points is None:
|
| 195 |
+
points = torch.zeros(0, 2, dtype=torch.float32)
|
| 196 |
+
elif not isinstance(points, torch.Tensor):
|
| 197 |
+
points = torch.tensor(points, dtype=torch.float32)
|
| 198 |
+
if labels is None:
|
| 199 |
+
labels = torch.zeros(0, dtype=torch.int32)
|
| 200 |
+
elif not isinstance(labels, torch.Tensor):
|
| 201 |
+
labels = torch.tensor(labels, dtype=torch.int32)
|
| 202 |
+
if points.dim() == 2:
|
| 203 |
+
points = points.unsqueeze(0) # add batch dimension
|
| 204 |
+
if labels.dim() == 1:
|
| 205 |
+
labels = labels.unsqueeze(0) # add batch dimension
|
| 206 |
+
|
| 207 |
+
# If `box` is provided, we add it as the first two points with labels 2 and 3
|
| 208 |
+
# along with the user-provided points (consistent with how SAM 2 is trained).
|
| 209 |
+
if box is not None:
|
| 210 |
+
if not clear_old_points:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"cannot add box without clearing old points, since "
|
| 213 |
+
"box prompt must be provided before any point prompt "
|
| 214 |
+
"(please use clear_old_points=True instead)"
|
| 215 |
+
)
|
| 216 |
+
if inference_state["tracking_has_started"]:
|
| 217 |
+
warnings.warn(
|
| 218 |
+
"You are adding a box after tracking starts. SAM 2 may not always be "
|
| 219 |
+
"able to incorporate a box prompt for *refinement*. If you intend to "
|
| 220 |
+
"use box prompt as an *initial* input before tracking, please call "
|
| 221 |
+
"'reset_state' on the inference state to restart from scratch.",
|
| 222 |
+
category=UserWarning,
|
| 223 |
+
stacklevel=2,
|
| 224 |
+
)
|
| 225 |
+
if not isinstance(box, torch.Tensor):
|
| 226 |
+
box = torch.tensor(box, dtype=torch.float32, device=points.device)
|
| 227 |
+
box_coords = box.reshape(1, 2, 2)
|
| 228 |
+
box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
|
| 229 |
+
box_labels = box_labels.reshape(1, 2)
|
| 230 |
+
points = torch.cat([box_coords, points], dim=1)
|
| 231 |
+
labels = torch.cat([box_labels, labels], dim=1)
|
| 232 |
+
|
| 233 |
+
if normalize_coords:
|
| 234 |
+
video_H = inference_state["video_height"]
|
| 235 |
+
video_W = inference_state["video_width"]
|
| 236 |
+
points = points / torch.tensor([video_W, video_H]).to(points.device)
|
| 237 |
+
# scale the (normalized) coordinates by the model's internal image size
|
| 238 |
+
points = points * self.image_size
|
| 239 |
+
points = points.to(inference_state["device"])
|
| 240 |
+
labels = labels.to(inference_state["device"])
|
| 241 |
+
|
| 242 |
+
if not clear_old_points:
|
| 243 |
+
point_inputs = point_inputs_per_frame.get(frame_idx, None)
|
| 244 |
+
else:
|
| 245 |
+
point_inputs = None
|
| 246 |
+
point_inputs = concat_points(point_inputs, points, labels)
|
| 247 |
+
|
| 248 |
+
point_inputs_per_frame[frame_idx] = point_inputs
|
| 249 |
+
mask_inputs_per_frame.pop(frame_idx, None)
|
| 250 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
| 251 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
| 252 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
| 253 |
+
# the input points will be used to correct the already tracked masks.
|
| 254 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
| 255 |
+
# whether to track in reverse time order
|
| 256 |
+
if is_init_cond_frame:
|
| 257 |
+
reverse = False
|
| 258 |
+
else:
|
| 259 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
| 260 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 261 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 262 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
| 263 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
| 264 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
| 265 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 266 |
+
|
| 267 |
+
# Get any previously predicted mask logits on this object and feed it along with
|
| 268 |
+
# the new clicks into the SAM mask decoder.
|
| 269 |
+
prev_sam_mask_logits = None
|
| 270 |
+
# lookup temporary output dict first, which contains the most recent output
|
| 271 |
+
# (if not found, then lookup conditioning and non-conditioning frame output)
|
| 272 |
+
prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
|
| 273 |
+
if prev_out is None:
|
| 274 |
+
prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
|
| 275 |
+
if prev_out is None:
|
| 276 |
+
prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
|
| 277 |
+
|
| 278 |
+
if prev_out is not None and prev_out["pred_masks"] is not None:
|
| 279 |
+
device = inference_state["device"]
|
| 280 |
+
prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True)
|
| 281 |
+
# Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
|
| 282 |
+
prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
|
| 283 |
+
current_out, _ = self._run_single_frame_inference(
|
| 284 |
+
inference_state=inference_state,
|
| 285 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
| 286 |
+
frame_idx=frame_idx,
|
| 287 |
+
batch_size=1, # run on the slice of a single object
|
| 288 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 289 |
+
point_inputs=point_inputs,
|
| 290 |
+
mask_inputs=None,
|
| 291 |
+
reverse=reverse,
|
| 292 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
| 293 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
| 294 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
| 295 |
+
# them into memory.
|
| 296 |
+
run_mem_encoder=False,
|
| 297 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
| 298 |
+
)
|
| 299 |
+
# Add the output to the output dict (to be used as future memory)
|
| 300 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
| 301 |
+
|
| 302 |
+
# Resize the output mask to the original video resolution
|
| 303 |
+
obj_ids = inference_state["obj_ids"]
|
| 304 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 305 |
+
inference_state,
|
| 306 |
+
frame_idx,
|
| 307 |
+
is_cond=is_cond,
|
| 308 |
+
run_mem_encoder=False,
|
| 309 |
+
consolidate_at_video_res=True,
|
| 310 |
+
)
|
| 311 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 312 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
| 313 |
+
)
|
| 314 |
+
return frame_idx, obj_ids, video_res_masks
|
| 315 |
+
|
| 316 |
+
def add_new_points(self, *args, **kwargs):
|
| 317 |
+
"""Deprecated method. Please use `add_new_points_or_box` instead."""
|
| 318 |
+
return self.add_new_points_or_box(*args, **kwargs)
|
| 319 |
+
|
| 320 |
+
@torch.inference_mode()
|
| 321 |
+
def add_new_mask(
|
| 322 |
+
self,
|
| 323 |
+
inference_state,
|
| 324 |
+
frame_idx,
|
| 325 |
+
obj_id,
|
| 326 |
+
mask,
|
| 327 |
+
):
|
| 328 |
+
"""Add new mask to a frame."""
|
| 329 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 330 |
+
point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
|
| 331 |
+
mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
|
| 332 |
+
|
| 333 |
+
if not isinstance(mask, torch.Tensor):
|
| 334 |
+
mask = torch.tensor(mask, dtype=torch.bool)
|
| 335 |
+
assert mask.dim() == 2
|
| 336 |
+
mask_H, mask_W = mask.shape
|
| 337 |
+
mask_inputs_orig = mask[None, None] # add batch and channel dimension
|
| 338 |
+
mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
|
| 339 |
+
|
| 340 |
+
# resize the mask if it doesn't match the model's image size
|
| 341 |
+
if mask_H != self.image_size or mask_W != self.image_size:
|
| 342 |
+
mask_inputs = torch.nn.functional.interpolate(
|
| 343 |
+
mask_inputs_orig,
|
| 344 |
+
size=(self.image_size, self.image_size),
|
| 345 |
+
align_corners=False,
|
| 346 |
+
mode="bilinear",
|
| 347 |
+
antialias=True, # use antialias for downsampling
|
| 348 |
+
)
|
| 349 |
+
mask_inputs = (mask_inputs >= 0.5).float()
|
| 350 |
+
else:
|
| 351 |
+
mask_inputs = mask_inputs_orig
|
| 352 |
+
|
| 353 |
+
mask_inputs_per_frame[frame_idx] = mask_inputs
|
| 354 |
+
point_inputs_per_frame.pop(frame_idx, None)
|
| 355 |
+
# If this frame hasn't been tracked before, we treat it as an initial conditioning
|
| 356 |
+
# frame, meaning that the inputs points are to generate segments on this frame without
|
| 357 |
+
# using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
|
| 358 |
+
# the input points will be used to correct the already tracked masks.
|
| 359 |
+
is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
|
| 360 |
+
# whether to track in reverse time order
|
| 361 |
+
if is_init_cond_frame:
|
| 362 |
+
reverse = False
|
| 363 |
+
else:
|
| 364 |
+
reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
|
| 365 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 366 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 367 |
+
# Add a frame to conditioning output if it's an initial conditioning frame or
|
| 368 |
+
# if the model sees all frames receiving clicks/mask as conditioning frames.
|
| 369 |
+
is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
|
| 370 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 371 |
+
|
| 372 |
+
current_out, _ = self._run_single_frame_inference(
|
| 373 |
+
inference_state=inference_state,
|
| 374 |
+
output_dict=obj_output_dict, # run on the slice of a single object
|
| 375 |
+
frame_idx=frame_idx,
|
| 376 |
+
batch_size=1, # run on the slice of a single object
|
| 377 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 378 |
+
point_inputs=None,
|
| 379 |
+
mask_inputs=mask_inputs,
|
| 380 |
+
reverse=reverse,
|
| 381 |
+
# Skip the memory encoder when adding clicks or mask. We execute the memory encoder
|
| 382 |
+
# at the beginning of `propagate_in_video` (after user finalize their clicks). This
|
| 383 |
+
# allows us to enforce non-overlapping constraints on all objects before encoding
|
| 384 |
+
# them into memory.
|
| 385 |
+
run_mem_encoder=False,
|
| 386 |
+
)
|
| 387 |
+
# Add the output to the output dict (to be used as future memory)
|
| 388 |
+
obj_temp_output_dict[storage_key][frame_idx] = current_out
|
| 389 |
+
|
| 390 |
+
# Resize the output mask to the original video resolution
|
| 391 |
+
obj_ids = inference_state["obj_ids"]
|
| 392 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 393 |
+
inference_state,
|
| 394 |
+
frame_idx,
|
| 395 |
+
is_cond=is_cond,
|
| 396 |
+
run_mem_encoder=False,
|
| 397 |
+
consolidate_at_video_res=True,
|
| 398 |
+
)
|
| 399 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 400 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
| 401 |
+
)
|
| 402 |
+
return frame_idx, obj_ids, video_res_masks
|
| 403 |
+
|
| 404 |
+
def _get_orig_video_res_output(self, inference_state, any_res_masks):
|
| 405 |
+
"""
|
| 406 |
+
Resize the object scores to the original video resolution (video_res_masks)
|
| 407 |
+
and apply non-overlapping constraints for final output.
|
| 408 |
+
"""
|
| 409 |
+
device = inference_state["device"]
|
| 410 |
+
video_H = inference_state["video_height"]
|
| 411 |
+
video_W = inference_state["video_width"]
|
| 412 |
+
any_res_masks = any_res_masks.to(device, non_blocking=True)
|
| 413 |
+
if any_res_masks.shape[-2:] == (video_H, video_W):
|
| 414 |
+
video_res_masks = any_res_masks
|
| 415 |
+
else:
|
| 416 |
+
video_res_masks = torch.nn.functional.interpolate(
|
| 417 |
+
any_res_masks,
|
| 418 |
+
size=(video_H, video_W),
|
| 419 |
+
mode="bilinear",
|
| 420 |
+
align_corners=False,
|
| 421 |
+
)
|
| 422 |
+
if self.non_overlap_masks:
|
| 423 |
+
video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
|
| 424 |
+
return any_res_masks, video_res_masks
|
| 425 |
+
|
| 426 |
+
def _consolidate_temp_output_across_obj(
|
| 427 |
+
self,
|
| 428 |
+
inference_state,
|
| 429 |
+
frame_idx,
|
| 430 |
+
is_cond,
|
| 431 |
+
run_mem_encoder,
|
| 432 |
+
consolidate_at_video_res=False,
|
| 433 |
+
):
|
| 434 |
+
"""
|
| 435 |
+
Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
|
| 436 |
+
a frame into a single output for all objects, including
|
| 437 |
+
1) fill any missing objects either from `output_dict_per_obj` (if they exist in
|
| 438 |
+
`output_dict_per_obj` for this frame) or leave them as placeholder values
|
| 439 |
+
(if they don't exist in `output_dict_per_obj` for this frame);
|
| 440 |
+
2) if specified, rerun memory encoder after apply non-overlapping constraints
|
| 441 |
+
on the object scores.
|
| 442 |
+
"""
|
| 443 |
+
batch_size = self._get_obj_num(inference_state)
|
| 444 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 445 |
+
# Optionally, we allow consolidating the temporary outputs at the original
|
| 446 |
+
# video resolution (to provide a better editing experience for mask prompts).
|
| 447 |
+
if consolidate_at_video_res:
|
| 448 |
+
assert not run_mem_encoder, "memory encoder cannot run at video resolution"
|
| 449 |
+
consolidated_H = inference_state["video_height"]
|
| 450 |
+
consolidated_W = inference_state["video_width"]
|
| 451 |
+
consolidated_mask_key = "pred_masks_video_res"
|
| 452 |
+
else:
|
| 453 |
+
consolidated_H = consolidated_W = self.image_size // 4
|
| 454 |
+
consolidated_mask_key = "pred_masks"
|
| 455 |
+
|
| 456 |
+
# Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
|
| 457 |
+
# will be added when rerunning the memory encoder after applying non-overlapping
|
| 458 |
+
# constraints to object scores. Its "pred_masks" are prefilled with a large
|
| 459 |
+
# negative value (NO_OBJ_SCORE) to represent missing objects.
|
| 460 |
+
consolidated_out = {
|
| 461 |
+
"maskmem_features": None,
|
| 462 |
+
"maskmem_pos_enc": None,
|
| 463 |
+
consolidated_mask_key: torch.full(
|
| 464 |
+
size=(batch_size, 1, consolidated_H, consolidated_W),
|
| 465 |
+
fill_value=NO_OBJ_SCORE,
|
| 466 |
+
dtype=torch.float32,
|
| 467 |
+
device=inference_state["storage_device"],
|
| 468 |
+
),
|
| 469 |
+
"obj_ptr": torch.full(
|
| 470 |
+
size=(batch_size, self.hidden_dim),
|
| 471 |
+
fill_value=NO_OBJ_SCORE,
|
| 472 |
+
dtype=torch.float32,
|
| 473 |
+
device=inference_state["device"],
|
| 474 |
+
),
|
| 475 |
+
"object_score_logits": torch.full(
|
| 476 |
+
size=(batch_size, 1),
|
| 477 |
+
# default to 10.0 for object_score_logits, i.e. assuming the object is
|
| 478 |
+
# present as sigmoid(10)=1, same as in `predict_masks` of `MaskDecoder`
|
| 479 |
+
fill_value=10.0,
|
| 480 |
+
dtype=torch.float32,
|
| 481 |
+
device=inference_state["device"],
|
| 482 |
+
),
|
| 483 |
+
}
|
| 484 |
+
empty_mask_ptr = None
|
| 485 |
+
for obj_idx in range(batch_size):
|
| 486 |
+
obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
|
| 487 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
|
| 488 |
+
out = obj_temp_output_dict[storage_key].get(frame_idx, None)
|
| 489 |
+
# If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
|
| 490 |
+
# we fall back and look up its previous output in "output_dict_per_obj".
|
| 491 |
+
# We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
|
| 492 |
+
# "output_dict_per_obj" to find a previous output for this object.
|
| 493 |
+
if out is None:
|
| 494 |
+
out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
|
| 495 |
+
if out is None:
|
| 496 |
+
out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
|
| 497 |
+
# If the object doesn't appear in "output_dict_per_obj" either, we skip it
|
| 498 |
+
# and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
|
| 499 |
+
# placeholder above) and set its object pointer to be a dummy pointer.
|
| 500 |
+
if out is None:
|
| 501 |
+
# Fill in dummy object pointers for those objects without any inputs or
|
| 502 |
+
# tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
|
| 503 |
+
# i.e. when we need to build the memory for tracking).
|
| 504 |
+
if run_mem_encoder:
|
| 505 |
+
if empty_mask_ptr is None:
|
| 506 |
+
empty_mask_ptr = self._get_empty_mask_ptr(
|
| 507 |
+
inference_state, frame_idx
|
| 508 |
+
)
|
| 509 |
+
# fill object pointer with a dummy pointer (based on an empty mask)
|
| 510 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
|
| 511 |
+
continue
|
| 512 |
+
# Add the temporary object output mask to consolidated output mask
|
| 513 |
+
obj_mask = out["pred_masks"]
|
| 514 |
+
consolidated_pred_masks = consolidated_out[consolidated_mask_key]
|
| 515 |
+
if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
|
| 516 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
|
| 517 |
+
else:
|
| 518 |
+
# Resize first if temporary object mask has a different resolution
|
| 519 |
+
resized_obj_mask = torch.nn.functional.interpolate(
|
| 520 |
+
obj_mask,
|
| 521 |
+
size=consolidated_pred_masks.shape[-2:],
|
| 522 |
+
mode="bilinear",
|
| 523 |
+
align_corners=False,
|
| 524 |
+
)
|
| 525 |
+
consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
|
| 526 |
+
consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
|
| 527 |
+
consolidated_out["object_score_logits"][obj_idx : obj_idx + 1] = out[
|
| 528 |
+
"object_score_logits"
|
| 529 |
+
]
|
| 530 |
+
|
| 531 |
+
# Optionally, apply non-overlapping constraints on the consolidated scores
|
| 532 |
+
# and rerun the memory encoder
|
| 533 |
+
if run_mem_encoder:
|
| 534 |
+
device = inference_state["device"]
|
| 535 |
+
high_res_masks = torch.nn.functional.interpolate(
|
| 536 |
+
consolidated_out["pred_masks"].to(device, non_blocking=True),
|
| 537 |
+
size=(self.image_size, self.image_size),
|
| 538 |
+
mode="bilinear",
|
| 539 |
+
align_corners=False,
|
| 540 |
+
)
|
| 541 |
+
if self.non_overlap_masks_for_mem_enc:
|
| 542 |
+
high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
|
| 543 |
+
maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
|
| 544 |
+
inference_state=inference_state,
|
| 545 |
+
frame_idx=frame_idx,
|
| 546 |
+
batch_size=batch_size,
|
| 547 |
+
high_res_masks=high_res_masks,
|
| 548 |
+
object_score_logits=consolidated_out["object_score_logits"],
|
| 549 |
+
is_mask_from_pts=True, # these frames are what the user interacted with
|
| 550 |
+
)
|
| 551 |
+
consolidated_out["maskmem_features"] = maskmem_features
|
| 552 |
+
consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
|
| 553 |
+
|
| 554 |
+
return consolidated_out
|
| 555 |
+
|
| 556 |
+
def _get_empty_mask_ptr(self, inference_state, frame_idx):
|
| 557 |
+
"""Get a dummy object pointer based on an empty mask on the current frame."""
|
| 558 |
+
# A dummy (empty) mask with a single object
|
| 559 |
+
batch_size = 1
|
| 560 |
+
mask_inputs = torch.zeros(
|
| 561 |
+
(batch_size, 1, self.image_size, self.image_size),
|
| 562 |
+
dtype=torch.float32,
|
| 563 |
+
device=inference_state["device"],
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# Retrieve correct image features
|
| 567 |
+
(
|
| 568 |
+
_,
|
| 569 |
+
_,
|
| 570 |
+
current_vision_feats,
|
| 571 |
+
current_vision_pos_embeds,
|
| 572 |
+
feat_sizes,
|
| 573 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
| 574 |
+
|
| 575 |
+
# Feed the empty mask and image feature above to get a dummy object pointer
|
| 576 |
+
current_out = self.track_step(
|
| 577 |
+
frame_idx=frame_idx,
|
| 578 |
+
is_init_cond_frame=True,
|
| 579 |
+
current_vision_feats=current_vision_feats,
|
| 580 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
| 581 |
+
feat_sizes=feat_sizes,
|
| 582 |
+
point_inputs=None,
|
| 583 |
+
mask_inputs=mask_inputs,
|
| 584 |
+
output_dict={},
|
| 585 |
+
num_frames=inference_state["num_frames"],
|
| 586 |
+
track_in_reverse=False,
|
| 587 |
+
run_mem_encoder=False,
|
| 588 |
+
prev_sam_mask_logits=None,
|
| 589 |
+
)
|
| 590 |
+
return current_out["obj_ptr"]
|
| 591 |
+
|
| 592 |
+
@torch.inference_mode()
|
| 593 |
+
def propagate_in_video_preflight(self, inference_state):
|
| 594 |
+
"""Prepare inference_state and consolidate temporary outputs before tracking."""
|
| 595 |
+
# Tracking has started and we don't allow adding new objects until session is reset.
|
| 596 |
+
inference_state["tracking_has_started"] = True
|
| 597 |
+
batch_size = self._get_obj_num(inference_state)
|
| 598 |
+
|
| 599 |
+
# Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
|
| 600 |
+
# add them into "output_dict".
|
| 601 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 602 |
+
output_dict = inference_state["output_dict"]
|
| 603 |
+
# "consolidated_frame_inds" contains indices of those frames where consolidated
|
| 604 |
+
# temporary outputs have been added (either in this call or any previous calls
|
| 605 |
+
# to `propagate_in_video_preflight`).
|
| 606 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 607 |
+
for is_cond in [False, True]:
|
| 608 |
+
# Separately consolidate conditioning and non-conditioning temp outputs
|
| 609 |
+
storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
|
| 610 |
+
# Find all the frames that contain temporary outputs for any objects
|
| 611 |
+
# (these should be the frames that have just received clicks for mask inputs
|
| 612 |
+
# via `add_new_points_or_box` or `add_new_mask`)
|
| 613 |
+
temp_frame_inds = set()
|
| 614 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 615 |
+
temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
|
| 616 |
+
consolidated_frame_inds[storage_key].update(temp_frame_inds)
|
| 617 |
+
# consolidate the temporary output across all objects on this frame
|
| 618 |
+
for frame_idx in temp_frame_inds:
|
| 619 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 620 |
+
inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
|
| 621 |
+
)
|
| 622 |
+
# merge them into "output_dict" and also create per-object slices
|
| 623 |
+
output_dict[storage_key][frame_idx] = consolidated_out
|
| 624 |
+
self._add_output_per_object(
|
| 625 |
+
inference_state, frame_idx, consolidated_out, storage_key
|
| 626 |
+
)
|
| 627 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
| 628 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
| 629 |
+
)
|
| 630 |
+
if clear_non_cond_mem:
|
| 631 |
+
# clear non-conditioning memory of the surrounding frames
|
| 632 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
| 633 |
+
|
| 634 |
+
# clear temporary outputs in `temp_output_dict_per_obj`
|
| 635 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values():
|
| 636 |
+
obj_temp_output_dict[storage_key].clear()
|
| 637 |
+
|
| 638 |
+
# edge case: if an output is added to "cond_frame_outputs", we remove any prior
|
| 639 |
+
# output on the same frame in "non_cond_frame_outputs"
|
| 640 |
+
for frame_idx in output_dict["cond_frame_outputs"]:
|
| 641 |
+
output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 642 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 643 |
+
for frame_idx in obj_output_dict["cond_frame_outputs"]:
|
| 644 |
+
obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 645 |
+
for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
| 646 |
+
assert frame_idx in output_dict["cond_frame_outputs"]
|
| 647 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
| 648 |
+
|
| 649 |
+
# Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
|
| 650 |
+
# with either points or mask inputs (which should be true under a correct workflow).
|
| 651 |
+
all_consolidated_frame_inds = (
|
| 652 |
+
consolidated_frame_inds["cond_frame_outputs"]
|
| 653 |
+
| consolidated_frame_inds["non_cond_frame_outputs"]
|
| 654 |
+
)
|
| 655 |
+
input_frames_inds = set()
|
| 656 |
+
for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
|
| 657 |
+
input_frames_inds.update(point_inputs_per_frame.keys())
|
| 658 |
+
for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
|
| 659 |
+
input_frames_inds.update(mask_inputs_per_frame.keys())
|
| 660 |
+
assert all_consolidated_frame_inds == input_frames_inds
|
| 661 |
+
|
| 662 |
+
@torch.inference_mode()
|
| 663 |
+
def propagate_in_video(
|
| 664 |
+
self,
|
| 665 |
+
inference_state,
|
| 666 |
+
start_frame_idx=None,
|
| 667 |
+
max_frame_num_to_track=None,
|
| 668 |
+
reverse=False,
|
| 669 |
+
):
|
| 670 |
+
"""Propagate the input points across frames to track in the entire video."""
|
| 671 |
+
self.propagate_in_video_preflight(inference_state)
|
| 672 |
+
|
| 673 |
+
output_dict = inference_state["output_dict"]
|
| 674 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 675 |
+
obj_ids = inference_state["obj_ids"]
|
| 676 |
+
num_frames = inference_state["num_frames"]
|
| 677 |
+
batch_size = self._get_obj_num(inference_state)
|
| 678 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
| 679 |
+
raise RuntimeError("No points are provided; please add points first")
|
| 680 |
+
clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
|
| 681 |
+
self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# set start index, end index, and processing order
|
| 685 |
+
if start_frame_idx is None:
|
| 686 |
+
# default: start from the earliest frame with input points
|
| 687 |
+
start_frame_idx = min(output_dict["cond_frame_outputs"])
|
| 688 |
+
if max_frame_num_to_track is None:
|
| 689 |
+
# default: track all the frames in the video
|
| 690 |
+
max_frame_num_to_track = num_frames
|
| 691 |
+
if reverse:
|
| 692 |
+
end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
|
| 693 |
+
if start_frame_idx > 0:
|
| 694 |
+
processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
|
| 695 |
+
else:
|
| 696 |
+
processing_order = [] # skip reverse tracking if starting from frame 0
|
| 697 |
+
else:
|
| 698 |
+
end_frame_idx = min(
|
| 699 |
+
start_frame_idx + max_frame_num_to_track, num_frames - 1
|
| 700 |
+
)
|
| 701 |
+
processing_order = range(start_frame_idx, end_frame_idx + 1)
|
| 702 |
+
|
| 703 |
+
for frame_idx in tqdm(processing_order, desc="propagate in video"):
|
| 704 |
+
# We skip those frames already in consolidated outputs (these are frames
|
| 705 |
+
# that received input clicks or mask). Note that we cannot directly run
|
| 706 |
+
# batched forward on them via `_run_single_frame_inference` because the
|
| 707 |
+
# number of clicks on each object might be different.
|
| 708 |
+
if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
|
| 709 |
+
storage_key = "cond_frame_outputs"
|
| 710 |
+
current_out = output_dict[storage_key][frame_idx]
|
| 711 |
+
pred_masks = current_out["pred_masks"]
|
| 712 |
+
if clear_non_cond_mem:
|
| 713 |
+
# clear non-conditioning memory of the surrounding frames
|
| 714 |
+
self._clear_non_cond_mem_around_input(inference_state, frame_idx)
|
| 715 |
+
elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
|
| 716 |
+
storage_key = "non_cond_frame_outputs"
|
| 717 |
+
current_out = output_dict[storage_key][frame_idx]
|
| 718 |
+
pred_masks = current_out["pred_masks"]
|
| 719 |
+
else:
|
| 720 |
+
storage_key = "non_cond_frame_outputs"
|
| 721 |
+
current_out, pred_masks = self._run_single_frame_inference(
|
| 722 |
+
inference_state=inference_state,
|
| 723 |
+
output_dict=output_dict,
|
| 724 |
+
frame_idx=frame_idx,
|
| 725 |
+
batch_size=batch_size,
|
| 726 |
+
is_init_cond_frame=False,
|
| 727 |
+
point_inputs=None,
|
| 728 |
+
mask_inputs=None,
|
| 729 |
+
reverse=reverse,
|
| 730 |
+
run_mem_encoder=True,
|
| 731 |
+
)
|
| 732 |
+
output_dict[storage_key][frame_idx] = current_out
|
| 733 |
+
# Create slices of per-object outputs for subsequent interaction with each
|
| 734 |
+
# individual object after tracking.
|
| 735 |
+
self._add_output_per_object(
|
| 736 |
+
inference_state, frame_idx, current_out, storage_key
|
| 737 |
+
)
|
| 738 |
+
inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
|
| 739 |
+
|
| 740 |
+
# Resize the output mask to the original video resolution (we directly use
|
| 741 |
+
# the mask scores on GPU for output to avoid any CPU conversion in between)
|
| 742 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 743 |
+
inference_state, pred_masks
|
| 744 |
+
)
|
| 745 |
+
yield frame_idx, obj_ids, video_res_masks
|
| 746 |
+
|
| 747 |
+
def _add_output_per_object(
|
| 748 |
+
self, inference_state, frame_idx, current_out, storage_key
|
| 749 |
+
):
|
| 750 |
+
"""
|
| 751 |
+
Split a multi-object output into per-object output slices and add them into
|
| 752 |
+
`output_dict_per_obj`. The resulting slices share the same tensor storage.
|
| 753 |
+
"""
|
| 754 |
+
maskmem_features = current_out["maskmem_features"]
|
| 755 |
+
assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
|
| 756 |
+
|
| 757 |
+
maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| 758 |
+
assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
|
| 759 |
+
|
| 760 |
+
output_dict_per_obj = inference_state["output_dict_per_obj"]
|
| 761 |
+
for obj_idx, obj_output_dict in output_dict_per_obj.items():
|
| 762 |
+
obj_slice = slice(obj_idx, obj_idx + 1)
|
| 763 |
+
obj_out = {
|
| 764 |
+
"maskmem_features": None,
|
| 765 |
+
"maskmem_pos_enc": None,
|
| 766 |
+
"pred_masks": current_out["pred_masks"][obj_slice],
|
| 767 |
+
"obj_ptr": current_out["obj_ptr"][obj_slice],
|
| 768 |
+
"object_score_logits": current_out["object_score_logits"][obj_slice],
|
| 769 |
+
}
|
| 770 |
+
if maskmem_features is not None:
|
| 771 |
+
obj_out["maskmem_features"] = maskmem_features[obj_slice]
|
| 772 |
+
if maskmem_pos_enc is not None:
|
| 773 |
+
obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
|
| 774 |
+
obj_output_dict[storage_key][frame_idx] = obj_out
|
| 775 |
+
|
| 776 |
+
@torch.inference_mode()
|
| 777 |
+
def clear_all_prompts_in_frame(
|
| 778 |
+
self, inference_state, frame_idx, obj_id, need_output=True
|
| 779 |
+
):
|
| 780 |
+
"""Remove all input points or mask in a specific frame for a given object."""
|
| 781 |
+
obj_idx = self._obj_id_to_idx(inference_state, obj_id)
|
| 782 |
+
|
| 783 |
+
# Clear the conditioning information on the given frame
|
| 784 |
+
inference_state["point_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
| 785 |
+
inference_state["mask_inputs_per_obj"][obj_idx].pop(frame_idx, None)
|
| 786 |
+
|
| 787 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 788 |
+
temp_output_dict_per_obj[obj_idx]["cond_frame_outputs"].pop(frame_idx, None)
|
| 789 |
+
temp_output_dict_per_obj[obj_idx]["non_cond_frame_outputs"].pop(frame_idx, None)
|
| 790 |
+
|
| 791 |
+
# Check and see if there are still any inputs left on this frame
|
| 792 |
+
batch_size = self._get_obj_num(inference_state)
|
| 793 |
+
frame_has_input = False
|
| 794 |
+
for obj_idx2 in range(batch_size):
|
| 795 |
+
if frame_idx in inference_state["point_inputs_per_obj"][obj_idx2]:
|
| 796 |
+
frame_has_input = True
|
| 797 |
+
break
|
| 798 |
+
if frame_idx in inference_state["mask_inputs_per_obj"][obj_idx2]:
|
| 799 |
+
frame_has_input = True
|
| 800 |
+
break
|
| 801 |
+
|
| 802 |
+
# If this frame has no remaining inputs for any objects, we further clear its
|
| 803 |
+
# conditioning frame status
|
| 804 |
+
if not frame_has_input:
|
| 805 |
+
output_dict = inference_state["output_dict"]
|
| 806 |
+
consolidated_frame_inds = inference_state["consolidated_frame_inds"]
|
| 807 |
+
consolidated_frame_inds["cond_frame_outputs"].discard(frame_idx)
|
| 808 |
+
consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
|
| 809 |
+
# Remove the frame's conditioning output (possibly downgrading it to non-conditioning)
|
| 810 |
+
out = output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
| 811 |
+
if out is not None:
|
| 812 |
+
# The frame is not a conditioning frame anymore since it's not receiving inputs,
|
| 813 |
+
# so we "downgrade" its output (if exists) to a non-conditioning frame output.
|
| 814 |
+
output_dict["non_cond_frame_outputs"][frame_idx] = out
|
| 815 |
+
inference_state["frames_already_tracked"].pop(frame_idx, None)
|
| 816 |
+
# Similarly, do it for the sliced output on each object.
|
| 817 |
+
for obj_idx2 in range(batch_size):
|
| 818 |
+
obj_output_dict = inference_state["output_dict_per_obj"][obj_idx2]
|
| 819 |
+
obj_out = obj_output_dict["cond_frame_outputs"].pop(frame_idx, None)
|
| 820 |
+
if obj_out is not None:
|
| 821 |
+
obj_output_dict["non_cond_frame_outputs"][frame_idx] = obj_out
|
| 822 |
+
|
| 823 |
+
# If all the conditioning frames have been removed, we also clear the tracking outputs
|
| 824 |
+
if len(output_dict["cond_frame_outputs"]) == 0:
|
| 825 |
+
self._reset_tracking_results(inference_state)
|
| 826 |
+
|
| 827 |
+
if not need_output:
|
| 828 |
+
return
|
| 829 |
+
# Finally, output updated masks per object (after removing the inputs above)
|
| 830 |
+
obj_ids = inference_state["obj_ids"]
|
| 831 |
+
is_cond = any(
|
| 832 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| 833 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| 834 |
+
)
|
| 835 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 836 |
+
inference_state,
|
| 837 |
+
frame_idx,
|
| 838 |
+
is_cond=is_cond,
|
| 839 |
+
run_mem_encoder=False,
|
| 840 |
+
consolidate_at_video_res=True,
|
| 841 |
+
)
|
| 842 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 843 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
| 844 |
+
)
|
| 845 |
+
return frame_idx, obj_ids, video_res_masks
|
| 846 |
+
|
| 847 |
+
@torch.inference_mode()
|
| 848 |
+
def reset_state(self, inference_state):
|
| 849 |
+
"""Remove all input points or mask in all frames throughout the video."""
|
| 850 |
+
self._reset_tracking_results(inference_state)
|
| 851 |
+
# Remove all object ids
|
| 852 |
+
inference_state["obj_id_to_idx"].clear()
|
| 853 |
+
inference_state["obj_idx_to_id"].clear()
|
| 854 |
+
inference_state["obj_ids"].clear()
|
| 855 |
+
inference_state["point_inputs_per_obj"].clear()
|
| 856 |
+
inference_state["mask_inputs_per_obj"].clear()
|
| 857 |
+
inference_state["output_dict_per_obj"].clear()
|
| 858 |
+
inference_state["temp_output_dict_per_obj"].clear()
|
| 859 |
+
|
| 860 |
+
def _reset_tracking_results(self, inference_state):
|
| 861 |
+
"""Reset all tracking inputs and results across the videos."""
|
| 862 |
+
for v in inference_state["point_inputs_per_obj"].values():
|
| 863 |
+
v.clear()
|
| 864 |
+
for v in inference_state["mask_inputs_per_obj"].values():
|
| 865 |
+
v.clear()
|
| 866 |
+
for v in inference_state["output_dict_per_obj"].values():
|
| 867 |
+
v["cond_frame_outputs"].clear()
|
| 868 |
+
v["non_cond_frame_outputs"].clear()
|
| 869 |
+
for v in inference_state["temp_output_dict_per_obj"].values():
|
| 870 |
+
v["cond_frame_outputs"].clear()
|
| 871 |
+
v["non_cond_frame_outputs"].clear()
|
| 872 |
+
inference_state["output_dict"]["cond_frame_outputs"].clear()
|
| 873 |
+
inference_state["output_dict"]["non_cond_frame_outputs"].clear()
|
| 874 |
+
inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
|
| 875 |
+
inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
|
| 876 |
+
inference_state["tracking_has_started"] = False
|
| 877 |
+
inference_state["frames_already_tracked"].clear()
|
| 878 |
+
|
| 879 |
+
def _get_image_feature(self, inference_state, frame_idx, batch_size):
|
| 880 |
+
"""Compute the image features on a given frame."""
|
| 881 |
+
# Look up in the cache first
|
| 882 |
+
image, backbone_out = inference_state["cached_features"].get(
|
| 883 |
+
frame_idx, (None, None)
|
| 884 |
+
)
|
| 885 |
+
if backbone_out is None:
|
| 886 |
+
# Cache miss -- we will run inference on a single image
|
| 887 |
+
device = inference_state["device"]
|
| 888 |
+
image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0)
|
| 889 |
+
backbone_out = self.forward_image(image)
|
| 890 |
+
# Cache the most recent frame's feature (for repeated interactions with
|
| 891 |
+
# a frame; we can use an LRU cache for more frames in the future).
|
| 892 |
+
inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
|
| 893 |
+
|
| 894 |
+
# expand the features to have the same dimension as the number of objects
|
| 895 |
+
expanded_image = image.expand(batch_size, -1, -1, -1)
|
| 896 |
+
expanded_backbone_out = {
|
| 897 |
+
"backbone_fpn": backbone_out["backbone_fpn"].copy(),
|
| 898 |
+
"vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
|
| 899 |
+
}
|
| 900 |
+
for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
|
| 901 |
+
expanded_backbone_out["backbone_fpn"][i] = feat.expand(
|
| 902 |
+
batch_size, -1, -1, -1
|
| 903 |
+
)
|
| 904 |
+
for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
|
| 905 |
+
pos = pos.expand(batch_size, -1, -1, -1)
|
| 906 |
+
expanded_backbone_out["vision_pos_enc"][i] = pos
|
| 907 |
+
|
| 908 |
+
features = self._prepare_backbone_features(expanded_backbone_out)
|
| 909 |
+
features = (expanded_image,) + features
|
| 910 |
+
return features
|
| 911 |
+
|
| 912 |
+
def _run_single_frame_inference(
|
| 913 |
+
self,
|
| 914 |
+
inference_state,
|
| 915 |
+
output_dict,
|
| 916 |
+
frame_idx,
|
| 917 |
+
batch_size,
|
| 918 |
+
is_init_cond_frame,
|
| 919 |
+
point_inputs,
|
| 920 |
+
mask_inputs,
|
| 921 |
+
reverse,
|
| 922 |
+
run_mem_encoder,
|
| 923 |
+
prev_sam_mask_logits=None,
|
| 924 |
+
):
|
| 925 |
+
"""Run tracking on a single frame based on current inputs and previous memory."""
|
| 926 |
+
# Retrieve correct image features
|
| 927 |
+
(
|
| 928 |
+
_,
|
| 929 |
+
_,
|
| 930 |
+
current_vision_feats,
|
| 931 |
+
current_vision_pos_embeds,
|
| 932 |
+
feat_sizes,
|
| 933 |
+
) = self._get_image_feature(inference_state, frame_idx, batch_size)
|
| 934 |
+
|
| 935 |
+
# point and mask should not appear as input simultaneously on the same frame
|
| 936 |
+
assert point_inputs is None or mask_inputs is None
|
| 937 |
+
current_out = self.track_step(
|
| 938 |
+
frame_idx=frame_idx,
|
| 939 |
+
is_init_cond_frame=is_init_cond_frame,
|
| 940 |
+
current_vision_feats=current_vision_feats,
|
| 941 |
+
current_vision_pos_embeds=current_vision_pos_embeds,
|
| 942 |
+
feat_sizes=feat_sizes,
|
| 943 |
+
point_inputs=point_inputs,
|
| 944 |
+
mask_inputs=mask_inputs,
|
| 945 |
+
output_dict=output_dict,
|
| 946 |
+
num_frames=inference_state["num_frames"],
|
| 947 |
+
track_in_reverse=reverse,
|
| 948 |
+
run_mem_encoder=run_mem_encoder,
|
| 949 |
+
prev_sam_mask_logits=prev_sam_mask_logits,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
# optionally offload the output to CPU memory to save GPU space
|
| 953 |
+
storage_device = inference_state["storage_device"]
|
| 954 |
+
maskmem_features = current_out["maskmem_features"]
|
| 955 |
+
if maskmem_features is not None:
|
| 956 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 957 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 958 |
+
pred_masks_gpu = current_out["pred_masks"]
|
| 959 |
+
# potentially fill holes in the predicted masks
|
| 960 |
+
if self.fill_hole_area > 0:
|
| 961 |
+
pred_masks_gpu = fill_holes_in_mask_scores(
|
| 962 |
+
pred_masks_gpu, self.fill_hole_area
|
| 963 |
+
)
|
| 964 |
+
pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
|
| 965 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 966 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
|
| 967 |
+
# object pointer is a small tensor, so we always keep it on GPU memory for fast access
|
| 968 |
+
obj_ptr = current_out["obj_ptr"]
|
| 969 |
+
object_score_logits = current_out["object_score_logits"]
|
| 970 |
+
# make a compact version of this frame's output to reduce the state size
|
| 971 |
+
compact_current_out = {
|
| 972 |
+
"maskmem_features": maskmem_features,
|
| 973 |
+
"maskmem_pos_enc": maskmem_pos_enc,
|
| 974 |
+
"pred_masks": pred_masks,
|
| 975 |
+
"obj_ptr": obj_ptr,
|
| 976 |
+
"object_score_logits": object_score_logits,
|
| 977 |
+
}
|
| 978 |
+
return compact_current_out, pred_masks_gpu
|
| 979 |
+
|
| 980 |
+
def _run_memory_encoder(
|
| 981 |
+
self,
|
| 982 |
+
inference_state,
|
| 983 |
+
frame_idx,
|
| 984 |
+
batch_size,
|
| 985 |
+
high_res_masks,
|
| 986 |
+
object_score_logits,
|
| 987 |
+
is_mask_from_pts,
|
| 988 |
+
):
|
| 989 |
+
"""
|
| 990 |
+
Run the memory encoder on `high_res_masks`. This is usually after applying
|
| 991 |
+
non-overlapping constraints to object scores. Since their scores changed, their
|
| 992 |
+
memory also need to be computed again with the memory encoder.
|
| 993 |
+
"""
|
| 994 |
+
# Retrieve correct image features
|
| 995 |
+
_, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
|
| 996 |
+
inference_state, frame_idx, batch_size
|
| 997 |
+
)
|
| 998 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
| 999 |
+
current_vision_feats=current_vision_feats,
|
| 1000 |
+
feat_sizes=feat_sizes,
|
| 1001 |
+
pred_masks_high_res=high_res_masks,
|
| 1002 |
+
object_score_logits=object_score_logits,
|
| 1003 |
+
is_mask_from_pts=is_mask_from_pts,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
# optionally offload the output to CPU memory to save GPU space
|
| 1007 |
+
storage_device = inference_state["storage_device"]
|
| 1008 |
+
maskmem_features = maskmem_features.to(torch.bfloat16)
|
| 1009 |
+
maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
|
| 1010 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 1011 |
+
maskmem_pos_enc = self._get_maskmem_pos_enc(
|
| 1012 |
+
inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
|
| 1013 |
+
)
|
| 1014 |
+
return maskmem_features, maskmem_pos_enc
|
| 1015 |
+
|
| 1016 |
+
def _get_maskmem_pos_enc(self, inference_state, current_out):
|
| 1017 |
+
"""
|
| 1018 |
+
`maskmem_pos_enc` is the same across frames and objects, so we cache it as
|
| 1019 |
+
a constant in the inference session to reduce session storage size.
|
| 1020 |
+
"""
|
| 1021 |
+
model_constants = inference_state["constants"]
|
| 1022 |
+
# "out_maskmem_pos_enc" should be either a list of tensors or None
|
| 1023 |
+
out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
|
| 1024 |
+
if out_maskmem_pos_enc is not None:
|
| 1025 |
+
if "maskmem_pos_enc" not in model_constants:
|
| 1026 |
+
assert isinstance(out_maskmem_pos_enc, list)
|
| 1027 |
+
# only take the slice for one object, since it's same across objects
|
| 1028 |
+
maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
|
| 1029 |
+
model_constants["maskmem_pos_enc"] = maskmem_pos_enc
|
| 1030 |
+
else:
|
| 1031 |
+
maskmem_pos_enc = model_constants["maskmem_pos_enc"]
|
| 1032 |
+
# expand the cached maskmem_pos_enc to the actual batch size
|
| 1033 |
+
batch_size = out_maskmem_pos_enc[0].size(0)
|
| 1034 |
+
expanded_maskmem_pos_enc = [
|
| 1035 |
+
x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
|
| 1036 |
+
]
|
| 1037 |
+
else:
|
| 1038 |
+
expanded_maskmem_pos_enc = None
|
| 1039 |
+
return expanded_maskmem_pos_enc
|
| 1040 |
+
|
| 1041 |
+
@torch.inference_mode()
|
| 1042 |
+
def remove_object(self, inference_state, obj_id, strict=False, need_output=True):
|
| 1043 |
+
"""
|
| 1044 |
+
Remove an object id from the tracking state. If strict is True, we check whether
|
| 1045 |
+
the object id actually exists and raise an error if it doesn't exist.
|
| 1046 |
+
"""
|
| 1047 |
+
old_obj_idx_to_rm = inference_state["obj_id_to_idx"].get(obj_id, None)
|
| 1048 |
+
updated_frames = []
|
| 1049 |
+
# Check whether this object_id to remove actually exists and possibly raise an error.
|
| 1050 |
+
if old_obj_idx_to_rm is None:
|
| 1051 |
+
if not strict:
|
| 1052 |
+
return inference_state["obj_ids"], updated_frames
|
| 1053 |
+
raise RuntimeError(
|
| 1054 |
+
f"Cannot remove object id {obj_id} as it doesn't exist. "
|
| 1055 |
+
f"All existing object ids: {inference_state['obj_ids']}."
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
# If this is the only remaining object id, we simply reset the state.
|
| 1059 |
+
if len(inference_state["obj_id_to_idx"]) == 1:
|
| 1060 |
+
self.reset_state(inference_state)
|
| 1061 |
+
return inference_state["obj_ids"], updated_frames
|
| 1062 |
+
|
| 1063 |
+
# There are still remaining objects after removing this object id. In this case,
|
| 1064 |
+
# we need to delete the object storage from inference state tensors.
|
| 1065 |
+
# Step 0: clear the input on those frames where this object id has point or mask input
|
| 1066 |
+
# (note that this step is required as it might downgrade conditioning frames to
|
| 1067 |
+
# non-conditioning ones)
|
| 1068 |
+
obj_input_frames_inds = set()
|
| 1069 |
+
obj_input_frames_inds.update(
|
| 1070 |
+
inference_state["point_inputs_per_obj"][old_obj_idx_to_rm]
|
| 1071 |
+
)
|
| 1072 |
+
obj_input_frames_inds.update(
|
| 1073 |
+
inference_state["mask_inputs_per_obj"][old_obj_idx_to_rm]
|
| 1074 |
+
)
|
| 1075 |
+
for frame_idx in obj_input_frames_inds:
|
| 1076 |
+
self.clear_all_prompts_in_frame(
|
| 1077 |
+
inference_state, frame_idx, obj_id, need_output=False
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
# Step 1: Update the object id mapping (note that it must be done after Step 0,
|
| 1081 |
+
# since Step 0 still requires the old object id mappings in inference_state)
|
| 1082 |
+
old_obj_ids = inference_state["obj_ids"]
|
| 1083 |
+
old_obj_inds = list(range(len(old_obj_ids)))
|
| 1084 |
+
remain_old_obj_inds = old_obj_inds.copy()
|
| 1085 |
+
remain_old_obj_inds.remove(old_obj_idx_to_rm)
|
| 1086 |
+
new_obj_ids = [old_obj_ids[old_idx] for old_idx in remain_old_obj_inds]
|
| 1087 |
+
new_obj_inds = list(range(len(new_obj_ids)))
|
| 1088 |
+
# build new mappings
|
| 1089 |
+
old_idx_to_new_idx = dict(zip(remain_old_obj_inds, new_obj_inds))
|
| 1090 |
+
inference_state["obj_id_to_idx"] = dict(zip(new_obj_ids, new_obj_inds))
|
| 1091 |
+
inference_state["obj_idx_to_id"] = dict(zip(new_obj_inds, new_obj_ids))
|
| 1092 |
+
inference_state["obj_ids"] = new_obj_ids
|
| 1093 |
+
|
| 1094 |
+
# Step 2: For per-object tensor storage, we shift their obj_idx in the dict keys.
|
| 1095 |
+
# (note that "consolidated_frame_inds" doesn't need to be updated in this step as
|
| 1096 |
+
# it's already handled in Step 0)
|
| 1097 |
+
def _map_keys(container):
|
| 1098 |
+
new_kvs = []
|
| 1099 |
+
for k in old_obj_inds:
|
| 1100 |
+
v = container.pop(k)
|
| 1101 |
+
if k in old_idx_to_new_idx:
|
| 1102 |
+
new_kvs.append((old_idx_to_new_idx[k], v))
|
| 1103 |
+
container.update(new_kvs)
|
| 1104 |
+
|
| 1105 |
+
_map_keys(inference_state["point_inputs_per_obj"])
|
| 1106 |
+
_map_keys(inference_state["mask_inputs_per_obj"])
|
| 1107 |
+
_map_keys(inference_state["output_dict_per_obj"])
|
| 1108 |
+
_map_keys(inference_state["temp_output_dict_per_obj"])
|
| 1109 |
+
|
| 1110 |
+
# Step 3: For packed tensor storage, we index the remaining ids and rebuild the per-object slices.
|
| 1111 |
+
def _slice_state(output_dict, storage_key):
|
| 1112 |
+
for frame_idx, out in output_dict[storage_key].items():
|
| 1113 |
+
out["maskmem_features"] = out["maskmem_features"][remain_old_obj_inds]
|
| 1114 |
+
out["maskmem_pos_enc"] = [
|
| 1115 |
+
x[remain_old_obj_inds] for x in out["maskmem_pos_enc"]
|
| 1116 |
+
]
|
| 1117 |
+
# "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
|
| 1118 |
+
out["maskmem_pos_enc"] = self._get_maskmem_pos_enc(inference_state, out)
|
| 1119 |
+
out["pred_masks"] = out["pred_masks"][remain_old_obj_inds]
|
| 1120 |
+
out["obj_ptr"] = out["obj_ptr"][remain_old_obj_inds]
|
| 1121 |
+
out["object_score_logits"] = out["object_score_logits"][
|
| 1122 |
+
remain_old_obj_inds
|
| 1123 |
+
]
|
| 1124 |
+
# also update the per-object slices
|
| 1125 |
+
self._add_output_per_object(
|
| 1126 |
+
inference_state, frame_idx, out, storage_key
|
| 1127 |
+
)
|
| 1128 |
+
|
| 1129 |
+
_slice_state(inference_state["output_dict"], "cond_frame_outputs")
|
| 1130 |
+
_slice_state(inference_state["output_dict"], "non_cond_frame_outputs")
|
| 1131 |
+
|
| 1132 |
+
# Step 4: Further collect the outputs on those frames in `obj_input_frames_inds`, which
|
| 1133 |
+
# could show an updated mask for objects previously occluded by the object being removed
|
| 1134 |
+
if need_output:
|
| 1135 |
+
temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
|
| 1136 |
+
for frame_idx in obj_input_frames_inds:
|
| 1137 |
+
is_cond = any(
|
| 1138 |
+
frame_idx in obj_temp_output_dict["cond_frame_outputs"]
|
| 1139 |
+
for obj_temp_output_dict in temp_output_dict_per_obj.values()
|
| 1140 |
+
)
|
| 1141 |
+
consolidated_out = self._consolidate_temp_output_across_obj(
|
| 1142 |
+
inference_state,
|
| 1143 |
+
frame_idx,
|
| 1144 |
+
is_cond=is_cond,
|
| 1145 |
+
run_mem_encoder=False,
|
| 1146 |
+
consolidate_at_video_res=True,
|
| 1147 |
+
)
|
| 1148 |
+
_, video_res_masks = self._get_orig_video_res_output(
|
| 1149 |
+
inference_state, consolidated_out["pred_masks_video_res"]
|
| 1150 |
+
)
|
| 1151 |
+
updated_frames.append((frame_idx, video_res_masks))
|
| 1152 |
+
|
| 1153 |
+
return inference_state["obj_ids"], updated_frames
|
| 1154 |
+
|
| 1155 |
+
def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
|
| 1156 |
+
"""
|
| 1157 |
+
Remove the non-conditioning memory around the input frame. When users provide
|
| 1158 |
+
correction clicks, the surrounding frames' non-conditioning memories can still
|
| 1159 |
+
contain outdated object appearance information and could confuse the model.
|
| 1160 |
+
|
| 1161 |
+
This method clears those non-conditioning memories surrounding the interacted
|
| 1162 |
+
frame to avoid giving the model both old and new information about the object.
|
| 1163 |
+
"""
|
| 1164 |
+
r = self.memory_temporal_stride_for_eval
|
| 1165 |
+
frame_idx_begin = frame_idx - r * self.num_maskmem
|
| 1166 |
+
frame_idx_end = frame_idx + r * self.num_maskmem
|
| 1167 |
+
output_dict = inference_state["output_dict"]
|
| 1168 |
+
non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
|
| 1169 |
+
for t in range(frame_idx_begin, frame_idx_end + 1):
|
| 1170 |
+
non_cond_frame_outputs.pop(t, None)
|
| 1171 |
+
for obj_output_dict in inference_state["output_dict_per_obj"].values():
|
| 1172 |
+
obj_output_dict["non_cond_frame_outputs"].pop(t, None)
|