Spaces:
Sleeping
Sleeping
File size: 3,166 Bytes
d4c3acc 246f207 ceb95cf c21a752 6ed6f8e 648d371 b842d10 6ed6f8e e3205c6 6ed6f8e b842d10 6ed6f8e b842d10 6ed6f8e 648d371 ceb95cf d5aac08 aca9f11 6ed6f8e 793cc29 ceb95cf 793cc29 ceb95cf 793cc29 ceb95cf 793cc29 6564a3a b842d10 6ed6f8e cfd90f9 e3205c6 6ed6f8e cfd90f9 6ed6f8e cfd90f9 b3747be 5769b69 6ed6f8e 5769b69 6ed6f8e 5769b69 6ed6f8e e3205c6 6ed6f8e e14364d aca9f11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import gradio as gr
from transformers import pipeline, DetrForObjectDetection, DetrConfig, DetrImageProcessor
import numpy as np
import cv2
from PIL import Image
# Pre-load the base configuration and models (without setting a threshold yet)
base_config = DetrConfig.from_pretrained("facebook/detr-resnet-50")
base_model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=base_config)
base_processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def load_model(threshold):
# Adjust the configuration for the current threshold
config = DetrConfig.from_pretrained("facebook/detr-resnet-50", threshold=threshold)
# Create a new model instance with the updated configuration
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", config=config)
# Image processor does not need to be re-loaded
return pipeline(task='object-detection', model=model, image_processor=base_processor)
# Initialize the pipeline with a default threshold
od_pipe = load_model(0.25) # Set a default threshold here
def draw_detections(image, detections):
np_image = np.array(image)
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
for detection in detections:
score = detection['score']
label = detection['label']
box = detection['box']
x_min, y_min = box['xmin'], box['ymin']
x_max, y_max = box['xmax'], box['ymax']
# Draw rectangles and text with a larger font
cv2.rectangle(np_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
label_text = f'{label} {score:.2f}'
# Increase the font size and text thickness
cv2.putText(np_image, label_text, (x_min, y_min - 10), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 4)
# Convert BGR to RGB for displaying
final_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
final_pil_image = Image.fromarray(final_image)
return final_pil_image
def get_pipeline_prediction(threshold, pil_image):
global od_pipe
od_pipe = load_model(threshold) # reload model with the specified threshold
try:
if not isinstance(pil_image, Image.Image):
pil_image = Image.fromarray(np.array(pil_image).astype('uint8'), 'RGB')
result = od_pipe(pil_image)
processed_image = draw_detections(pil_image, result)
return processed_image, result
except Exception as e:
return pil_image, {"error": str(e)}
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
gr.Markdown("## Object Detection")
inp_image = gr.Image(label="Upload your image here")
threshold_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.25, label="Detection Threshold")
run_button = gr.Button("Detect Objects")
with gr.Column():
with gr.Tab("Annotated Image"):
output_image = gr.Image()
with gr.Tab("Detection Results"):
output_data = gr.JSON()
run_button.click(get_pipeline_prediction, inputs=[threshold_slider, inp_image], outputs=[output_image, output_data])
demo.launch() |