danieaneta commited on
Commit
945be6a
·
verified ·
1 Parent(s): 03f73fb

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +71 -0
  2. main.py +164 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import base64
4
+ import io
5
+ import numpy as np
6
+ from typing import List
7
+ from main import segmenter # Import the segmenter instance
8
+
9
+ def process_image(image: Image.Image, objects_text: str) -> dict:
10
+ """Process image and return results"""
11
+ try:
12
+ # Parse objects
13
+ objects = [obj.strip() for obj in objects_text.split('.') if obj.strip()]
14
+
15
+ # Use the segmenter to process the image
16
+ results = segmenter.segment_objects(image, objects)
17
+
18
+ # Create visualization of results
19
+ # For now, just returning the original image
20
+ buffered = io.BytesIO()
21
+ image.save(buffered, format="PNG")
22
+ img_str = base64.b64encode(buffered.getvalue()).decode()
23
+
24
+ # Format results for response
25
+ return {
26
+ "success": True,
27
+ "message": f"Processed image with objects: {objects}",
28
+ "image": img_str,
29
+ "results": [
30
+ {
31
+ "label": r.label,
32
+ "confidence": float(r.confidence),
33
+ "bounding_box": r.bounding_box
34
+ }
35
+ for r in results
36
+ ]
37
+ }
38
+ except Exception as e:
39
+ return {
40
+ "success": False,
41
+ "message": str(e),
42
+ "image": None,
43
+ "results": []
44
+ }
45
+
46
+ # Create Gradio interface with API mode enabled
47
+ demo = gr.Interface(
48
+ fn=process_image,
49
+ inputs=[
50
+ gr.Image(type="pil", label="Input Image"),
51
+ gr.Textbox(label="Objects (separate with dots)", placeholder="cat. dog. chair")
52
+ ],
53
+ outputs=gr.JSON(label="API Response"),
54
+ title="Zero Shot Segmentation",
55
+ description="Upload an image and specify objects to detect.",
56
+ allow_flagging="never",
57
+ examples=[
58
+ ["path/to/example.jpg", "cat. dog"]
59
+ ]
60
+ )
61
+
62
+ # Enable API access
63
+ demo.queue()
64
+
65
+ if __name__ == "__main__":
66
+ demo.launch(
67
+ share=True,
68
+ server_name="0.0.0.0",
69
+ server_port=7860,
70
+ show_api=True
71
+ )
main.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
5
+ from ultralytics import YOLO
6
+ from typing import Dict, List, Tuple, Union, Optional
7
+ from dataclasses import dataclass
8
+
9
+ @dataclass
10
+ class SegmentationResult:
11
+ """Data class to store segmentation results"""
12
+ label: str
13
+ confidence: float
14
+ mask: np.ndarray
15
+ bounding_box: List[int]
16
+
17
+ class ObjectSegmenter:
18
+ """A class for zero-shot object detection and segmentation"""
19
+ def __init__(self, device: Optional[str] = None):
20
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
21
+ torch.cuda.empty_cache()
22
+ self._init_models()
23
+
24
+ def _init_models(self):
25
+ """Initialize DINO and YOLO models"""
26
+ # Grounding DINO setup
27
+ self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
28
+ self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
29
+ "IDEA-Research/grounding-dino-tiny"
30
+ ).to(self.device).eval()
31
+
32
+ # YOLO setup
33
+ self.yolo_model = YOLO('yolov8n-seg.pt')
34
+
35
+ def segment_objects(
36
+ self,
37
+ image: Union[Image.Image, np.ndarray, str],
38
+ objects: Union[str, List[str]],
39
+ box_threshold: float = 0.4,
40
+ text_threshold: float = 0.3
41
+ ) -> List[SegmentationResult]:
42
+ """Segment specified objects in the image"""
43
+ # Prepare image
44
+ if isinstance(image, str):
45
+ image = Image.open(image)
46
+ elif isinstance(image, np.ndarray):
47
+ image = Image.fromarray(image)
48
+
49
+ if image.mode != 'RGB':
50
+ image = image.convert('RGB')
51
+
52
+ # Prepare text prompt
53
+ if isinstance(objects, list):
54
+ text_prompt = ". ".join(objects)
55
+ else:
56
+ text_prompt = objects
57
+ if not text_prompt.endswith('.'):
58
+ text_prompt += '.'
59
+
60
+ # Get DINO detections
61
+ dino_results = self._get_dino_detections(
62
+ image, text_prompt, box_threshold, text_threshold
63
+ )
64
+
65
+ # Get YOLO segmentation
66
+ yolo_results = self.yolo_model(image, verbose=False)[0]
67
+
68
+ # Match detections with segmentations
69
+ return self._process_results(dino_results, yolo_results)
70
+
71
+ @torch.no_grad()
72
+ def _get_dino_detections(
73
+ self,
74
+ image: Image.Image,
75
+ text_prompt: str,
76
+ box_threshold: float,
77
+ text_threshold: float
78
+ ) -> dict:
79
+ """Get object detections from Grounding DINO"""
80
+ inputs = self.dino_processor(
81
+ images=image,
82
+ text=text_prompt,
83
+ return_tensors="pt"
84
+ ).to(self.device)
85
+
86
+ outputs = self.dino_model(**inputs)
87
+ results = self.dino_processor.post_process_grounded_object_detection(
88
+ outputs,
89
+ inputs.input_ids,
90
+ box_threshold=box_threshold,
91
+ text_threshold=text_threshold,
92
+ target_sizes=[image.size[::-1]]
93
+ )[0]
94
+
95
+ return results
96
+
97
+ def _process_results(
98
+ self,
99
+ dino_results: dict,
100
+ yolo_results
101
+ ) -> List[SegmentationResult]:
102
+ """Match detections with segmentations and create result objects"""
103
+ segmentation_results = []
104
+
105
+ for box, score, label in zip(
106
+ dino_results["boxes"],
107
+ dino_results["scores"],
108
+ dino_results["labels"]
109
+ ):
110
+ box = [int(x) for x in box.tolist()]
111
+
112
+ # Find best matching YOLO mask
113
+ best_mask = self._find_best_mask(box, yolo_results)
114
+
115
+ if best_mask is not None:
116
+ result = SegmentationResult(
117
+ label=label,
118
+ confidence=float(score),
119
+ mask=best_mask,
120
+ bounding_box=box
121
+ )
122
+ segmentation_results.append(result)
123
+
124
+ return segmentation_results
125
+
126
+ def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
127
+ """Find best matching YOLO mask for a given bounding box"""
128
+ if len(yolo_results.masks) == 0:
129
+ return None
130
+
131
+ best_iou = 0
132
+ best_mask = None
133
+
134
+ for mask in yolo_results.masks.data:
135
+ mask_np = mask.cpu().numpy()
136
+ y_indices, x_indices = np.where(mask_np > 0)
137
+ if len(y_indices) == 0:
138
+ continue
139
+
140
+ mask_box = [
141
+ x_indices.min(),
142
+ y_indices.min(),
143
+ x_indices.max(),
144
+ y_indices.max()
145
+ ]
146
+
147
+ iou = self._calculate_iou(box, mask_box)
148
+ if iou > best_iou:
149
+ best_iou = iou
150
+ best_mask = mask_np
151
+
152
+ return best_mask
153
+
154
+ @staticmethod
155
+ def _calculate_iou(box1: List[int], box2: List[int]) -> float:
156
+ """Calculate Intersection over Union between two boxes"""
157
+ intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
158
+ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
159
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
160
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
161
+ return intersection / (box1_area + box2_area - intersection)
162
+
163
+ # Initialize the segmenter
164
+ segmenter = ObjectSegmenter()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ Pillow
3
+ numpy
4
+ torch
5
+ transformers
6
+ ultralytics