Gabolozano's picture
Update app.py
793cc29 verified
raw
history blame
3.17 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Pre-load the base configuration and models (without setting a threshold yet)
base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Adjust the configuration for the current threshold
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
# Create a new model instance with the updated configuration
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
# Image processor does not need to be re-loaded
return pipeline(task='object-detection', model=model, image_processor=base_processor)
# Initialize the pipeline with a default threshold
od_pipe = load_model(0.25) # Set a default threshold here
def draw_detections(image, detections):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
# Draw rectangles and text with a larger font
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
# Increase the font size and text thickness
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
# Convert BGR to RGB for displaying
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
od_pipe = load_model(threshold) # reload model with the specified threshold
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result)
return processed_image, result
except Exception as e:
return pil_image, {"error": str(e)}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
demo.launch()