Spaces:
Sleeping
Sleeping
File size: 3,856 Bytes
6498519 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import gradio as gr
from transformers import pipeline
import torch
from PIL import Image, ImageDraw
import numpy as np
CUSTOM_CSS = """
.output-panel {
padding: 15px;
border-radius: 8px;
background-color: #f8f9fa;
}
"""
DESCRIPTION = """
# Zero-Shot Object Detection Demo
This demo uses OWL-ViT model to perform zero-shot object detection. You can:
- Upload an image or use your webcam
- Specify objects you want to detect (comma-separated)
- Adjust the confidence threshold
- Get real-time detection results
## Instructions
1. Upload an image or use webcam
2. Enter objects to detect (e.g., "person, car, dog, chair")
3. Adjust confidence threshold if needed
4. Click "Detect Objects" to process
"""
class ObjectDetector:
def __init__(self):
self.device = 0 if torch.cuda.is_available() else -1
self.detector = pipeline(
model="google/owlv2-base-patch16-ensemble",
task="zero-shot-object-detection",
device=self.device,
)
def process_image(
self, image, objects_to_detect, confidence_threshold=0.3, progress=gr.Progress()
):
if image is None or not objects_to_detect:
return None
progress(0.2, "Processing image...")
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Parse objects to detect
candidate_labels = [obj.strip() for obj in objects_to_detect.split(",")]
progress(0.5, "Detecting objects...")
# Run detection
predictions = self.detector(
image, candidate_labels=candidate_labels, threshold=confidence_threshold
)
progress(0.8, "Drawing results...")
# Draw predictions on image
draw = ImageDraw.Draw(image)
for prediction in predictions:
box = prediction["box"]
label = prediction["label"]
score = prediction["score"]
xmin, ymin, xmax, ymax = box.values()
draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2)
draw.text((xmin, ymin - 10), f"{label}: {score:.2f}", fill="red")
progress(1.0, "Done!")
return image
def create_demo():
detector = ObjectDetector()
with gr.Blocks(css=CUSTOM_CSS) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Input Image", type="pil", sources=["upload", "webcam"]
)
objects_input = gr.Textbox(
label="Objects to Detect (comma-separated)",
placeholder="person, car, dog, chair",
value="person, face, phone, laptop",
)
confidence = gr.Slider(
label="Confidence Threshold",
minimum=0.1,
maximum=1.0,
value=0.3,
step=0.1,
)
with gr.Row():
clear_btn = gr.Button("Clear", variant="secondary")
detect_btn = gr.Button("Detect Objects", variant="primary")
with gr.Column(scale=1, elem_classes=["output-panel"]):
output_image = gr.Image(label="Detection Results", type="pil")
# Event handlers
detect_btn.click(
fn=detector.process_image,
inputs=[input_image, objects_input, confidence],
outputs=[output_image],
)
clear_btn.click(
fn=lambda: (None, None), inputs=[], outputs=[input_image, output_image]
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)
|