Spaces:
Sleeping
Sleeping
File size: 3,445 Bytes
d4c3acc 246f207 ceb95cf c21a752 648d371 b842d10 e3205c6 b842d10 648d371 ceb95cf e3205c6 ceb95cf d5aac08 aca9f11 e3205c6 ceb95cf b6fa5d6 e3205c6 ceb95cf 6564a3a b842d10 cfd90f9 e3205c6 cfd90f9 e3205c6 cfd90f9 e3205c6 b3747be 5769b69 4ce9fc8 e3205c6 5769b69 e3205c6 b842d10 e14364d aca9f11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Initialize the model
config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Reinitialize the model with the desired detection threshold
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
image_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
return pipeline(task='object-detection', model=model, image_processor=image_processor)
od_pipe = load_model(0.5) # Default threshold
def draw_detections(image, detections):
# Convert PIL image to a numpy array
np_image = np.array(image)
# Convert RGB to BGR for OpenCV
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
# Draw detections
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min = box['xmin']
y_min = box['ymin']
x_max = box['xmax']
y_max = box['ymax']
# Increase font size for better visibility
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
# Convert BGR to RGB for displaying
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
try:
# Check if the model threshold needs adjusting
if od_pipe.config.threshold != threshold:
od_pipe = load_model(threshold)
print("Model reloaded with new threshold:", threshold)
# Ensure input is a PIL image
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
# Run detection and return annotated image and results
pipeline_output = od_pipe(pil_image)
processed_image = draw_detections(pil_image, pipeline_output)
return processed_image, pipeline_output
except Exception as e:
error_message = f"An error occurred: {str(e)}"
print(error_message)
return pil_image, {"error": error_message}
# Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
inp_image = gr.Image(label="Input image")
slider = gr.Slider(minimum=0, maximum=1, step=0.05, label="Detection Sensitivity", value=0.5)
gr.Markdown("Adjust the slider to change detection sensitivity.")
btn_run = gr.Button('Run Detection')
with gr.Column():
with gr.Tab("Annotated Image"):
out_image = gr.Image()
with gr.Tab("Detection Results"):
out_json = gr.JSON()
btn_run.click(get_pipeline_prediction, inputs=[slider, inp_image], outputs=[out_image, out_json])
demo.launch() |