File size: 5,566 Bytes
945be6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from ultralytics import YOLO
from typing import Dict, List, Tuple, Union, Optional
from dataclasses import dataclass

@dataclass
class SegmentationResult:
    """Data class to store segmentation results"""
    label: str
    confidence: float
    mask: np.ndarray
    bounding_box: List[int]

class ObjectSegmenter:
    """A class for zero-shot object detection and segmentation"""
    def __init__(self, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.empty_cache()
        self._init_models()

    def _init_models(self):
        """Initialize DINO and YOLO models"""
        # Grounding DINO setup
        self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
        self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
            "IDEA-Research/grounding-dino-tiny"
        ).to(self.device).eval()
        
        # YOLO setup
        self.yolo_model = YOLO('yolov8n-seg.pt')

    def segment_objects(

        self,

        image: Union[Image.Image, np.ndarray, str],

        objects: Union[str, List[str]],

        box_threshold: float = 0.4,

        text_threshold: float = 0.3

    ) -> List[SegmentationResult]:
        """Segment specified objects in the image"""
        # Prepare image
        if isinstance(image, str):
            image = Image.open(image)
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Prepare text prompt
        if isinstance(objects, list):
            text_prompt = ". ".join(objects)
        else:
            text_prompt = objects
        if not text_prompt.endswith('.'):
            text_prompt += '.'

        # Get DINO detections
        dino_results = self._get_dino_detections(
            image, text_prompt, box_threshold, text_threshold
        )
        
        # Get YOLO segmentation
        yolo_results = self.yolo_model(image, verbose=False)[0]
        
        # Match detections with segmentations
        return self._process_results(dino_results, yolo_results)

    @torch.no_grad()
    def _get_dino_detections(

        self, 

        image: Image.Image, 

        text_prompt: str,

        box_threshold: float,

        text_threshold: float

    ) -> dict:
        """Get object detections from Grounding DINO"""
        inputs = self.dino_processor(
            images=image, 
            text=text_prompt, 
            return_tensors="pt"
        ).to(self.device)
        
        outputs = self.dino_model(**inputs)
        results = self.dino_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[image.size[::-1]]
        )[0]
        
        return results

    def _process_results(

        self, 

        dino_results: dict,

        yolo_results

    ) -> List[SegmentationResult]:
        """Match detections with segmentations and create result objects"""
        segmentation_results = []

        for box, score, label in zip(
            dino_results["boxes"],
            dino_results["scores"],
            dino_results["labels"]
        ):
            box = [int(x) for x in box.tolist()]
            
            # Find best matching YOLO mask
            best_mask = self._find_best_mask(box, yolo_results)
            
            if best_mask is not None:
                result = SegmentationResult(
                    label=label,
                    confidence=float(score),
                    mask=best_mask,
                    bounding_box=box
                )
                segmentation_results.append(result)

        return segmentation_results

    def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
        """Find best matching YOLO mask for a given bounding box"""
        if len(yolo_results.masks) == 0:
            return None

        best_iou = 0
        best_mask = None

        for mask in yolo_results.masks.data:
            mask_np = mask.cpu().numpy()
            y_indices, x_indices = np.where(mask_np > 0)
            if len(y_indices) == 0:
                continue
            
            mask_box = [
                x_indices.min(),
                y_indices.min(),
                x_indices.max(),
                y_indices.max()
            ]
            
            iou = self._calculate_iou(box, mask_box)
            if iou > best_iou:
                best_iou = iou
                best_mask = mask_np

        return best_mask

    @staticmethod
    def _calculate_iou(box1: List[int], box2: List[int]) -> float:
        """Calculate Intersection over Union between two boxes"""
        intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
                      max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
        return intersection / (box1_area + box2_area - intersection)

# Initialize the segmenter
segmenter = ObjectSegmenter()