Spaces:
Sleeping
Sleeping
File size: 2,948 Bytes
d4c3acc 246f207 ceb95cf c21a752 6ed6f8e 648d371 b842d10 6ed6f8e e3205c6 6ed6f8e b842d10 6ed6f8e b842d10 6ed6f8e 648d371 ceb95cf d5aac08 aca9f11 6ed6f8e ceb95cf 6ed6f8e ceb95cf 6ed6f8e 6564a3a b842d10 6ed6f8e cfd90f9 e3205c6 6ed6f8e cfd90f9 6ed6f8e cfd90f9 b3747be 5769b69 6ed6f8e 5769b69 6ed6f8e 5769b69 6ed6f8e e3205c6 6ed6f8e e14364d aca9f11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Pre-load the base configuration and models (without setting a threshold yet)
base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Adjust the configuration for the current threshold
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
# Create a new model instance with the updated configuration
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
# Image processor does not need to be re-loaded
return pipeline(task='object-detection', model=model, image_processor=base_processor)
# Initialize the pipeline with a default threshold
od_pipe = load_model(0.25) # Set a default threshold here
def draw_detections(image, detections):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.putText(np_image, f"{label} {score:.2f}", (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
return Image.fromarray(final_image)
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
od_pipe = load_model(threshold) # reload model with the specified threshold
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result)
return processed_image, result
except Exception as e:
return pil_image, {"error": str(e)}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
demo.launch() |