Spaces:
Sleeping
Sleeping
File size: 5,566 Bytes
945be6a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from ultralytics import YOLO
from typing import Dict, List, Tuple, Union, Optional
from dataclasses import dataclass
@dataclass
class SegmentationResult:
"""Data class to store segmentation results"""
label: str
confidence: float
mask: np.ndarray
bounding_box: List[int]
class ObjectSegmenter:
"""A class for zero-shot object detection and segmentation"""
def __init__(self, device: Optional[str] = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
self._init_models()
def _init_models(self):
"""Initialize DINO and YOLO models"""
# Grounding DINO setup
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-tiny"
).to(self.device).eval()
# YOLO setup
self.yolo_model = YOLO('yolov8n-seg.pt')
def segment_objects(
self,
image: Union[Image.Image, np.ndarray, str],
objects: Union[str, List[str]],
box_threshold: float = 0.4,
text_threshold: float = 0.3
) -> List[SegmentationResult]:
"""Segment specified objects in the image"""
# Prepare image
if isinstance(image, str):
image = Image.open(image)
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare text prompt
if isinstance(objects, list):
text_prompt = ". ".join(objects)
else:
text_prompt = objects
if not text_prompt.endswith('.'):
text_prompt += '.'
# Get DINO detections
dino_results = self._get_dino_detections(
image, text_prompt, box_threshold, text_threshold
)
# Get YOLO segmentation
yolo_results = self.yolo_model(image, verbose=False)[0]
# Match detections with segmentations
return self._process_results(dino_results, yolo_results)
@torch.no_grad()
def _get_dino_detections(
self,
image: Image.Image,
text_prompt: str,
box_threshold: float,
text_threshold: float
) -> dict:
"""Get object detections from Grounding DINO"""
inputs = self.dino_processor(
images=image,
text=text_prompt,
return_tensors="pt"
).to(self.device)
outputs = self.dino_model(**inputs)
results = self.dino_processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[image.size[::-1]]
)[0]
return results
def _process_results(
self,
dino_results: dict,
yolo_results
) -> List[SegmentationResult]:
"""Match detections with segmentations and create result objects"""
segmentation_results = []
for box, score, label in zip(
dino_results["boxes"],
dino_results["scores"],
dino_results["labels"]
):
box = [int(x) for x in box.tolist()]
# Find best matching YOLO mask
best_mask = self._find_best_mask(box, yolo_results)
if best_mask is not None:
result = SegmentationResult(
label=label,
confidence=float(score),
mask=best_mask,
bounding_box=box
)
segmentation_results.append(result)
return segmentation_results
def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
"""Find best matching YOLO mask for a given bounding box"""
if len(yolo_results.masks) == 0:
return None
best_iou = 0
best_mask = None
for mask in yolo_results.masks.data:
mask_np = mask.cpu().numpy()
y_indices, x_indices = np.where(mask_np > 0)
if len(y_indices) == 0:
continue
mask_box = [
x_indices.min(),
y_indices.min(),
x_indices.max(),
y_indices.max()
]
iou = self._calculate_iou(box, mask_box)
if iou > best_iou:
best_iou = iou
best_mask = mask_np
return best_mask
@staticmethod
def _calculate_iou(box1: List[int], box2: List[int]) -> float:
"""Calculate Intersection over Union between two boxes"""
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
return intersection / (box1_area + box2_area - intersection)
# Initialize the segmenter
segmenter = ObjectSegmenter()
|