Gabolozano's picture
Update app.py
b842d10 verified
raw
history blame
3.23 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Initialize the model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Since changing threshold at runtime for models isn't typically supported directly by the transformers pipeline,
# we reinitialize the model with the desired configuration when needed.
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", num_labels=91, threshold=threshold)
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
return pipeline(task='object-detection', model=model, image_processor=image_processor)
od_pipe = load_model(0.5) # Default threshold
def draw_detections(image, detections):
# Convert PIL image to a numpy array
np_image = np.array(image)
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min = box['xmin']
y_min = box['ymin']
x_max = box['xmax']
y_max = box['ymax']
# Draw rectangles and text with a larger font
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
# Convert BGR to RGB for displaying
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
if od_pipe.config.threshold != threshold:
od_pipe = load_model(threshold)
try:
pil_image = Image.fromarray(np.array(pil_image))
pipeline_output = od_pipe(pil_image)
processed_image = draw_detections(pil_image, pipeline_output)
return processed_image, pipeline_output
except Exception as e:
print(f"An error occurred: {str(e)}")
return pil_image, {"error": str(e)}
# Define the Gradio blocks interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inp_image = gr.Image(label="Input image")
slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Adjust Detection Sensitivity", value=0.5)
gr.Markdown("Adjust the slider to change the detection sensitivity.")
btn_run = gr.Button('Run Detection')
with gr.Column():
with gr.Tab("Annotated Image"):
out_image = gr.Image()
with gr.Tab("Detection Results"):
out_json = gr.JSON()
btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
demo.launch()