Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
from ultralytics import YOLO | |
from typing import Dict, List, Tuple, Union, Optional | |
from dataclasses import dataclass | |
class SegmentationResult: | |
"""Data class to store segmentation results""" | |
label: str | |
confidence: float | |
mask: np.ndarray | |
bounding_box: List[int] | |
class ObjectSegmenter: | |
"""A class for zero-shot object detection and segmentation""" | |
def __init__(self, device: Optional[str] = None): | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
torch.cuda.empty_cache() | |
self._init_models() | |
def _init_models(self): | |
"""Initialize DINO and YOLO models""" | |
# Grounding DINO setup | |
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") | |
self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
"IDEA-Research/grounding-dino-tiny" | |
).to(self.device).eval() | |
# YOLO setup | |
self.yolo_model = YOLO('yolov8n-seg.pt') | |
def segment_objects( | |
self, | |
image: Union[Image.Image, np.ndarray, str], | |
objects: Union[str, List[str]], | |
box_threshold: float = 0.4, | |
text_threshold: float = 0.3 | |
) -> List[SegmentationResult]: | |
"""Segment specified objects in the image""" | |
# Prepare image | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Prepare text prompt | |
if isinstance(objects, list): | |
text_prompt = ". ".join(objects) | |
else: | |
text_prompt = objects | |
if not text_prompt.endswith('.'): | |
text_prompt += '.' | |
# Get DINO detections | |
dino_results = self._get_dino_detections( | |
image, text_prompt, box_threshold, text_threshold | |
) | |
# Get YOLO segmentation | |
yolo_results = self.yolo_model(image, verbose=False)[0] | |
# Match detections with segmentations | |
return self._process_results(dino_results, yolo_results) | |
def _get_dino_detections( | |
self, | |
image: Image.Image, | |
text_prompt: str, | |
box_threshold: float, | |
text_threshold: float | |
) -> dict: | |
"""Get object detections from Grounding DINO""" | |
inputs = self.dino_processor( | |
images=image, | |
text=text_prompt, | |
return_tensors="pt" | |
).to(self.device) | |
outputs = self.dino_model(**inputs) | |
results = self.dino_processor.post_process_grounded_object_detection( | |
outputs, | |
inputs.input_ids, | |
box_threshold=box_threshold, | |
text_threshold=text_threshold, | |
target_sizes=[image.size[::-1]] | |
)[0] | |
return results | |
def _process_results( | |
self, | |
dino_results: dict, | |
yolo_results | |
) -> List[SegmentationResult]: | |
"""Match detections with segmentations and create result objects""" | |
segmentation_results = [] | |
for box, score, label in zip( | |
dino_results["boxes"], | |
dino_results["scores"], | |
dino_results["labels"] | |
): | |
box = [int(x) for x in box.tolist()] | |
# Find best matching YOLO mask | |
best_mask = self._find_best_mask(box, yolo_results) | |
if best_mask is not None: | |
result = SegmentationResult( | |
label=label, | |
confidence=float(score), | |
mask=best_mask, | |
bounding_box=box | |
) | |
segmentation_results.append(result) | |
return segmentation_results | |
def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]: | |
"""Find best matching YOLO mask for a given bounding box""" | |
if len(yolo_results.masks) == 0: | |
return None | |
best_iou = 0 | |
best_mask = None | |
for mask in yolo_results.masks.data: | |
mask_np = mask.cpu().numpy() | |
y_indices, x_indices = np.where(mask_np > 0) | |
if len(y_indices) == 0: | |
continue | |
mask_box = [ | |
x_indices.min(), | |
y_indices.min(), | |
x_indices.max(), | |
y_indices.max() | |
] | |
iou = self._calculate_iou(box, mask_box) | |
if iou > best_iou: | |
best_iou = iou | |
best_mask = mask_np | |
return best_mask | |
def _calculate_iou(box1: List[int], box2: List[int]) -> float: | |
"""Calculate Intersection over Union between two boxes""" | |
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ | |
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) | |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
return intersection / (box1_area + box2_area - intersection) | |
# Initialize the segmenter | |
segmenter = ObjectSegmenter() | |