atalaydenknalbant commited on
Commit
c5e5663
·
verified ·
1 Parent(s): 6248745

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -72
app.py CHANGED
@@ -3,6 +3,9 @@ from PIL import Image, ImageDraw, ImageFont
3
  from ultralytics import YOLO, RTDETR
4
  import spaces
5
  import os
 
 
 
6
  from huggingface_hub import hf_hub_download
7
 
8
  def get_model_path(model_name):
@@ -26,18 +29,22 @@ def get_model_path(model_name):
26
  return model_cache_path
27
 
28
  @spaces.GPU
29
- def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detection):
30
  """
31
- Performs budgerigar gender determination inference on an image using a selected YOLO or RTDETR model.
 
32
 
33
- This function handles input images, loads the appropriate model (YOLO or RTDETR)
34
- based on the `model_id`, and then runs inference to detect budgerigars and
35
- determine their gender. The results are then plotted onto the original image.
36
- If no image is provided, it returns a blank image with a message.
37
 
38
  Args:
39
- images (PIL.Image.Image or None): The input image on which to perform detection.
40
- Can be None if no image is uploaded.
 
 
 
41
  model_id (str): The identifier of the model to use (e.g., 'budgerigar_yolo11x.pt',
42
  'budgerigar_rtdetr-x.pt').
43
  conf_threshold (float): The confidence threshold for filtering detections.
@@ -47,78 +54,213 @@ def yolo_inference(images, model_id, conf_threshold, iou_threshold, max_detectio
47
  max_detection (int): The maximum number of detections to return and display.
48
 
49
  Returns:
50
- PIL.Image.Image: The input image annotated with detection results, including
51
- bounding boxes and gender labels. Returns a blank image
52
- with a message if no input image is provided.
 
 
53
  """
54
- if images is None:
55
- # Create a blank image
56
- width, height = 640, 480
57
- blank_image = Image.new("RGB", (width, height), color="white")
58
- draw = ImageDraw.Draw(blank_image)
59
- message = "No image provided"
60
- font = ImageFont.load_default(size=40)
61
- bbox = draw.textbbox((0, 0), message, font=font)
62
- text_width = bbox[2] - bbox[0]
63
- text_height = bbox[3] - bbox[1]
64
- text_x = (width - text_width) / 2
65
- text_y = (height - text_height) / 2
66
- draw.text((text_x, text_y), message, fill="black", font=font)
67
- return blank_image
68
-
69
- model_path = get_model_path(model_id) # Download model
70
  model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO
71
  model = model_type(model_path)
72
- results = model.predict(
73
- source=images,
74
- conf=conf_threshold,
75
- iou=iou_threshold,
76
- imgsz=640,
77
- max_det=max_detection,
78
- show_labels=True,
79
- show_conf=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
 
81
 
82
- # Process results and convert to PIL Image
83
- for r in results:
84
- image_array = r.plot()
85
- image = Image.fromarray(image_array[..., ::-1])
86
- return image
87
-
88
- interface = gr.Interface(
89
- fn=yolo_inference,
90
- inputs=[
91
- gr.Image(type="pil", label="Example Image", interactive=True),
92
- gr.Radio(
93
- choices=[
94
- 'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt',
95
- 'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt',
96
- 'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt',
97
- 'budgerigar_rtdetr-x.pt'
98
- ],
99
- label="Model Name",
100
- value="budgerigar_yolo11x.pt",
101
- ),
102
- gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold"),
103
- gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold"),
104
- gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection"),
105
- ],
106
- outputs=gr.Image(type="pil", label="Annotated Image"),
107
- cache_examples=True,
108
- title="Budgerigar Gender Determination",
109
- description=(
110
  "Pretrained object detection models for determining budgerigar gender based on cere color variations. "
111
- "Upload image(s) for inference. For more details, refer to the paper: "
112
  '<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">'
113
  '"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>'
114
  "<br><br>"
115
  "To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>."
116
  "Your feedback is important for retraining and improving the model."
117
- ),
118
- examples=[
119
- ["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300],
120
- ["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300],
121
- ["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300],
122
- ],
123
- )
124
- interface.launch(mcp_server=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from ultralytics import YOLO, RTDETR
4
  import spaces
5
  import os
6
+ import cv2
7
+ import numpy as np
8
+ import tempfile
9
  from huggingface_hub import hf_hub_download
10
 
11
  def get_model_path(model_name):
 
29
  return model_cache_path
30
 
31
  @spaces.GPU
32
+ def yolo_inference(input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection):
33
  """
34
+ Performs budgerigar gender determination inference on an image or video
35
+ using a selected YOLO or RTDETR model.
36
 
37
+ This function handles both image and video inputs. For images, it loads the
38
+ appropriate model and annotates the image. For videos, it processes each
39
+ frame, performs detection, and then reconstructs an annotated video.
40
+ Error handling for missing inputs is included, returning blank outputs with messages.
41
 
42
  Args:
43
+ input_type (str): Specifies the input type, either "Image" or "Video".
44
+ image (PIL.Image.Image or None): The input image if `input_type` is "Image".
45
+ None otherwise.
46
+ video (str or None): The path to the input video file if `input_type` is "Video".
47
+ None otherwise.
48
  model_id (str): The identifier of the model to use (e.g., 'budgerigar_yolo11x.pt',
49
  'budgerigar_rtdetr-x.pt').
50
  conf_threshold (float): The confidence threshold for filtering detections.
 
54
  max_detection (int): The maximum number of detections to return and display.
55
 
56
  Returns:
57
+ tuple: A tuple containing two elements:
58
+ - PIL.Image.Image or None: The annotated image if `input_type` was "Image",
59
+ otherwise None.
60
+ - str or None: The path to the annotated video file if `input_type` was "Video",
61
+ otherwise None.
62
  """
63
+ model_path = get_model_path(model_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  model_type = RTDETR if 'rtdetr' in model_id.lower() else YOLO
65
  model = model_type(model_path)
66
+
67
+ if input_type == "Image":
68
+ if image is None:
69
+ width, height = 640, 480
70
+ blank_image = Image.new("RGB", (width, height), color="white")
71
+ draw = ImageDraw.Draw(blank_image)
72
+ message = "No image provided"
73
+ font = ImageFont.load_default(size=40)
74
+ bbox = draw.textbbox((0, 0), message, font=font)
75
+ text_width = bbox[2] - bbox[0]
76
+ text_height = bbox[3] - bbox[1]
77
+ text_x = (width - text_width) / 2
78
+ text_y = (height - text_height) / 2
79
+ draw.text((text_x, text_y), message, fill="black", font=font)
80
+ return blank_image, None
81
+
82
+ results = model.predict(
83
+ source=image,
84
+ conf=conf_threshold,
85
+ iou=iou_threshold,
86
+ imgsz=640,
87
+ max_det=max_detection,
88
+ show_labels=True,
89
+ show_conf=True,
90
+ )
91
+ for r in results:
92
+ image_array = r.plot()
93
+ annotated_image = Image.fromarray(image_array[..., ::-1])
94
+ return annotated_image, None
95
+
96
+ elif input_type == "Video":
97
+ if video is None:
98
+ width, height = 640, 480
99
+ blank_image = Image.new("RGB", (width, height), color="white")
100
+ draw = ImageDraw.Draw(blank_image)
101
+ message = "No video provided"
102
+ font = ImageFont.load_default(size=40)
103
+ bbox = draw.textbbox((0, 0), message, font=font)
104
+ text_width = bbox[2] - bbox[0]
105
+ text_height = bbox[3] - bbox[1]
106
+ text_x = (width - text_width) / 2
107
+ text_y = (height - text_height) / 2
108
+ draw.text((text_x, text_y), message, fill="black", font=font)
109
+ temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
110
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
111
+ out = cv2.VideoWriter(temp_video_file, fourcc, 1, (width, height))
112
+ frame = cv2.cvtColor(np.array(blank_image), cv2.COLOR_RGB2BGR)
113
+ out.write(frame)
114
+ out.release()
115
+ return None, temp_video_file
116
+
117
+ cap = cv2.VideoCapture(video)
118
+ fps = cap.get(cv2.CAP_PROP_FPS) if cap.get(cv2.CAP_PROP_FPS) > 0 else 25
119
+ frames = []
120
+ while True:
121
+ ret, frame = cap.read()
122
+ if not ret:
123
+ break
124
+ pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
125
+ results = model.predict(
126
+ source=pil_frame,
127
+ conf=conf_threshold,
128
+ iou=iou_threshold,
129
+ imgsz=640,
130
+ max_det=max_detection,
131
+ show_labels=True,
132
+ show_conf=True,
133
+ )
134
+ for r in results:
135
+ annotated_frame_array = r.plot()
136
+ annotated_frame = cv2.cvtColor(annotated_frame_array, cv2.COLOR_BGR2RGB)
137
+ frames.append(annotated_frame)
138
+ cap.release()
139
+ if not frames:
140
+ return None, None
141
+
142
+ height_out, width_out, _ = frames[0].shape
143
+ temp_video_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
144
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
145
+ out = cv2.VideoWriter(temp_video_file, fourcc, fps, (width_out, height_out))
146
+ for f in frames:
147
+ f_bgr = cv2.cvtColor(f, cv2.COLOR_RGB2BGR)
148
+ out.write(f_bgr)
149
+ out.release()
150
+ return None, temp_video_file
151
+
152
+ return None, None
153
+
154
+ def update_visibility(input_type):
155
+ """
156
+ Adjusts the visibility of Gradio components based on the selected input type.
157
+
158
+ This function dynamically shows or hides the image and video input/output
159
+ components in the Gradio interface to ensure only relevant fields are visible.
160
+
161
+ Args:
162
+ input_type (str): The selected input type, either "Image" or "Video".
163
+
164
+ Returns:
165
+ tuple: A tuple of `gr.update` objects for the visibility of:
166
+ (image input, video input, image output, video output).
167
+ """
168
+ if input_type == "Image":
169
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)
170
+ else:
171
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
172
+
173
+ def yolo_inference_for_examples(image, model_id, conf_threshold, iou_threshold, max_detection):
174
+ """
175
+ Wrapper function for `yolo_inference` specifically for Gradio examples that use images.
176
+
177
+ This function simplifies the `yolo_inference` call for the `gr.Examples` component,
178
+ ensuring only image-based inference is performed for predefined examples.
179
+
180
+ Args:
181
+ image (PIL.Image.Image): The input image for the example.
182
+ model_id (str): The identifier of the YOLO model to use.
183
+ conf_threshold (float): The confidence threshold.
184
+ iou_threshold (float): The IoU threshold.
185
+ max_detection (int): The maximum number of detections.
186
+
187
+ Returns:
188
+ PIL.Image.Image or None: The annotated image. Returns None if no image is processed.
189
+ """
190
+ annotated_image, _ = yolo_inference(
191
+ input_type="Image",
192
+ image=image,
193
+ video=None,
194
+ model_id=model_id,
195
+ conf_threshold=conf_threshold,
196
+ iou_threshold=iou_threshold,
197
+ max_detection=max_detection
198
  )
199
+ return annotated_image
200
 
201
+ with gr.Blocks(title="Budgerigar Gender Determination") as app:
202
+ gr.Markdown("# Budgerigar Gender Determination")
203
+ gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  "Pretrained object detection models for determining budgerigar gender based on cere color variations. "
205
+ "Upload image(s) or video(s) for inference. For more details, refer to the paper: "
206
  '<a href="https://ieeexplore.ieee.org/document/10773570" target="_blank">'
207
  '"Advanced Computer Vision Techniques for Reliable Gender Determination in Budgerigars (Melopsittacus Undulatus)"</a>'
208
  "<br><br>"
209
  "To help us improve, please report any incorrect gender determinations by sending the original image and details to -> <a href='mailto:[email protected]'>Email</a>."
210
  "Your feedback is important for retraining and improving the model."
211
+ )
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ image = gr.Image(type="pil", label="Image Input", visible=True)
216
+ video = gr.Video(label="Video Input", visible=False)
217
+ input_type = gr.Radio(
218
+ choices=["Image", "Video"],
219
+ value="Image",
220
+ label="Input Type",
221
+ )
222
+
223
+ model_id = gr.Radio(
224
+ choices=[
225
+ 'budgerigar_yolo11x.pt', 'budgerigar_yolov9e.pt',
226
+ 'budgerigar_yolo11l.pt', 'budgerigar_yolo11m.pt',
227
+ 'budgerigar_yolo11s.pt', 'budgerigar_yolo11n.pt',
228
+ 'budgerigar_rtdetr-x.pt'
229
+ ],
230
+ label="Model Name",
231
+ value="budgerigar_yolo11x.pt",
232
+ )
233
+ conf_threshold = gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold")
234
+ iou_threshold = gr.Slider(minimum=0, maximum=1, value=0.45, label="IoU Threshold")
235
+ max_detection = gr.Slider(minimum=1, maximum=300, step=1, value=300, label="Max Detection")
236
+ infer_button = gr.Button("Detect Objects")
237
+ with gr.Column():
238
+ output_image = gr.Image(type="pil", label="Annotated Image", visible=True)
239
+ output_video = gr.Video(label="Annotated Video", visible=False)
240
+ gr.DeepLinkButton()
241
+
242
+ input_type.change(
243
+ fn=update_visibility,
244
+ inputs=input_type,
245
+ outputs=[image, video, output_image, output_video],
246
+ )
247
+
248
+ infer_button.click(
249
+ fn=yolo_inference,
250
+ inputs=[input_type, image, video, model_id, conf_threshold, iou_threshold, max_detection],
251
+ outputs=[output_image, output_video],
252
+ )
253
+
254
+ gr.Examples(
255
+ examples=[
256
+ ["both.jpg", "budgerigar_rtdetr-x.pt", 0.25, 0.45, 300],
257
+ ["Male.png", "budgerigar_yolov9e.pt", 0.25, 0.45, 300],
258
+ ["Female.png", "budgerigar_yolo11x.pt", 0.25, 0.45, 300],
259
+ ],
260
+ fn=yolo_inference_for_examples,
261
+ inputs=[image, model_id, conf_threshold, iou_threshold, max_detection],
262
+ outputs=[output_image],
263
+ label="Examples (Images)",
264
+ )
265
+
266
+ app.launch(mcp_server=True)