zero-shot-seg / main.py
danieaneta's picture
Upload 3 files
945be6a verified
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from ultralytics import YOLO
from typing import Dict, List, Tuple, Union, Optional
from dataclasses import dataclass
@dataclass
class SegmentationResult:
"""Data class to store segmentation results"""
label: str
confidence: float
mask: np.ndarray
bounding_box: List[int]
class ObjectSegmenter:
"""A class for zero-shot object detection and segmentation"""
def __init__(self, device: Optional[str] = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
self._init_models()
def _init_models(self):
"""Initialize DINO and YOLO models"""
# Grounding DINO setup
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-tiny"
).to(self.device).eval()
# YOLO setup
self.yolo_model = YOLO('yolov8n-seg.pt')
def segment_objects(
self,
image: Union[Image.Image, np.ndarray, str],
objects: Union[str, List[str]],
box_threshold: float = 0.4,
text_threshold: float = 0.3
) -> List[SegmentationResult]:
"""Segment specified objects in the image"""
# Prepare image
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare text prompt
if isinstance(objects, list):
text_prompt = ". ".join(objects)
else:
text_prompt = objects
if not text_prompt.endswith('.'):
text_prompt += '.'
# Get DINO detections
dino_results = self._get_dino_detections(
image, text_prompt, box_threshold, text_threshold
)
# Get YOLO segmentation
yolo_results = self.yolo_model(image, verbose=False)[0]
# Match detections with segmentations
return self._process_results(dino_results, yolo_results)
@torch.no_grad()
def _get_dino_detections(
self,
image: Image.Image,
text_prompt: str,
box_threshold: float,
text_threshold: float
) -> dict:
"""Get object detections from Grounding DINO"""
inputs = self.dino_processor(
images=image,
text=text_prompt,
return_tensors="pt"
).to(self.device)
outputs = self.dino_model(**inputs)
results = self.dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]]
)[0]
return results
def _process_results(
self,
dino_results: dict,
yolo_results
) -> List[SegmentationResult]:
"""Match detections with segmentations and create result objects"""
segmentation_results = []
for box, score, label in zip(
dino_results["boxes"],
dino_results["scores"],
dino_results["labels"]
):
box = [int(x) for x in box.tolist()]
# Find best matching YOLO mask
best_mask = self._find_best_mask(box, yolo_results)
if best_mask is not None:
result = SegmentationResult(
label=label,
confidence=float(score),
mask=best_mask,
bounding_box=box
)
segmentation_results.append(result)
return segmentation_results
def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
"""Find best matching YOLO mask for a given bounding box"""
if len(yolo_results.masks) == 0:
return None
best_iou = 0
best_mask = None
for mask in yolo_results.masks.data:
mask_np = mask.cpu().numpy()
y_indices, x_indices = np.where(mask_np > 0)
if len(y_indices) == 0:
continue
mask_box = [
x_indices.min(),
y_indices.min(),
x_indices.max(),
y_indices.max()
]
iou = self._calculate_iou(box, mask_box)
if iou > best_iou:
best_iou = iou
best_mask = mask_np
return best_mask
@staticmethod
def _calculate_iou(box1: List[int], box2: List[int]) -> float:
"""Calculate Intersection over Union between two boxes"""
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection / (box1_area + box2_area - intersection)
# Initialize the segmenter
segmenter = ObjectSegmenter()