atalaydenknalbant's picture
Update app.py
c4c158a verified
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO, RTDETR
import spaces
import os
from huggingface_hub import hf_hub_download
# Helper function to download models from Hugging Face
def get_model_path(model_name):
model_cache_path = hf_hub_download(
repo_id="atalaydenknalbant/budgerigar_models",
filename=model_name
)
return model_cache_path
@spaces.GPU
def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detection):
if images is None:
# Create a blank image
width, height = 640, 480
blank_image = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(blank_image)
message = "No image provided"
font = ImageFont.load_default(size=40)
bbox = draw.textbbox((0, 0), message, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
text_x = (width - text_width) / 2
text_y = (height - text_height) / 2
draw.text((text_x, text_y), message, fill="black", font=font)
return blank_image
model_path = get_model_path(model_id) # Download model
model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO
model = model_type(model_path)
results = model.predict(
source=images,
conf=conf_threshold,
iou=iou_threshold,
imgsz=640,
max_det=max_detection,
show_labels=True,
show_conf=True,
)
# Process results and convert to PIL Image
for r in results:
image_array = r.plot()
image = Image.fromarray(image_array[..., ::-1])
return image
interface = gr.Interface(
fn=yolo_inference,
inputs=[
gr.Image(type="pil", label="Example Image", interactive=True),
gr.Radio(
choices=[
'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt',
'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt',
'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt',
'budgerigar_rtdetr-x.pt'
],
label="Model Name",
value="budgerigar_yolo11x.pt",
),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold"),
gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection"),
],
outputs=gr.Image(type="pil", label="Annotated Image"),
cache_examples=True,
title="Budgerigar Gender Determination",
description=(
"Pretrained object detection models for determining budgerigar gender based on cere color variations. "
"Upload image(s) for inference. For more details, refer to the paper: "
'<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">'
'"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>'
"<br><br>"
"To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>."
"Your feedback is important for retraining and improving the model."
),
examples=[
["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300],
["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300],
["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300],
],
)
interface.launch()