Gabolozano's picture
Update app.py
a897877 verified
raw
history blame
3.87 kB
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
import warnings
import logging
# To suppress all warnings entries
warnings.filterwarnings('ignore')
# To ignore specific loggings from the Transformers library
logging.getLogger("transformers").setLevel(logging.ERROR)
def model_is_panoptic(model_name):
return "panoptic" in model_name
def load_model(model_name, threshold):
config = DetrConfig.from_pretrained(model_name, threshold=threshold)
model = DetrForObjectDetection.from_pretrained(model_name, config=config)
image_processor = DetrImageProcessor.from_pretrained(model_name)
return pipeline(task='object-detection', model=model, image_processor=image_processor)
# Initial model with default threshold
od_pipe = load_model("facebook/detr-resnet-101", 0.25)
def draw_detections(image, detections, model_name):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
if model_is_panoptic(model_name):
# Handle segmentations for panoptic models
mask = detection['mask']
color = np.random.randint(0, 255, size=3)
mask = np.round(mask * 255).astype(np.uint8)
mask = cv2.resize(mask, (image.width, image.height))
mask_image = np.stack([mask]*3, axis=-1)
np_image[mask == 255] = np_image[mask == 255] * 0.5 + color * 0.5
else:
# Handle bounding boxes for standard models
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(model_name, threshold, pil_image):
global od_pipe
od_pipe = load_model(model_name, threshold)
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result, model_name)
description = f'Model used: {model_name}, Detection Threshold: {threshold}'
return processed_image, result, description
except Exception as e:
return pil_image, {"error": str(e)}, "Failed to process image"
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
model_dropdown = gr.Dropdown(choices=["facebook/detr-resnet-50", "facebook/detr-resnet-50-panoptic", "facebook/detr-resnet-101", "facebook/detr-resnet-101-panoptic"], value="facebook/detr-resnet-101", label="Select Model")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
with gr.Tab("Description"):
description_output = gr.Textbox()
run_button.click(get_pipeline_prediction, inputs=[model_dropdown, threshold_slider, inp_image], outputs=[output_image, output_data, description_output])
demo.launch()