Enhanced YOLOv11 SAHI Demo with Dynamic Model Loading, UI Controls and MCP Compatibility

#5
Files changed (1) hide show
  1. app.py +178 -55
app.py CHANGED
@@ -6,10 +6,9 @@ import sahi.slicing
6
  from PIL import Image
7
  import numpy
8
  from ultralytics import YOLO
9
-
10
-
11
  import sys
12
  import types
 
13
  if 'huggingface_hub.utils._errors' not in sys.modules:
14
  mock_errors = types.ModuleType('_errors')
15
  mock_errors.RepositoryNotFoundError = Exception
@@ -37,15 +36,33 @@ sahi.utils.file.download_from_url(
37
  "highway3.jpg",
38
  )
39
 
 
 
40
 
41
- # Model
42
- model = AutoDetectionModel.from_pretrained(
43
- model_type="ultralytics", model_path="yolo11s.pt", device="cpu", confidence_threshold=0.5, image_size=IMAGE_SIZE
44
- )
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def sahi_yolo_inference(
48
  image,
 
 
 
49
  slice_height=512,
50
  slice_width=512,
51
  overlap_height_ratio=0.2,
@@ -55,6 +72,29 @@ def sahi_yolo_inference(
55
  postprocess_match_threshold=0.5,
56
  postprocess_class_agnostic=False,
57
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  image_width, image_height = image.size
60
  sliced_bboxes = sahi.slicing.get_slice_bboxes(
@@ -71,18 +111,24 @@ def sahi_yolo_inference(
71
  f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
72
  )
73
 
74
- # standard inference
75
  prediction_result_1 = sahi.predict.get_prediction(
76
- image=image, detection_model=model
77
  )
78
- print(image)
 
 
 
 
 
 
79
  visual_result_1 = sahi.utils.cv.visualize_object_predictions(
80
  image=numpy.array(image),
81
  object_prediction_list=prediction_result_1.object_prediction_list,
82
  )
83
  output_1 = Image.fromarray(visual_result_1["image"])
84
 
85
- # sliced inference
86
  prediction_result_2 = sahi.predict.get_sliced_prediction(
87
  image=image,
88
  detection_model=model,
@@ -95,6 +141,13 @@ def sahi_yolo_inference(
95
  postprocess_match_threshold=postprocess_match_threshold,
96
  postprocess_class_agnostic=postprocess_class_agnostic,
97
  )
 
 
 
 
 
 
 
98
  visual_result_2 = sahi.utils.cv.visualize_object_predictions(
99
  image=numpy.array(image),
100
  object_prediction_list=prediction_result_2.object_prediction_list,
@@ -105,48 +158,118 @@ def sahi_yolo_inference(
105
  return output_1, output_2
106
 
107
 
108
- inputs = [
109
- gr.Image(type="pil", label="Original Image"),
110
- gr.Number(value=512, label="slice_height"),
111
- gr.Number(value=512, label="slice_width"),
112
- gr.Number(value=0.2, label="overlap_height_ratio"),
113
- gr.Number(value=0.2, label="overlap_width_ratio"),
114
- gr.Dropdown(
115
- ["NMS", "GREEDYNMM"],
116
- type="value",
117
- value="NMS",
118
- label="postprocess_type",
119
- ),
120
- gr.Dropdown(
121
- ["IOU", "IOS"], type="value", value="IOU", label="postprocess_type"
122
- ),
123
- gr.Number(value=0.5, label="postprocess_match_threshold"),
124
- gr.Checkbox(value=True, label="postprocess_class_agnostic"),
125
- ]
126
-
127
- outputs = [
128
- gr.Image(type="pil", label="YOLO11s Standard"),
129
- gr.Image(type="pil", label="YOLO11s + SAHI Sliced"),
130
- ]
131
-
132
- title = "Small Object Detection with SAHI + YOLO11"
133
- description = "SAHI + YOLO11 demo for small object detection. Upload your own image or click an example image to use."
134
- article = "<p style='text-align: center'>SAHI is a lightweight vision library for performing large scale object detection/ instance segmentation.. <a href='https://github.com/obss/sahi'>SAHI Github</a> | <a href='https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80'>SAHI Blog</a> </p>"
135
- examples = [
136
- ["apple_tree.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
137
- ["highway.jpg", 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
138
- ["highway2.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
139
- ["highway3.jpg", 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
140
- ]
141
-
142
- gr.Interface(
143
- sahi_yolo_inference,
144
- inputs,
145
- outputs,
146
- title=title,
147
- description=description,
148
- article=article,
149
- examples=examples,
150
- theme="huggingface",
151
- cache_examples=True,
152
- ).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from PIL import Image
7
  import numpy
8
  from ultralytics import YOLO
 
 
9
  import sys
10
  import types
11
+
12
  if 'huggingface_hub.utils._errors' not in sys.modules:
13
  mock_errors = types.ModuleType('_errors')
14
  mock_errors.RepositoryNotFoundError = Exception
 
36
  "highway3.jpg",
37
  )
38
 
39
+ # Global model variable
40
+ model = None
41
 
42
+ def load_yolo_model(model_name, confidence_threshold=0.5):
43
+ """
44
+ Loads a YOLOv11 detection model.
45
+
46
+ Args:
47
+ model_name (str): The name of the YOLOv11 model to load (e.g., "yolo11n.pt").
48
+ confidence_threshold (float): The confidence threshold for object detection.
49
 
50
+ Returns:
51
+ AutoDetectionModel: The loaded SAHI AutoDetectionModel.
52
+ """
53
+ global model
54
+ model_path = model_name
55
+ model = AutoDetectionModel.from_pretrained(
56
+ model_type="ultralytics", model_path=model_path, device="cpu",
57
+ confidence_threshold=confidence_threshold, image_size=IMAGE_SIZE
58
+ )
59
+ return model
60
 
61
  def sahi_yolo_inference(
62
  image,
63
+ yolo_model_name,
64
+ confidence_threshold,
65
+ max_detections,
66
  slice_height=512,
67
  slice_width=512,
68
  overlap_height_ratio=0.2,
 
72
  postprocess_match_threshold=0.5,
73
  postprocess_class_agnostic=False,
74
  ):
75
+ """
76
+ Performs object detection using SAHI with a specified YOLOv11 model.
77
+
78
+ Args:
79
+ image (PIL.Image.Image): The input image for detection.
80
+ yolo_model_name (str): The name of the YOLOv11 model to use for inference.
81
+ confidence_threshold (float): The confidence threshold for object detection.
82
+ max_detections (int): The maximum number of detections to return.
83
+ slice_height (int): The height of each slice for sliced inference.
84
+ slice_width (int): The width of each slice for sliced inference.
85
+ overlap_height_ratio (float): The overlap ratio for slice height.
86
+ overlap_width_ratio (float): The overlap ratio for slice width.
87
+ postprocess_type (str): The type of postprocessing to apply ("NMS" or "GREEDYNMM").
88
+ postprocess_match_metric (str): The metric for postprocessing matching ("IOU" or "IOS").
89
+ postprocess_match_threshold (float): The threshold for postprocessing matching.
90
+ postprocess_class_agnostic (bool): Whether postprocessing should be class agnostic.
91
+
92
+ Returns:
93
+ tuple: A tuple containing two PIL.Image.Image objects:
94
+ - The image with standard YOLO inference results.
95
+ - The image with SAHI sliced YOLO inference results.
96
+ """
97
+ load_yolo_model(yolo_model_name, confidence_threshold)
98
 
99
  image_width, image_height = image.size
100
  sliced_bboxes = sahi.slicing.get_slice_bboxes(
 
111
  f"{len(sliced_bboxes)} slices are too much for huggingface spaces, try smaller slice size."
112
  )
113
 
114
+ # Standard inference
115
  prediction_result_1 = sahi.predict.get_prediction(
116
+ image=image, detection_model=model,
117
  )
118
+
119
+ # Filter by max_detections for standard inference
120
+ if max_detections is not None and len(prediction_result_1.object_prediction_list) > max_detections:
121
+ prediction_result_1.object_prediction_list = sorted(
122
+ prediction_result_1.object_prediction_list, key=lambda x: x.score.value, reverse=True
123
+ )[:max_detections]
124
+
125
  visual_result_1 = sahi.utils.cv.visualize_object_predictions(
126
  image=numpy.array(image),
127
  object_prediction_list=prediction_result_1.object_prediction_list,
128
  )
129
  output_1 = Image.fromarray(visual_result_1["image"])
130
 
131
+ # Sliced inference
132
  prediction_result_2 = sahi.predict.get_sliced_prediction(
133
  image=image,
134
  detection_model=model,
 
141
  postprocess_match_threshold=postprocess_match_threshold,
142
  postprocess_class_agnostic=postprocess_class_agnostic,
143
  )
144
+
145
+ # Filter by max_detections for sliced inference
146
+ if max_detections is not None and len(prediction_result_2.object_prediction_list) > max_detections:
147
+ prediction_result_2.object_prediction_list = sorted(
148
+ prediction_result_2.object_prediction_list, key=lambda x: x.score.value, reverse=True
149
+ )[:max_detections]
150
+
151
  visual_result_2 = sahi.utils.cv.visualize_object_predictions(
152
  image=numpy.array(image),
153
  object_prediction_list=prediction_result_2.object_prediction_list,
 
158
  return output_1, output_2
159
 
160
 
161
+ with gr.Blocks() as app:
162
+ gr.Markdown("# Small Object Detection with SAHI + YOLOv11")
163
+ gr.Markdown(
164
+ "SAHI + YOLOv11 demo for small object detection. "
165
+ "Upload your own image or click an example image to use."
166
+ )
167
+
168
+ with gr.Row():
169
+ with gr.Column():
170
+ original_image_input = gr.Image(type="pil", label="Original Image")
171
+ yolo_model_dropdown = gr.Dropdown(
172
+ choices=["yolo11n.pt", "yolo11s.pt", "yolo11m.pt", "yolo11l.pt", "yolo11x.pt"],
173
+ value="yolo11s.pt",
174
+ label="YOLOv11 Model",
175
+ )
176
+ confidence_threshold_slider = gr.Slider(
177
+ minimum=0.0,
178
+ maximum=1.0,
179
+ step=0.01,
180
+ value=0.5,
181
+ label="Confidence Threshold",
182
+ )
183
+ max_detections_slider = gr.Slider(
184
+ minimum=1,
185
+ maximum=500,
186
+ step=1,
187
+ value=300,
188
+ label="Max Detections",
189
+ )
190
+ slice_height_input = gr.Number(value=512, label="Slice Height")
191
+ slice_width_input = gr.Number(value=512, label="Slice Width")
192
+ overlap_height_ratio_slider = gr.Slider(
193
+ minimum=0.0,
194
+ maximum=1.0,
195
+ step=0.01,
196
+ value=0.2,
197
+ label="Overlap Height Ratio",
198
+ )
199
+ overlap_width_ratio_slider = gr.Slider(
200
+ minimum=0.0,
201
+ maximum=1.0,
202
+ step=0.01,
203
+ value=0.2,
204
+ label="Overlap Width Ratio",
205
+ )
206
+ postprocess_type_dropdown = gr.Dropdown(
207
+ ["NMS", "GREEDYNMM"],
208
+ type="value",
209
+ value="NMS",
210
+ label="Postprocess Type",
211
+ )
212
+ postprocess_match_metric_dropdown = gr.Dropdown(
213
+ ["IOU", "IOS"], type="value", value="IOU", label="Postprocess Match Metric"
214
+ )
215
+ postprocess_match_threshold_slider = gr.Slider(
216
+ minimum=0.0,
217
+ maximum=1.0,
218
+ step=0.01,
219
+ value=0.5,
220
+ label="Postprocess Match Threshold",
221
+ )
222
+ postprocess_class_agnostic_checkbox = gr.Checkbox(value=True, label="Postprocess Class Agnostic")
223
+
224
+ submit_button = gr.Button("Run Inference")
225
+
226
+ with gr.Column():
227
+ output_standard = gr.Image(type="pil", label="YOLOv11 Standard")
228
+ output_sahi_sliced = gr.Image(type="pil", label="YOLOv11 + SAHI Sliced")
229
+
230
+ gr.Examples(
231
+ examples=[
232
+ ["apple_tree.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
233
+ ["highway.jpg", "yolo11s.pt", 0.5, 300, 256, 256, 0.2, 0.2, "NMS", "IOU", 0.4, True],
234
+ ["highway2.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
235
+ ["highway3.jpg", "yolo11s.pt", 0.5, 300, 512, 512, 0.2, 0.2, "NMS", "IOU", 0.4, True],
236
+ ],
237
+ inputs=[
238
+ original_image_input,
239
+ yolo_model_dropdown,
240
+ confidence_threshold_slider,
241
+ max_detections_slider,
242
+ slice_height_input,
243
+ slice_width_input,
244
+ overlap_height_ratio_slider,
245
+ overlap_width_ratio_slider,
246
+ postprocess_type_dropdown,
247
+ postprocess_match_metric_dropdown,
248
+ postprocess_match_threshold_slider,
249
+ postprocess_class_agnostic_checkbox,
250
+ ],
251
+ outputs=[output_standard, output_sahi_sliced],
252
+ fn=sahi_yolo_inference,
253
+ cache_examples=True,
254
+ )
255
+
256
+ submit_button.click(
257
+ fn=sahi_yolo_inference,
258
+ inputs=[
259
+ original_image_input,
260
+ yolo_model_dropdown,
261
+ confidence_threshold_slider,
262
+ max_detections_slider,
263
+ slice_height_input,
264
+ slice_width_input,
265
+ overlap_height_ratio_slider,
266
+ overlap_width_ratio_slider,
267
+ postprocess_type_dropdown,
268
+ postprocess_match_metric_dropdown,
269
+ postprocess_match_threshold_slider,
270
+ postprocess_class_agnostic_checkbox,
271
+ ],
272
+ outputs=[output_standard, output_sahi_sliced],
273
+ )
274
+
275
+ app.launch(mcp_server=True)