File size: 3,515 Bytes
07d1802 25c54ab 506ac4e 07d1802 b81c01d 07d1802 b8eded5 bbeba56 c4c158a bbeba56 07d1802 e57154d 506ac4e 07d1802 7592465 4846836 e57154d 07d1802 2fbdb0d e57154d 2fbdb0d e57154d 5236243 2fbdb0d 07d1802 e9d980d 0664949 07d1802 2fbdb0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO, RTDETR
import spaces
import os
from huggingface_hub import hf_hub_download
# Helper function to download models from Hugging Face
def get_model_path(model_name):
model_cache_path = hf_hub_download(
repo_id="atalaydenknalbant/budgerigar_models",
filename=model_name
)
return model_cache_path
@spaces.GPU
def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detection):
if images is None:
# Create a blank image
width, height = 640, 480
blank_image = Image.new("RGB", (width, height), color="white")
draw = ImageDraw.Draw(blank_image)
message = "No image provided"
font = ImageFont.load_default(size=40)
bbox = draw.textbbox((0, 0), message, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
text_x = (width - text_width) / 2
text_y = (height - text_height) / 2
draw.text((text_x, text_y), message, fill="black", font=font)
return blank_image
model_path = get_model_path(model_id) # Download model
model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO
model = model_type(model_path)
results = model.predict(
source=images,
conf=conf_threshold,
iou=iou_threshold,
imgsz=640,
max_det=max_detection,
show_labels=True,
show_conf=True,
)
# Process results and convert to PIL Image
for r in results:
image_array = r.plot()
image = Image.fromarray(image_array[..., ::-1])
return image
interface = gr.Interface(
fn=yolo_inference,
inputs=[
gr.Image(type="pil", label="Example Image", interactive=True),
gr.Radio(
choices=[
'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt',
'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt',
'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt',
'budgerigar_rtdetr-x.pt'
],
label="Model Name",
value="budgerigar_yolo11x.pt",
),
gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"),
gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold"),
gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection"),
],
outputs=gr.Image(type="pil", label="Annotated Image"),
cache_examples=True,
title="Budgerigar Gender Determination",
description=(
"Pretrained object detection models for determining budgerigar gender based on cere color variations. "
"Upload image(s) for inference. For more details, refer to the paper: "
'<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">'
'"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>'
"<br><br>"
"To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>."
"Your feedback is important for retraining and improving the model."
),
examples=[
["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300],
["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300],
["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300],
],
)
interface.launch()
|