Aatricks's picture
Upload folder using huggingface_hub
d9a2e19 verified
import os
import numpy as np
from segment_anything import SamPredictor, sam_model_registry
import torch
from modules.AutoDetailer import mask_util
from modules.Device import Device
def sam_predict(
predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float
) -> list:
"""#### Predict masks using SAM.
#### Args:
- `predictor` (SamPredictor): The SAM predictor.
- `points` (list): List of points.
- `plabs` (list): List of point labels.
- `bbox` (list): Bounding box.
- `threshold` (float): Threshold for mask selection.
#### Returns:
- `list`: List of predicted masks.
"""
point_coords = None if not points else np.array(points)
point_labels = None if not plabs else np.array(plabs)
box = np.array([bbox]) if bbox is not None else None
cur_masks, scores, _ = predictor.predict(
point_coords=point_coords, point_labels=point_labels, box=box
)
total_masks = []
selected = False
max_score = 0
max_mask = None
for idx in range(len(scores)):
if scores[idx] > max_score:
max_score = scores[idx]
max_mask = cur_masks[idx]
if scores[idx] >= threshold:
selected = True
total_masks.append(cur_masks[idx])
else:
pass
if not selected and max_mask is not None:
total_masks.append(max_mask)
return total_masks
def is_same_device(a: torch.device, b: torch.device) -> bool:
"""#### Check if two devices are the same.
#### Args:
- `a` (torch.device): The first device.
- `b` (torch.device): The second device.
#### Returns:
- `bool`: Whether the devices are the same.
"""
a_device = torch.device(a) if isinstance(a, str) else a
b_device = torch.device(b) if isinstance(b, str) else b
return a_device.type == b_device.type and a_device.index == b_device.index
class SafeToGPU:
"""#### Class to safely move objects to GPU."""
def __init__(self, size: int):
self.size = size
def to_device(self, obj: torch.nn.Module, device: torch.device) -> None:
"""#### Move an object to a device.
#### Args:
- `obj` (torch.nn.Module): The object to move.
- `device` (torch.device): The target device.
"""
if is_same_device(device, "cpu"):
obj.to(device)
else:
if is_same_device(obj.device, "cpu"): # cpu to gpu
Device.free_memory(self.size * 1.3, device)
if Device.get_free_memory(device) > self.size * 1.3:
try:
obj.to(device)
except:
print(
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]"
)
else:
print(
f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]"
)
class SAMWrapper:
"""#### Wrapper class for SAM model."""
def __init__(
self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None
):
self.model = model
self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU()
self.is_auto_mode = is_auto_mode
def prepare_device(self) -> None:
"""#### Prepare the device for the model."""
if self.is_auto_mode:
device = Device.get_torch_device()
self.safe_to_gpu.to_device(self.model, device=device)
def release_device(self) -> None:
"""#### Release the device from the model."""
if self.is_auto_mode:
self.model.to(device="cpu")
def predict(
self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float
) -> list:
"""#### Predict masks using the SAM model.
#### Args:
- `image` (np.ndarray): The input image.
- `points` (list): List of points.
- `plabs` (list): List of point labels.
- `bbox` (list): Bounding box.
- `threshold` (float): Threshold for mask selection.
#### Returns:
- `list`: List of predicted masks.
"""
predictor = SamPredictor(self.model)
predictor.set_image(image, "RGB")
return sam_predict(predictor, points, plabs, bbox, threshold)
class SAMLoader:
"""#### Class to load SAM models."""
def load_model(self, model_name: str, device_mode: str = "auto") -> tuple:
"""#### Load a SAM model.
#### Args:
- `model_name` (str): The name of the model.
- `device_mode` (str, optional): The device mode. Defaults to "auto".
#### Returns:
- `tuple`: The loaded SAM model.
"""
modelname = "./_internal/yolos/" + model_name
if "vit_h" in model_name:
model_kind = "vit_h"
elif "vit_l" in model_name:
model_kind = "vit_l"
else:
model_kind = "vit_b"
sam = sam_model_registry[model_kind](checkpoint=modelname)
size = os.path.getsize(modelname)
safe_to = SafeToGPU(size)
# Unless user explicitly wants to use CPU, we use GPU
device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU"
if device_mode == "Prefer GPU":
safe_to.to_device(sam, device)
is_auto_mode = device_mode == "AUTO"
sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to)
sam.sam_wrapper = sam_obj
print(f"Loads SAM model: {modelname} (device:{device_mode})")
return (sam,)
def make_sam_mask(
sam: SAMWrapper,
segs: tuple,
image: torch.Tensor,
detection_hint: bool,
dilation: int,
threshold: float,
bbox_expansion: int,
mask_hint_threshold: float,
mask_hint_use_negative: bool,
) -> torch.Tensor:
"""#### Create a SAM mask.
#### Args:
- `sam` (SAMWrapper): The SAM wrapper.
- `segs` (tuple): Segmentation information.
- `image` (torch.Tensor): The input image.
- `detection_hint` (bool): Whether to use detection hint.
- `dilation` (int): Dilation value.
- `threshold` (float): Threshold for mask selection.
- `bbox_expansion` (int): Bounding box expansion value.
- `mask_hint_threshold` (float): Mask hint threshold.
- `mask_hint_use_negative` (bool): Whether to use negative mask hint.
#### Returns:
- `torch.Tensor`: The created SAM mask.
"""
sam_obj = sam.sam_wrapper
sam_obj.prepare_device()
try:
image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
total_masks = []
# seg_shape = segs[0]
segs = segs[1]
for i in range(len(segs)):
bbox = segs[i].bbox
center = mask_util.center_of_bbox(bbox)
x1 = max(bbox[0] - bbox_expansion, 0)
y1 = max(bbox[1] - bbox_expansion, 0)
x2 = min(bbox[2] + bbox_expansion, image.shape[1])
y2 = min(bbox[3] + bbox_expansion, image.shape[0])
dilated_bbox = [x1, y1, x2, y2]
points = []
plabs = []
points.append(center)
plabs = [1] # 1 = foreground point, 0 = background point
detected_masks = sam_obj.predict(
image, points, plabs, dilated_bbox, threshold
)
total_masks += detected_masks
# merge every collected masks
mask = mask_util.combine_masks2(total_masks)
finally:
sam_obj.release_device()
if mask is not None:
mask = mask.float()
mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation)
mask = torch.from_numpy(mask)
mask = mask_util.make_3d_mask(mask)
return mask
else:
return None
class SAMDetectorCombined:
"""#### Class to combine SAM detection."""
def doit(
self,
sam_model: SAMWrapper,
segs: tuple,
image: torch.Tensor,
detection_hint: bool,
dilation: int,
threshold: float,
bbox_expansion: int,
mask_hint_threshold: float,
mask_hint_use_negative: bool,
) -> tuple:
"""#### Combine SAM detection.
#### Args:
- `sam_model` (SAMWrapper): The SAM wrapper.
- `segs` (tuple): Segmentation information.
- `image` (torch.Tensor): The input image.
- `detection_hint` (bool): Whether to use detection hint.
- `dilation` (int): Dilation value.
- `threshold` (float): Threshold for mask selection.
- `bbox_expansion` (int): Bounding box expansion value.
- `mask_hint_threshold` (float): Mask hint threshold.
- `mask_hint_use_negative` (bool): Whether to use negative mask hint.
#### Returns:
- `tuple`: The combined SAM detection result.
"""
sam = make_sam_mask(
sam_model,
segs,
image,
detection_hint,
dilation,
threshold,
bbox_expansion,
mask_hint_threshold,
mask_hint_use_negative,
)
if sam is not None:
return (sam,)
else:
return None