import gradio as gr from transformers import pipeline import torch from PIL import Image, ImageDraw import numpy as np CUSTOM_CSS = """ .output-panel { padding: 15px; border-radius: 8px; background-color: #f8f9fa; } """ DESCRIPTION = """ # Zero-Shot Object Detection Demo This demo uses OWL-ViT model to perform zero-shot object detection. You can: - Upload an image or use your webcam - Specify objects you want to detect (comma-separated) - Adjust the confidence threshold - Get real-time detection results ## Instructions 1. Upload an image or use webcam 2. Enter objects to detect (e.g., "person, car, dog, chair") 3. Adjust confidence threshold if needed 4. Click "Detect Objects" to process """ class ObjectDetector: def __init__(self): self.device = 0 if torch.cuda.is_available() else -1 self.detector = pipeline( model="google/owlv2-base-patch16-ensemble", task="zero-shot-object-detection", device=self.device, ) def process_image( self, image, objects_to_detect, confidence_threshold=0.3, progress=gr.Progress() ): if image is None or not objects_to_detect: return None progress(0.2, "Processing image...") # Convert numpy array to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Parse objects to detect candidate_labels = [obj.strip() for obj in objects_to_detect.split(",")] progress(0.5, "Detecting objects...") # Run detection predictions = self.detector( image, candidate_labels=candidate_labels, threshold=confidence_threshold ) progress(0.8, "Drawing results...") # Draw predictions on image draw = ImageDraw.Draw(image) for prediction in predictions: box = prediction["box"] label = prediction["label"] score = prediction["score"] xmin, ymin, xmax, ymax = box.values() draw.rectangle([xmin, ymin, xmax, ymax], outline="red", width=2) draw.text((xmin, ymin - 10), f"{label}: {score:.2f}", fill="red") progress(1.0, "Done!") return image def create_demo(): detector = ObjectDetector() with gr.Blocks(css=CUSTOM_CSS) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="Input Image", type="pil", sources=["upload", "webcam"] ) objects_input = gr.Textbox( label="Objects to Detect (comma-separated)", placeholder="person, car, dog, chair", value="person, face, phone, laptop", ) confidence = gr.Slider( label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.3, step=0.1, ) with gr.Row(): clear_btn = gr.Button("Clear", variant="secondary") detect_btn = gr.Button("Detect Objects", variant="primary") with gr.Column(scale=1, elem_classes=["output-panel"]): output_image = gr.Image(label="Detection Results", type="pil") # Event handlers detect_btn.click( fn=detector.process_image, inputs=[input_image, objects_input, confidence], outputs=[output_image], ) clear_btn.click( fn=lambda: (None, None), inputs=[], outputs=[input_image, output_image] ) return demo if __name__ == "__main__": demo = create_demo() demo.launch(share=True, server_name="0.0.0.0", server_port=7860, show_error=True)