Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection | |
| from ultralytics import YOLO | |
| from typing import Dict, List, Tuple, Union, Optional | |
| from dataclasses import dataclass | |
| class SegmentationResult: | |
| """Data class to store segmentation results""" | |
| label: str | |
| confidence: float | |
| mask: np.ndarray | |
| bounding_box: List[int] | |
| class ObjectSegmenter: | |
| """A class for zero-shot object detection and segmentation""" | |
| def __init__(self, device: Optional[str] = None): | |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.empty_cache() | |
| self._init_models() | |
| def _init_models(self): | |
| """Initialize DINO and YOLO models""" | |
| # Grounding DINO setup | |
| self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") | |
| self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( | |
| "IDEA-Research/grounding-dino-tiny" | |
| ).to(self.device).eval() | |
| # YOLO setup | |
| self.yolo_model = YOLO('yolov8n-seg.pt') | |
| def segment_objects( | |
| self, | |
| image: Union[Image.Image, np.ndarray, str], | |
| objects: Union[str, List[str]], | |
| box_threshold: float = 0.4, | |
| text_threshold: float = 0.3 | |
| ) -> List[SegmentationResult]: | |
| """Segment specified objects in the image""" | |
| # Prepare image | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Prepare text prompt | |
| if isinstance(objects, list): | |
| text_prompt = ". ".join(objects) | |
| else: | |
| text_prompt = objects | |
| if not text_prompt.endswith('.'): | |
| text_prompt += '.' | |
| # Get DINO detections | |
| dino_results = self._get_dino_detections( | |
| image, text_prompt, box_threshold, text_threshold | |
| ) | |
| # Get YOLO segmentation | |
| yolo_results = self.yolo_model(image, verbose=False)[0] | |
| # Match detections with segmentations | |
| return self._process_results(dino_results, yolo_results) | |
| def _get_dino_detections( | |
| self, | |
| image: Image.Image, | |
| text_prompt: str, | |
| box_threshold: float, | |
| text_threshold: float | |
| ) -> dict: | |
| """Get object detections from Grounding DINO""" | |
| inputs = self.dino_processor( | |
| images=image, | |
| text=text_prompt, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| outputs = self.dino_model(**inputs) | |
| results = self.dino_processor.post_process_grounded_object_detection( | |
| outputs, | |
| inputs.input_ids, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| target_sizes=[image.size[::-1]] | |
| )[0] | |
| return results | |
| def _process_results( | |
| self, | |
| dino_results: dict, | |
| yolo_results | |
| ) -> List[SegmentationResult]: | |
| """Match detections with segmentations and create result objects""" | |
| segmentation_results = [] | |
| for box, score, label in zip( | |
| dino_results["boxes"], | |
| dino_results["scores"], | |
| dino_results["labels"] | |
| ): | |
| box = [int(x) for x in box.tolist()] | |
| # Find best matching YOLO mask | |
| best_mask = self._find_best_mask(box, yolo_results) | |
| if best_mask is not None: | |
| result = SegmentationResult( | |
| label=label, | |
| confidence=float(score), | |
| mask=best_mask, | |
| bounding_box=box | |
| ) | |
| segmentation_results.append(result) | |
| return segmentation_results | |
| def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]: | |
| """Find best matching YOLO mask for a given bounding box""" | |
| if len(yolo_results.masks) == 0: | |
| return None | |
| best_iou = 0 | |
| best_mask = None | |
| for mask in yolo_results.masks.data: | |
| mask_np = mask.cpu().numpy() | |
| y_indices, x_indices = np.where(mask_np > 0) | |
| if len(y_indices) == 0: | |
| continue | |
| mask_box = [ | |
| x_indices.min(), | |
| y_indices.min(), | |
| x_indices.max(), | |
| y_indices.max() | |
| ] | |
| iou = self._calculate_iou(box, mask_box) | |
| if iou > best_iou: | |
| best_iou = iou | |
| best_mask = mask_np | |
| return best_mask | |
| def _calculate_iou(box1: List[int], box2: List[int]) -> float: | |
| """Calculate Intersection over Union between two boxes""" | |
| intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ | |
| max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| return intersection / (box1_area + box2_area - intersection) | |
| # Initialize the segmenter | |
| segmenter = ObjectSegmenter() | |