Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import numpy as np | |
from segment_anything import SamPredictor, sam_model_registry | |
import torch | |
from modules.AutoDetailer import mask_util | |
from modules.Device import Device | |
def sam_predict( | |
predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float | |
) -> list: | |
"""#### Predict masks using SAM. | |
#### Args: | |
- `predictor` (SamPredictor): The SAM predictor. | |
- `points` (list): List of points. | |
- `plabs` (list): List of point labels. | |
- `bbox` (list): Bounding box. | |
- `threshold` (float): Threshold for mask selection. | |
#### Returns: | |
- `list`: List of predicted masks. | |
""" | |
point_coords = None if not points else np.array(points) | |
point_labels = None if not plabs else np.array(plabs) | |
box = np.array([bbox]) if bbox is not None else None | |
cur_masks, scores, _ = predictor.predict( | |
point_coords=point_coords, point_labels=point_labels, box=box | |
) | |
total_masks = [] | |
selected = False | |
max_score = 0 | |
max_mask = None | |
for idx in range(len(scores)): | |
if scores[idx] > max_score: | |
max_score = scores[idx] | |
max_mask = cur_masks[idx] | |
if scores[idx] >= threshold: | |
selected = True | |
total_masks.append(cur_masks[idx]) | |
else: | |
pass | |
if not selected and max_mask is not None: | |
total_masks.append(max_mask) | |
return total_masks | |
def is_same_device(a: torch.device, b: torch.device) -> bool: | |
"""#### Check if two devices are the same. | |
#### Args: | |
- `a` (torch.device): The first device. | |
- `b` (torch.device): The second device. | |
#### Returns: | |
- `bool`: Whether the devices are the same. | |
""" | |
a_device = torch.device(a) if isinstance(a, str) else a | |
b_device = torch.device(b) if isinstance(b, str) else b | |
return a_device.type == b_device.type and a_device.index == b_device.index | |
class SafeToGPU: | |
"""#### Class to safely move objects to GPU.""" | |
def __init__(self, size: int): | |
self.size = size | |
def to_device(self, obj: torch.nn.Module, device: torch.device) -> None: | |
"""#### Move an object to a device. | |
#### Args: | |
- `obj` (torch.nn.Module): The object to move. | |
- `device` (torch.device): The target device. | |
""" | |
if is_same_device(device, "cpu"): | |
obj.to(device) | |
else: | |
if is_same_device(obj.device, "cpu"): # cpu to gpu | |
Device.free_memory(self.size * 1.3, device) | |
if Device.get_free_memory(device) > self.size * 1.3: | |
try: | |
obj.to(device) | |
except: | |
print( | |
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]" | |
) | |
else: | |
print( | |
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]" | |
) | |
class SAMWrapper: | |
"""#### Wrapper class for SAM model.""" | |
def __init__( | |
self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None | |
): | |
self.model = model | |
self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU() | |
self.is_auto_mode = is_auto_mode | |
def prepare_device(self) -> None: | |
"""#### Prepare the device for the model.""" | |
if self.is_auto_mode: | |
device = Device.get_torch_device() | |
self.safe_to_gpu.to_device(self.model, device=device) | |
def release_device(self) -> None: | |
"""#### Release the device from the model.""" | |
if self.is_auto_mode: | |
self.model.to(device="cpu") | |
def predict( | |
self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float | |
) -> list: | |
"""#### Predict masks using the SAM model. | |
#### Args: | |
- `image` (np.ndarray): The input image. | |
- `points` (list): List of points. | |
- `plabs` (list): List of point labels. | |
- `bbox` (list): Bounding box. | |
- `threshold` (float): Threshold for mask selection. | |
#### Returns: | |
- `list`: List of predicted masks. | |
""" | |
predictor = SamPredictor(self.model) | |
predictor.set_image(image, "RGB") | |
return sam_predict(predictor, points, plabs, bbox, threshold) | |
class SAMLoader: | |
"""#### Class to load SAM models.""" | |
def load_model(self, model_name: str, device_mode: str = "auto") -> tuple: | |
"""#### Load a SAM model. | |
#### Args: | |
- `model_name` (str): The name of the model. | |
- `device_mode` (str, optional): The device mode. Defaults to "auto". | |
#### Returns: | |
- `tuple`: The loaded SAM model. | |
""" | |
modelname = "./_internal/yolos/" + model_name | |
if "vit_h" in model_name: | |
model_kind = "vit_h" | |
elif "vit_l" in model_name: | |
model_kind = "vit_l" | |
else: | |
model_kind = "vit_b" | |
sam = sam_model_registry[model_kind](checkpoint=modelname) | |
size = os.path.getsize(modelname) | |
safe_to = SafeToGPU(size) | |
# Unless user explicitly wants to use CPU, we use GPU | |
device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU" | |
if device_mode == "Prefer GPU": | |
safe_to.to_device(sam, device) | |
is_auto_mode = device_mode == "AUTO" | |
sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to) | |
sam.sam_wrapper = sam_obj | |
print(f"Loads SAM model: {modelname} (device:{device_mode})") | |
return (sam,) | |
def make_sam_mask( | |
sam: SAMWrapper, | |
segs: tuple, | |
image: torch.Tensor, | |
detection_hint: bool, | |
dilation: int, | |
threshold: float, | |
bbox_expansion: int, | |
mask_hint_threshold: float, | |
mask_hint_use_negative: bool, | |
) -> torch.Tensor: | |
"""#### Create a SAM mask. | |
#### Args: | |
- `sam` (SAMWrapper): The SAM wrapper. | |
- `segs` (tuple): Segmentation information. | |
- `image` (torch.Tensor): The input image. | |
- `detection_hint` (bool): Whether to use detection hint. | |
- `dilation` (int): Dilation value. | |
- `threshold` (float): Threshold for mask selection. | |
- `bbox_expansion` (int): Bounding box expansion value. | |
- `mask_hint_threshold` (float): Mask hint threshold. | |
- `mask_hint_use_negative` (bool): Whether to use negative mask hint. | |
#### Returns: | |
- `torch.Tensor`: The created SAM mask. | |
""" | |
sam_obj = sam.sam_wrapper | |
sam_obj.prepare_device() | |
try: | |
image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) | |
total_masks = [] | |
# seg_shape = segs[0] | |
segs = segs[1] | |
for i in range(len(segs)): | |
bbox = segs[i].bbox | |
center = mask_util.center_of_bbox(bbox) | |
x1 = max(bbox[0] - bbox_expansion, 0) | |
y1 = max(bbox[1] - bbox_expansion, 0) | |
x2 = min(bbox[2] + bbox_expansion, image.shape[1]) | |
y2 = min(bbox[3] + bbox_expansion, image.shape[0]) | |
dilated_bbox = [x1, y1, x2, y2] | |
points = [] | |
plabs = [] | |
points.append(center) | |
plabs = [1] # 1 = foreground point, 0 = background point | |
detected_masks = sam_obj.predict( | |
image, points, plabs, dilated_bbox, threshold | |
) | |
total_masks += detected_masks | |
# merge every collected masks | |
mask = mask_util.combine_masks2(total_masks) | |
finally: | |
sam_obj.release_device() | |
if mask is not None: | |
mask = mask.float() | |
mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation) | |
mask = torch.from_numpy(mask) | |
mask = mask_util.make_3d_mask(mask) | |
return mask | |
else: | |
return None | |
class SAMDetectorCombined: | |
"""#### Class to combine SAM detection.""" | |
def doit( | |
self, | |
sam_model: SAMWrapper, | |
segs: tuple, | |
image: torch.Tensor, | |
detection_hint: bool, | |
dilation: int, | |
threshold: float, | |
bbox_expansion: int, | |
mask_hint_threshold: float, | |
mask_hint_use_negative: bool, | |
) -> tuple: | |
"""#### Combine SAM detection. | |
#### Args: | |
- `sam_model` (SAMWrapper): The SAM wrapper. | |
- `segs` (tuple): Segmentation information. | |
- `image` (torch.Tensor): The input image. | |
- `detection_hint` (bool): Whether to use detection hint. | |
- `dilation` (int): Dilation value. | |
- `threshold` (float): Threshold for mask selection. | |
- `bbox_expansion` (int): Bounding box expansion value. | |
- `mask_hint_threshold` (float): Mask hint threshold. | |
- `mask_hint_use_negative` (bool): Whether to use negative mask hint. | |
#### Returns: | |
- `tuple`: The combined SAM detection result. | |
""" | |
sam = make_sam_mask( | |
sam_model, | |
segs, | |
image, | |
detection_hint, | |
dilation, | |
threshold, | |
bbox_expansion, | |
mask_hint_threshold, | |
mask_hint_use_negative, | |
) | |
if sam is not None: | |
return (sam,) | |
else: | |
return None | |