mbar0075 commited on
Commit
8dd8474
·
1 Parent(s): 6909e18
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import supervision as sv
6
+ from inference import get_model
7
+
8
+ MARKDOWN = """
9
+ <h1 style='text-align: center'>Segment Something 💫</h1>
10
+ Welcome to Segment Something! Your on the go demo for instance segmentation. 🚀
11
+
12
+ <h2 style='text-align: center'>Matthias Bartolo</h2>
13
+
14
+ Powered by Roboflow [Inference](https://github.com/roboflow/inference) and
15
+ [Supervision](https://github.com/roboflow/supervision). 🔥
16
+ """
17
+
18
+ IMAGE_EXAMPLES = [
19
+ ['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 0.3, 0.1],
20
+ ['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 0.3, 0.1],
21
+ ['https://media.roboflow.com/supervision/image-examples/basketball-1.png', 0.3, 0.3, 0.1],
22
+ ]
23
+
24
+ YOLO_V8N_MODEL = get_model(model_id="yolov8n-seg-640")
25
+ YOLO_V8S_MODEL = get_model(model_id="yolov8s-seg-640")
26
+ YOLO_V8M_MODEL = get_model(model_id="yolov8m-seg-640")
27
+
28
+ LABEL_ANNOTATORS = sv.LabelAnnotator(text_color=sv.Color.black())
29
+ INSTANCE_SEGMENTATION_ANNOTATORS = sv.InstanceSegmentationAnnotator()
30
+
31
+
32
+ def detect_and_annotate(
33
+ model,
34
+ input_image: np.ndarray,
35
+ confidence_threshold: float,
36
+ iou_threshold: float,
37
+ class_id_mapping: dict = None
38
+ ) -> np.ndarray:
39
+ result = model.infer(
40
+ input_image,
41
+ confidence=confidence_threshold,
42
+ iou_threshold=iou_threshold
43
+ )[0]
44
+ detections = sv.Detections.from_inference(result)
45
+
46
+ if class_id_mapping:
47
+ detections.class_id = np.array([
48
+ class_id_mapping[class_id]
49
+ for class_id
50
+ in detections.class_id
51
+ ])
52
+
53
+ labels = [
54
+ f"{class_name} ({confidence:.2f})"
55
+ for class_name, confidence
56
+ in zip(detections['class_name'], detections.confidence)
57
+ ]
58
+
59
+ annotated_image = input_image.copy()
60
+ annotated_image = INSTANCE_SEGMENTATION_ANNOTATORS.annotate(
61
+ scene=annotated_image, detections=detections)
62
+ annotated_image = LABEL_ANNOTATORS.annotate(
63
+ scene=annotated_image, detections=detections, labels=labels)
64
+ return annotated_image
65
+
66
+
67
+ def process_image(
68
+ input_image: np.ndarray,
69
+ yolo_v8_confidence_threshold: float,
70
+ yolo_v9_confidence_threshold: float,
71
+ yolo_v10_confidence_threshold: float,
72
+ iou_threshold: float
73
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
74
+ yolo_v8_annotated_image = detect_and_annotate(
75
+ YOLO_V8N_MODEL, input_image, yolo_v8_confidence_threshold, iou_threshold)
76
+ yolo_v9_annotated_image = detect_and_annotate(
77
+ YOLO_V8S_MODEL, input_image, yolo_v9_confidence_threshold, iou_threshold)
78
+ yolo_10_annotated_image = detect_and_annotate(
79
+ YOLO_V8M_MODEL, input_image, yolo_v10_confidence_threshold, iou_threshold)
80
+
81
+ return (
82
+ yolo_v8_annotated_image,
83
+ yolo_v9_annotated_image,
84
+ yolo_10_annotated_image
85
+ )
86
+
87
+
88
+ yolo_v8N_confidence_threshold_component = gr.Slider(
89
+ minimum=0,
90
+ maximum=1.0,
91
+ value=0.3,
92
+ step=0.01,
93
+ label="YOLOv8N Confidence Threshold",
94
+ info=(
95
+ "The confidence threshold for the YOLO model. Lower the threshold to "
96
+ "reduce false negatives, enhancing the model's sensitivity to detect "
97
+ "sought-after objects. Conversely, increase the threshold to minimize false "
98
+ "positives, preventing the model from identifying objects it shouldn't."
99
+ ))
100
+
101
+ yolo_v8S_confidence_threshold_component = gr.Slider(
102
+ minimum=0,
103
+ maximum=1.0,
104
+ value=0.3,
105
+ step=0.01,
106
+ label="YOLOv8S Confidence Threshold",
107
+ info=(
108
+ "The confidence threshold for the YOLO model. Lower the threshold to "
109
+ "reduce false negatives, enhancing the model's sensitivity to detect "
110
+ "sought-after objects. Conversely, increase the threshold to minimize false "
111
+ "positives, preventing the model from identifying objects it shouldn't."
112
+ ))
113
+
114
+ yolo_v8M_confidence_threshold_component = gr.Slider(
115
+ minimum=0,
116
+ maximum=1.0,
117
+ value=0.3,
118
+ step=0.01,
119
+ label="YOLOv8M Confidence Threshold",
120
+ info=(
121
+ "The confidence threshold for the YOLO model. Lower the threshold to "
122
+ "reduce false negatives, enhancing the model's sensitivity to detect "
123
+ "sought-after objects. Conversely, increase the threshold to minimize false "
124
+ "positives, preventing the model from identifying objects it shouldn't."
125
+ ))
126
+
127
+ iou_threshold_component = gr.Slider(
128
+ minimum=0,
129
+ maximum=1.0,
130
+ value=0.5,
131
+ step=0.01,
132
+ label="IoU Threshold",
133
+ info=(
134
+ "The Intersection over Union (IoU) threshold for non-maximum suppression. "
135
+ "Decrease the value to lessen the occurrence of overlapping bounding boxes, "
136
+ "making the detection process stricter. On the other hand, increase the value "
137
+ "to allow more overlapping bounding boxes, accommodating a broader range of "
138
+ "detections."
139
+ ))
140
+
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown(MARKDOWN)
144
+ with gr.Accordion("Configuration", open=False):
145
+ with gr.Row():
146
+ yolo_v8N_confidence_threshold_component.render()
147
+ yolo_v8S_confidence_threshold_component.render()
148
+ yolo_v8M_confidence_threshold_component.render()
149
+ iou_threshold_component.render()
150
+ with gr.Row():
151
+ input_image_component = gr.Image(
152
+ type='pil',
153
+ label='Input'
154
+ )
155
+ yolo_v8n_output_image_component = gr.Image(
156
+ type='pil',
157
+ label='YOLOv8N'
158
+ )
159
+ with gr.Row():
160
+ yolo_v8s_output_image_component = gr.Image(
161
+ type='pil',
162
+ label='YOLOv8S'
163
+ )
164
+ yolo_v8m_output_image_component = gr.Image(
165
+ type='pil',
166
+ label='YOLOv8M'
167
+ )
168
+ submit_button_component = gr.Button(
169
+ value='Submit',
170
+ scale=1,
171
+ variant='primary'
172
+ )
173
+ gr.Examples(
174
+ fn=process_image,
175
+ examples=IMAGE_EXAMPLES,
176
+ inputs=[
177
+ input_image_component,
178
+ yolo_v8N_confidence_threshold_component,
179
+ yolo_v8S_confidence_threshold_component,
180
+ yolo_v8M_confidence_threshold_component,
181
+ iou_threshold_component
182
+ ],
183
+ outputs=[
184
+ yolo_v8n_output_image_component,
185
+ yolo_v8s_output_image_component,
186
+ yolo_v8m_output_image_component
187
+ ]
188
+ )
189
+
190
+ submit_button_component.click(
191
+ fn=process_image,
192
+ inputs=[
193
+ input_image_component,
194
+ yolo_v8N_confidence_threshold_component,
195
+ yolo_v8S_confidence_threshold_component,
196
+ yolo_v8M_confidence_threshold_component,
197
+ iou_threshold_component
198
+ ],
199
+ outputs=[
200
+ yolo_v8n_output_image_component,
201
+ yolo_v8s_output_image_component,
202
+ yolo_v8m_output_image_component
203
+ ]
204
+ )
205
+
206
+ demo.launch(debug=False, show_error=True, max_threads=1)