File size: 3,856 Bytes
6498519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import numpy as np

CUSTOM_CSS = """
.output-panel {
    padding: 15px;
    border-radius: 8px;
    background-color: #f8f9fa;
}
"""

DESCRIPTION = """
# Zero-Shot Object Detection Demo

This demo uses OWL-ViT model to perform zero-shot object detection. You can:
- Upload an image or use your webcam
- Specify objects you want to detect (comma-separated)
- Adjust the confidence threshold
- Get real-time detection results

## Instructions
1. Upload an image or use webcam
2. Enter objects to detect (e.g., "person, car, dog, chair")
3. Adjust confidence threshold if needed
4. Click "Detect Objects" to process
"""


class ObjectDetector:
    def __init__(self):
        self.device = 0 if torch.cuda.is_available() else -1
        self.detector = pipeline(
            model="google/owlv2-base-patch16-ensemble",
            task="zero-shot-object-detection",
            device=self.device,
        )

    def process_image(
        self, image, objects_to_detect, confidence_threshold=0.3, progress=gr.Progress()
    ):
        if image is None or not objects_to_detect:
            return None

        progress(0.2, "Processing image...")

        # Convert numpy array to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        # Parse objects to detect
        candidate_labels = [obj.strip() for obj in objects_to_detect.split(",")]

        progress(0.5, "Detecting objects...")
        # Run detection
        predictions = self.detector(
            image, candidate_labels=candidate_labels, threshold=confidence_threshold
        )

        progress(0.8, "Drawing results...")
        # Draw predictions on image
        draw = ImageDraw.Draw(image)

        for prediction in predictions:
            box = prediction["box"]
            label = prediction["label"]
            score = prediction["score"]

            xmin, ymin, xmax, ymax = box.values()
            draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
            draw.text((xmin, ymin - 10), f"{label}: {score:.2f}", fill="red")

        progress(1.0, "Done!")
        return image


def create_demo():
    detector = ObjectDetector()

    with gr.Blocks(css=CUSTOM_CSS) as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column(scale=1):
                input_image = gr.Image(
                    label="Input Image", type="pil", sources=["upload", "webcam"]
                )

                objects_input = gr.Textbox(
                    label="Objects to Detect (comma-separated)",
                    placeholder="person, car, dog, chair",
                    value="person, face, phone, laptop",
                )

                confidence = gr.Slider(
                    label="Confidence Threshold",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.3,
                    step=0.1,
                )

                with gr.Row():
                    clear_btn = gr.Button("Clear", variant="secondary")
                    detect_btn = gr.Button("Detect Objects", variant="primary")

            with gr.Column(scale=1, elem_classes=["output-panel"]):
                output_image = gr.Image(label="Detection Results", type="pil")

        # Event handlers
        detect_btn.click(
            fn=detector.process_image,
            inputs=[input_image, objects_input, confidence],
            outputs=[output_image],
        )

        clear_btn.click(
            fn=lambda: (None, None), inputs=[], outputs=[input_image, output_image]
        )

    return demo


if __name__ == "__main__":
    demo = create_demo()
    demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)