Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +71 -0
- main.py +164 -0
- requirements.txt +6 -0
app.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import numpy as np
|
6 |
+
from typing import List
|
7 |
+
from main import segmenter # Import the segmenter instance
|
8 |
+
|
9 |
+
def process_image(image: Image.Image, objects_text: str) -> dict:
|
10 |
+
"""Process image and return results"""
|
11 |
+
try:
|
12 |
+
# Parse objects
|
13 |
+
objects = [obj.strip() for obj in objects_text.split('.') if obj.strip()]
|
14 |
+
|
15 |
+
# Use the segmenter to process the image
|
16 |
+
results = segmenter.segment_objects(image, objects)
|
17 |
+
|
18 |
+
# Create visualization of results
|
19 |
+
# For now, just returning the original image
|
20 |
+
buffered = io.BytesIO()
|
21 |
+
image.save(buffered, format="PNG")
|
22 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
23 |
+
|
24 |
+
# Format results for response
|
25 |
+
return {
|
26 |
+
"success": True,
|
27 |
+
"message": f"Processed image with objects: {objects}",
|
28 |
+
"image": img_str,
|
29 |
+
"results": [
|
30 |
+
{
|
31 |
+
"label": r.label,
|
32 |
+
"confidence": float(r.confidence),
|
33 |
+
"bounding_box": r.bounding_box
|
34 |
+
}
|
35 |
+
for r in results
|
36 |
+
]
|
37 |
+
}
|
38 |
+
except Exception as e:
|
39 |
+
return {
|
40 |
+
"success": False,
|
41 |
+
"message": str(e),
|
42 |
+
"image": None,
|
43 |
+
"results": []
|
44 |
+
}
|
45 |
+
|
46 |
+
# Create Gradio interface with API mode enabled
|
47 |
+
demo = gr.Interface(
|
48 |
+
fn=process_image,
|
49 |
+
inputs=[
|
50 |
+
gr.Image(type="pil", label="Input Image"),
|
51 |
+
gr.Textbox(label="Objects (separate with dots)", placeholder="cat. dog. chair")
|
52 |
+
],
|
53 |
+
outputs=gr.JSON(label="API Response"),
|
54 |
+
title="Zero Shot Segmentation",
|
55 |
+
description="Upload an image and specify objects to detect.",
|
56 |
+
allow_flagging="never",
|
57 |
+
examples=[
|
58 |
+
["path/to/example.jpg", "cat. dog"]
|
59 |
+
]
|
60 |
+
)
|
61 |
+
|
62 |
+
# Enable API access
|
63 |
+
demo.queue()
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
demo.launch(
|
67 |
+
share=True,
|
68 |
+
server_name="0.0.0.0",
|
69 |
+
server_port=7860,
|
70 |
+
show_api=True
|
71 |
+
)
|
main.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
5 |
+
from ultralytics import YOLO
|
6 |
+
from typing import Dict, List, Tuple, Union, Optional
|
7 |
+
from dataclasses import dataclass
|
8 |
+
|
9 |
+
@dataclass
|
10 |
+
class SegmentationResult:
|
11 |
+
"""Data class to store segmentation results"""
|
12 |
+
label: str
|
13 |
+
confidence: float
|
14 |
+
mask: np.ndarray
|
15 |
+
bounding_box: List[int]
|
16 |
+
|
17 |
+
class ObjectSegmenter:
|
18 |
+
"""A class for zero-shot object detection and segmentation"""
|
19 |
+
def __init__(self, device: Optional[str] = None):
|
20 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
+
torch.cuda.empty_cache()
|
22 |
+
self._init_models()
|
23 |
+
|
24 |
+
def _init_models(self):
|
25 |
+
"""Initialize DINO and YOLO models"""
|
26 |
+
# Grounding DINO setup
|
27 |
+
self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
|
28 |
+
self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
|
29 |
+
"IDEA-Research/grounding-dino-tiny"
|
30 |
+
).to(self.device).eval()
|
31 |
+
|
32 |
+
# YOLO setup
|
33 |
+
self.yolo_model = YOLO('yolov8n-seg.pt')
|
34 |
+
|
35 |
+
def segment_objects(
|
36 |
+
self,
|
37 |
+
image: Union[Image.Image, np.ndarray, str],
|
38 |
+
objects: Union[str, List[str]],
|
39 |
+
box_threshold: float = 0.4,
|
40 |
+
text_threshold: float = 0.3
|
41 |
+
) -> List[SegmentationResult]:
|
42 |
+
"""Segment specified objects in the image"""
|
43 |
+
# Prepare image
|
44 |
+
if isinstance(image, str):
|
45 |
+
image = Image.open(image)
|
46 |
+
elif isinstance(image, np.ndarray):
|
47 |
+
image = Image.fromarray(image)
|
48 |
+
|
49 |
+
if image.mode != 'RGB':
|
50 |
+
image = image.convert('RGB')
|
51 |
+
|
52 |
+
# Prepare text prompt
|
53 |
+
if isinstance(objects, list):
|
54 |
+
text_prompt = ". ".join(objects)
|
55 |
+
else:
|
56 |
+
text_prompt = objects
|
57 |
+
if not text_prompt.endswith('.'):
|
58 |
+
text_prompt += '.'
|
59 |
+
|
60 |
+
# Get DINO detections
|
61 |
+
dino_results = self._get_dino_detections(
|
62 |
+
image, text_prompt, box_threshold, text_threshold
|
63 |
+
)
|
64 |
+
|
65 |
+
# Get YOLO segmentation
|
66 |
+
yolo_results = self.yolo_model(image, verbose=False)[0]
|
67 |
+
|
68 |
+
# Match detections with segmentations
|
69 |
+
return self._process_results(dino_results, yolo_results)
|
70 |
+
|
71 |
+
@torch.no_grad()
|
72 |
+
def _get_dino_detections(
|
73 |
+
self,
|
74 |
+
image: Image.Image,
|
75 |
+
text_prompt: str,
|
76 |
+
box_threshold: float,
|
77 |
+
text_threshold: float
|
78 |
+
) -> dict:
|
79 |
+
"""Get object detections from Grounding DINO"""
|
80 |
+
inputs = self.dino_processor(
|
81 |
+
images=image,
|
82 |
+
text=text_prompt,
|
83 |
+
return_tensors="pt"
|
84 |
+
).to(self.device)
|
85 |
+
|
86 |
+
outputs = self.dino_model(**inputs)
|
87 |
+
results = self.dino_processor.post_process_grounded_object_detection(
|
88 |
+
outputs,
|
89 |
+
inputs.input_ids,
|
90 |
+
box_threshold=box_threshold,
|
91 |
+
text_threshold=text_threshold,
|
92 |
+
target_sizes=[image.size[::-1]]
|
93 |
+
)[0]
|
94 |
+
|
95 |
+
return results
|
96 |
+
|
97 |
+
def _process_results(
|
98 |
+
self,
|
99 |
+
dino_results: dict,
|
100 |
+
yolo_results
|
101 |
+
) -> List[SegmentationResult]:
|
102 |
+
"""Match detections with segmentations and create result objects"""
|
103 |
+
segmentation_results = []
|
104 |
+
|
105 |
+
for box, score, label in zip(
|
106 |
+
dino_results["boxes"],
|
107 |
+
dino_results["scores"],
|
108 |
+
dino_results["labels"]
|
109 |
+
):
|
110 |
+
box = [int(x) for x in box.tolist()]
|
111 |
+
|
112 |
+
# Find best matching YOLO mask
|
113 |
+
best_mask = self._find_best_mask(box, yolo_results)
|
114 |
+
|
115 |
+
if best_mask is not None:
|
116 |
+
result = SegmentationResult(
|
117 |
+
label=label,
|
118 |
+
confidence=float(score),
|
119 |
+
mask=best_mask,
|
120 |
+
bounding_box=box
|
121 |
+
)
|
122 |
+
segmentation_results.append(result)
|
123 |
+
|
124 |
+
return segmentation_results
|
125 |
+
|
126 |
+
def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
|
127 |
+
"""Find best matching YOLO mask for a given bounding box"""
|
128 |
+
if len(yolo_results.masks) == 0:
|
129 |
+
return None
|
130 |
+
|
131 |
+
best_iou = 0
|
132 |
+
best_mask = None
|
133 |
+
|
134 |
+
for mask in yolo_results.masks.data:
|
135 |
+
mask_np = mask.cpu().numpy()
|
136 |
+
y_indices, x_indices = np.where(mask_np > 0)
|
137 |
+
if len(y_indices) == 0:
|
138 |
+
continue
|
139 |
+
|
140 |
+
mask_box = [
|
141 |
+
x_indices.min(),
|
142 |
+
y_indices.min(),
|
143 |
+
x_indices.max(),
|
144 |
+
y_indices.max()
|
145 |
+
]
|
146 |
+
|
147 |
+
iou = self._calculate_iou(box, mask_box)
|
148 |
+
if iou > best_iou:
|
149 |
+
best_iou = iou
|
150 |
+
best_mask = mask_np
|
151 |
+
|
152 |
+
return best_mask
|
153 |
+
|
154 |
+
@staticmethod
|
155 |
+
def _calculate_iou(box1: List[int], box2: List[int]) -> float:
|
156 |
+
"""Calculate Intersection over Union between two boxes"""
|
157 |
+
intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
|
158 |
+
max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
|
159 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
160 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
161 |
+
return intersection / (box1_area + box2_area - intersection)
|
162 |
+
|
163 |
+
# Initialize the segmenter
|
164 |
+
segmenter = ObjectSegmenter()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==4.44.1
|
2 |
+
Pillow
|
3 |
+
numpy
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
ultralytics
|