Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image, ImageDraw, ImageFont | |
from ultralytics import YOLO, RTDETR | |
import spaces | |
import os | |
from huggingface_hub import hf_hub_download | |
# Helper function to download models from Hugging Face | |
def get_model_path(model_name): | |
model_cache_path = hf_hub_download( | |
repo_id="atalaydenknalbant/budgerigar_models", | |
filename=model_name | |
) | |
return model_cache_path | |
def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detection): | |
if images is None: | |
# Create a blank image | |
width, height = 640, 480 | |
blank_image = Image.new("RGB", (width, height), color="white") | |
draw = ImageDraw.Draw(blank_image) | |
message = "No image provided" | |
font = ImageFont.load_default(size=40) | |
bbox = draw.textbbox((0, 0), message, font=font) | |
text_width = bbox[2] - bbox[0] | |
text_height = bbox[3] - bbox[1] | |
text_x = (width - text_width) / 2 | |
text_y = (height - text_height) / 2 | |
draw.text((text_x, text_y), message, fill="black", font=font) | |
return blank_image | |
model_path = get_model_path(model_id) # Download model | |
model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO | |
model = model_type(model_path) | |
results = model.predict( | |
source=images, | |
conf=conf_threshold, | |
iou=iou_threshold, | |
imgsz=640, | |
max_det=max_detection, | |
show_labels=True, | |
show_conf=True, | |
) | |
# Process results and convert to PIL Image | |
for r in results: | |
image_array = r.plot() | |
image = Image.fromarray(image_array[..., ::-1]) | |
return image | |
interface = gr.Interface( | |
fn=yolo_inference, | |
inputs=[ | |
gr.Image(type="pil", label="Example Image", interactive=True), | |
gr.Radio( | |
choices=[ | |
'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt', | |
'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt', | |
'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt', | |
'budgerigar_rtdetr-x.pt' | |
], | |
label="Model Name", | |
value="budgerigar_yolo11x.pt", | |
), | |
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"), | |
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold"), | |
gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection"), | |
], | |
outputs=gr.Image(type="pil", label="Annotated Image"), | |
cache_examples=True, | |
title="Budgerigar Gender Determination", | |
description=( | |
"Pretrained object detection models for determining budgerigar gender based on cere color variations. " | |
"Upload image(s) for inference. For more details, refer to the paper: " | |
'<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">' | |
'"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>' | |
"<br><br>" | |
"To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>." | |
"Your feedback is important for retraining and improving the model." | |
), | |
examples=[ | |
["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300], | |
["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300], | |
["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300], | |
], | |
) | |
interface.launch() | |