Gabolozano's picture
Update app.py
6ed6f8e verified
raw
history blame
2.95 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Pre-load the base configuration and models (without setting a threshold yet)
base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Adjust the configuration for the current threshold
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
# Create a new model instance with the updated configuration
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
# Image processor does not need to be re-loaded
return pipeline(task='object-detection', model=model, image_processor=base_processor)
# Initialize the pipeline with a default threshold
od_pipe = load_model(0.25) # Set a default threshold here
def draw_detections(image, detections):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
cv2.putText(np_image, f"{label} {score:.2f}", (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
return Image.fromarray(final_image)
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
od_pipe = load_model(threshold) # reload model with the specified threshold
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result)
return processed_image, result
except Exception as e:
return pil_image, {"error": str(e)}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
demo.launch()