File size: 3,515 Bytes
07d1802
25c54ab
506ac4e
07d1802
 
 
 
 
 
 
b81c01d
07d1802
 
 
 
 
 
b8eded5
bbeba56
 
 
 
c4c158a
bbeba56
 
 
 
 
 
 
 
 
07d1802
e57154d
506ac4e
07d1802
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7592465
4846836
e57154d
 
 
 
 
 
07d1802
 
 
 
 
 
 
 
 
 
2fbdb0d
 
e57154d
2fbdb0d
 
e57154d
5236243
 
2fbdb0d
07d1802
e9d980d
0664949
 
07d1802
 
2fbdb0d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO, RTDETR
import spaces
import os
from huggingface_hub import hf_hub_download

# Helper function to download models from Hugging Face
def get_model_path(model_name):
    model_cache_path = hf_hub_download(
        repo_id="atalaydenknalbant/budgerigar_models", 
        filename=model_name
    )
    return model_cache_path

@spaces.GPU
def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detection):
    if images is None:
        # Create a blank image
        width, height = 640, 480
        blank_image = Image.new("RGB", (width, height), color="white")
        draw = ImageDraw.Draw(blank_image)
        message = "No image provided"
        font = ImageFont.load_default(size=40)
        bbox = draw.textbbox((0, 0), message, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        text_x = (width - text_width) / 2
        text_y = (height - text_height) / 2
        draw.text((text_x, text_y), message, fill="black", font=font)
        return blank_image
    
    model_path = get_model_path(model_id)  # Download model
    model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO
    model = model_type(model_path)
    results = model.predict(
        source=images,
        conf=conf_threshold,
        iou=iou_threshold,
        imgsz=640,
        max_det=max_detection,
        show_labels=True,
        show_conf=True,
    )

    # Process results and convert to PIL Image
    for r in results:
        image_array = r.plot()
        image = Image.fromarray(image_array[..., ::-1])
    return image

interface = gr.Interface(
    fn=yolo_inference,
    inputs=[
        gr.Image(type="pil", label="Example Image", interactive=True),
        gr.Radio(
            choices=[
                'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt', 
                'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt', 
                'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt', 
                'budgerigar_rtdetr-x.pt'
            ],
            label="Model Name",
            value="budgerigar_yolo11x.pt",
        ),
        gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"),
        gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold"),
        gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection"),
    ],
    outputs=gr.Image(type="pil", label="Annotated Image"),
    cache_examples=True,
    title="Budgerigar Gender Determination",
    description=(
        "Pretrained object detection models for determining budgerigar gender based on cere color variations. "
        "Upload image(s) for inference. For more details, refer to the paper: "
        '<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">'
        '"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>'
        "<br><br>"
        "To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>."
        "Your feedback is important for retraining and improving the model."
    ),
    examples=[
        ["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300],
        ["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300],
        ["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300],
    ],
)
interface.launch()