samuellimabraz commited on
Commit
eaa7aa4
·
unverified ·
1 Parent(s): 53fda1d

refactor: separação dos módulos e diretorios

Browse files
Files changed (4) hide show
  1. app.py +2 -348
  2. constants.py +8 -0
  3. detector.py +288 -0
  4. metrics_storage.py +64 -0
app.py CHANGED
@@ -1,357 +1,11 @@
1
  import os
2
- import sqlite3
3
- import time
4
 
5
- import cv2
6
  import gradio as gr
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
- import onnxruntime as ort
10
  import pandas as pd
11
- from huggingface_hub import hf_hub_download
12
  from PIL import Image
13
 
14
- # Model info
15
- REPO_ID = "tech4humans/yolov8s-signature-detector"
16
- FILENAME = "yolov8s.onnx"
17
- MODEL_DIR = "model"
18
- MODEL_PATH = os.path.join(MODEL_DIR, "model.onnx")
19
- DATABASE_DIR = os.path.join(os.getcwd(), "db")
20
- DATABASE_PATH = os.path.join(DATABASE_DIR, "metrics.db")
21
-
22
-
23
- def download_model():
24
- """Download the model using Hugging Face Hub"""
25
- # Ensure model directory exists
26
- os.makedirs(MODEL_DIR, exist_ok=True)
27
-
28
- try:
29
- print(f"Downloading model from {REPO_ID}...")
30
- # Download the model file from Hugging Face Hub
31
- model_path = hf_hub_download(
32
- repo_id=REPO_ID,
33
- filename=FILENAME,
34
- local_dir=MODEL_DIR,
35
- force_download=True,
36
- cache_dir=None,
37
- )
38
-
39
- # Move the file to the correct location if it's not there already
40
- if os.path.exists(model_path) and model_path != MODEL_PATH:
41
- os.rename(model_path, MODEL_PATH)
42
-
43
- # Remove empty directories if they exist
44
- empty_dir = os.path.join(MODEL_DIR, "tune")
45
- if os.path.exists(empty_dir):
46
- import shutil
47
-
48
- shutil.rmtree(empty_dir)
49
-
50
- print("Model downloaded successfully!")
51
- return MODEL_PATH
52
-
53
- except Exception as e:
54
- print(f"Error downloading model: {e}")
55
- raise e
56
-
57
-
58
- class MetricsStorage:
59
- def __init__(self, db_path=DATABASE_PATH):
60
- self.db_path = db_path
61
- self.setup_database()
62
-
63
- def setup_database(self):
64
- """Initialize the SQLite database and create tables if they don't exist"""
65
- with sqlite3.connect(self.db_path) as conn:
66
- cursor = conn.cursor()
67
- cursor.execute(
68
- """
69
- CREATE TABLE IF NOT EXISTS inference_metrics (
70
- id INTEGER PRIMARY KEY AUTOINCREMENT,
71
- inference_time REAL,
72
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
73
- )
74
- """
75
- )
76
- conn.commit()
77
-
78
- def add_metric(self, inference_time):
79
- """Add a new inference time measurement to the database"""
80
- with sqlite3.connect(self.db_path) as conn:
81
- cursor = conn.cursor()
82
- cursor.execute(
83
- "INSERT INTO inference_metrics (inference_time) VALUES (?)",
84
- (inference_time,),
85
- )
86
- conn.commit()
87
-
88
- def get_recent_metrics(self, limit=80):
89
- """Get the most recent metrics from the database"""
90
- with sqlite3.connect(self.db_path) as conn:
91
- cursor = conn.cursor()
92
- cursor.execute(
93
- "SELECT inference_time FROM inference_metrics ORDER BY timestamp DESC LIMIT ?",
94
- (limit,),
95
- )
96
- results = cursor.fetchall()
97
- return [r[0] for r in reversed(results)]
98
-
99
- def get_total_inferences(self):
100
- """Get the total number of inferences recorded"""
101
- with sqlite3.connect(self.db_path) as conn:
102
- cursor = conn.cursor()
103
- cursor.execute("SELECT COUNT(*) FROM inference_metrics")
104
- return cursor.fetchone()[0]
105
-
106
- def get_average_time(self, limit=80):
107
- """Get the average inference time from the most recent entries"""
108
- with sqlite3.connect(self.db_path) as conn:
109
- cursor = conn.cursor()
110
- cursor.execute(
111
- "SELECT AVG(inference_time) FROM (SELECT inference_time FROM inference_metrics ORDER BY timestamp DESC LIMIT ?)",
112
- (limit,),
113
- )
114
- result = cursor.fetchone()[0]
115
- return result if result is not None else 0
116
-
117
-
118
- class SignatureDetector:
119
- def __init__(self, model_path):
120
- self.model_path = model_path
121
- self.classes = ["signature"]
122
- self.input_width = 640
123
- self.input_height = 640
124
-
125
- # Initialize ONNX Runtime session
126
- self.session = ort.InferenceSession(MODEL_PATH)
127
- self.session.set_providers(
128
- ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
129
- )
130
-
131
- self.metrics_storage = MetricsStorage()
132
-
133
- def update_metrics(self, inference_time):
134
- """Update metrics in persistent storage"""
135
- self.metrics_storage.add_metric(inference_time)
136
-
137
- def get_metrics(self):
138
- """Get current metrics from storage"""
139
- times = self.metrics_storage.get_recent_metrics()
140
- total = self.metrics_storage.get_total_inferences()
141
- avg = self.metrics_storage.get_average_time()
142
-
143
- start_index = max(0, total - len(times))
144
-
145
- return {
146
- "times": times,
147
- "total_inferences": total,
148
- "avg_time": avg,
149
- "start_index": start_index, # Adicionar índice inicial
150
- }
151
-
152
- def load_initial_metrics(self):
153
- """Load initial metrics for display"""
154
- metrics = self.get_metrics()
155
-
156
- if not metrics["times"]: # Se não houver dados
157
- return None, None, None, None, None, None
158
-
159
- # Criar plots data
160
- hist_data = pd.DataFrame({"Tempo (ms)": metrics["times"]})
161
- indices = range(
162
- metrics["start_index"], metrics["start_index"] + len(metrics["times"])
163
- )
164
-
165
- line_data = pd.DataFrame(
166
- {
167
- "Inferência": indices,
168
- "Tempo (ms)": metrics["times"],
169
- "Média": [metrics["avg_time"]] * len(metrics["times"]),
170
- }
171
- )
172
-
173
- # Criar plots
174
- hist_fig, line_fig = self.create_plots(hist_data, line_data)
175
-
176
- return (
177
- None,
178
- f"Total de Inferências: {metrics['total_inferences']}",
179
- hist_fig,
180
- line_fig,
181
- f"{metrics['avg_time']:.2f}",
182
- f"{metrics['times'][-1]:.2f}",
183
- )
184
-
185
- def create_plots(self, hist_data, line_data):
186
- """Helper method to create plots"""
187
- plt.style.use("dark_background")
188
-
189
- # Histograma
190
- hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
191
- hist_ax.set_facecolor("#f0f0f5")
192
- hist_data.hist(
193
- bins=20, ax=hist_ax, color="#4F46E5", alpha=0.7, edgecolor="white"
194
- )
195
- hist_ax.set_title(
196
- "Distribuição dos Tempos de Inferência",
197
- pad=15,
198
- fontsize=12,
199
- color="#1f2937",
200
- )
201
- hist_ax.set_xlabel("Tempo (ms)", color="#374151")
202
- hist_ax.set_ylabel("Frequência", color="#374151")
203
- hist_ax.tick_params(colors="#4b5563")
204
- hist_ax.grid(True, linestyle="--", alpha=0.3)
205
-
206
- # Gráfico de linha
207
- line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
208
- line_ax.set_facecolor("#f0f0f5")
209
- line_data.plot(
210
- x="Inferência",
211
- y="Tempo (ms)",
212
- ax=line_ax,
213
- color="#4F46E5",
214
- alpha=0.7,
215
- label="Tempo",
216
- )
217
- line_data.plot(
218
- x="Inferência",
219
- y="Média",
220
- ax=line_ax,
221
- color="#DC2626",
222
- linestyle="--",
223
- label="Média",
224
- )
225
- line_ax.set_title(
226
- "Tempo de Inferência por Execução", pad=15, fontsize=12, color="#1f2937"
227
- )
228
- line_ax.set_xlabel("Número da Inferência", color="#374151")
229
- line_ax.set_ylabel("Tempo (ms)", color="#374151")
230
- line_ax.tick_params(colors="#4b5563")
231
- line_ax.grid(True, linestyle="--", alpha=0.3)
232
- line_ax.legend(frameon=True, facecolor="#f0f0f5", edgecolor="none")
233
-
234
- hist_fig.tight_layout()
235
- line_fig.tight_layout()
236
-
237
- # Fechar as figuras para liberar memória
238
- plt.close(hist_fig)
239
- plt.close(line_fig)
240
-
241
- return hist_fig, line_fig
242
-
243
- def preprocess(self, img):
244
- # Convert PIL Image to cv2 format
245
- img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
246
-
247
- # Get image dimensions
248
- self.img_height, self.img_width = img_cv2.shape[:2]
249
-
250
- # Convert back to RGB for processing
251
- img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
252
-
253
- # Resize
254
- img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height))
255
-
256
- # Normalize and transpose
257
- image_data = np.array(img_resized) / 255.0
258
- image_data = np.transpose(image_data, (2, 0, 1))
259
- image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
260
-
261
- return image_data, img_cv2
262
-
263
- def draw_detections(self, img, box, score, class_id):
264
- x1, y1, w, h = box
265
- self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
266
- color = self.color_palette[class_id]
267
-
268
- cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
269
-
270
- label = f"{self.classes[class_id]}: {score:.2f}"
271
- (label_width, label_height), _ = cv2.getTextSize(
272
- label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
273
- )
274
-
275
- label_x = x1
276
- label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
277
-
278
- cv2.rectangle(
279
- img,
280
- (int(label_x), int(label_y - label_height)),
281
- (int(label_x + label_width), int(label_y + label_height)),
282
- color,
283
- cv2.FILLED,
284
- )
285
-
286
- cv2.putText(
287
- img,
288
- label,
289
- (int(label_x), int(label_y)),
290
- cv2.FONT_HERSHEY_SIMPLEX,
291
- 0.5,
292
- (0, 0, 0),
293
- 1,
294
- cv2.LINE_AA,
295
- )
296
-
297
- def postprocess(self, input_image, output, conf_thres, iou_thres):
298
- outputs = np.transpose(np.squeeze(output[0]))
299
- rows = outputs.shape[0]
300
-
301
- boxes = []
302
- scores = []
303
- class_ids = []
304
-
305
- x_factor = self.img_width / self.input_width
306
- y_factor = self.img_height / self.input_height
307
-
308
- for i in range(rows):
309
- classes_scores = outputs[i][4:]
310
- max_score = np.amax(classes_scores)
311
-
312
- if max_score >= conf_thres:
313
- class_id = np.argmax(classes_scores)
314
- x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
315
-
316
- left = int((x - w / 2) * x_factor)
317
- top = int((y - h / 2) * y_factor)
318
- width = int(w * x_factor)
319
- height = int(h * y_factor)
320
-
321
- class_ids.append(class_id)
322
- scores.append(max_score)
323
- boxes.append([left, top, width, height])
324
-
325
- indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
326
-
327
- for i in indices:
328
- box = boxes[i]
329
- score = scores[i]
330
- class_id = class_ids[i]
331
- self.draw_detections(input_image, box, score, class_id)
332
-
333
- return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
334
-
335
- def detect(self, image, conf_thres=0.25, iou_thres=0.5):
336
- # Preprocess the image
337
- img_data, original_image = self.preprocess(image)
338
-
339
- # Run inference
340
- start_time = time.time()
341
- outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data})
342
- inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds
343
-
344
- # Postprocess the results
345
- output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres)
346
-
347
- self.update_metrics(inference_time)
348
-
349
- return output_image, self.get_metrics()
350
-
351
- def detect_example(self, image, conf_thres=0.25, iou_thres=0.5):
352
- """Wrapper method for examples that returns only the image"""
353
- output_image, _ = self.detect(image, conf_thres, iou_thres)
354
- return output_image
355
 
356
 
357
  def create_gradio_interface():
 
1
  import os
 
 
2
 
 
3
  import gradio as gr
 
 
 
4
  import pandas as pd
 
5
  from PIL import Image
6
 
7
+ from constants import MODEL_PATH, DATABASE_DIR, DATABASE_PATH
8
+ from detector import SignatureDetector, download_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def create_gradio_interface():
constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ REPO_ID = "tech4humans/yolov8s-signature-detector"
4
+ FILENAME = "yolov8s.onnx"
5
+ MODEL_DIR = "model"
6
+ MODEL_PATH = os.path.join(MODEL_DIR, "model.onnx")
7
+ DATABASE_DIR = os.path.join(os.getcwd(), "db")
8
+ DATABASE_PATH = os.path.join(DATABASE_DIR, "metrics.db")
detector.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+ import pandas as pd
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH
12
+ from metrics_storage import MetricsStorage
13
+
14
+
15
+ def download_model():
16
+ """Download the model using Hugging Face Hub"""
17
+ # Ensure model directory exists
18
+ os.makedirs(MODEL_DIR, exist_ok=True)
19
+
20
+ try:
21
+ print(f"Downloading model from {REPO_ID}...")
22
+ # Download the model file from Hugging Face Hub
23
+ model_path = hf_hub_download(
24
+ repo_id=REPO_ID,
25
+ filename=FILENAME,
26
+ local_dir=MODEL_DIR,
27
+ force_download=True,
28
+ cache_dir=None,
29
+ )
30
+
31
+ # Move the file to the correct location if it's not there already
32
+ if os.path.exists(model_path) and model_path != MODEL_PATH:
33
+ os.rename(model_path, MODEL_PATH)
34
+
35
+ # Remove empty directories if they exist
36
+ empty_dir = os.path.join(MODEL_DIR, "tune")
37
+ if os.path.exists(empty_dir):
38
+ import shutil
39
+
40
+ shutil.rmtree(empty_dir)
41
+
42
+ print("Model downloaded successfully!")
43
+ return MODEL_PATH
44
+
45
+ except Exception as e:
46
+ print(f"Error downloading model: {e}")
47
+ raise e
48
+
49
+
50
+ class SignatureDetector:
51
+ def __init__(self, model_path):
52
+ self.model_path = model_path
53
+ self.classes = ["signature"]
54
+ self.input_width = 640
55
+ self.input_height = 640
56
+
57
+ # Initialize ONNX Runtime session
58
+ options = ort.SessionOptions()
59
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
60
+ self.session = ort.InferenceSession(MODEL_PATH, options)
61
+ self.session.set_providers(
62
+ ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
63
+ )
64
+
65
+ self.metrics_storage = MetricsStorage()
66
+
67
+ def update_metrics(self, inference_time):
68
+ """Update metrics in persistent storage"""
69
+ self.metrics_storage.add_metric(inference_time)
70
+
71
+ def get_metrics(self):
72
+ """Get current metrics from storage"""
73
+ times = self.metrics_storage.get_recent_metrics()
74
+ total = self.metrics_storage.get_total_inferences()
75
+ avg = self.metrics_storage.get_average_time()
76
+
77
+ start_index = max(0, total - len(times))
78
+
79
+ return {
80
+ "times": times,
81
+ "total_inferences": total,
82
+ "avg_time": avg,
83
+ "start_index": start_index, # Adicionar índice inicial
84
+ }
85
+
86
+ def load_initial_metrics(self):
87
+ """Load initial metrics for display"""
88
+ metrics = self.get_metrics()
89
+
90
+ if not metrics["times"]: # Se não houver dados
91
+ return None, None, None, None, None, None
92
+
93
+ # Criar plots data
94
+ hist_data = pd.DataFrame({"Tempo (ms)": metrics["times"]})
95
+ indices = range(
96
+ metrics["start_index"], metrics["start_index"] + len(metrics["times"])
97
+ )
98
+
99
+ line_data = pd.DataFrame(
100
+ {
101
+ "Inferência": indices,
102
+ "Tempo (ms)": metrics["times"],
103
+ "Média": [metrics["avg_time"]] * len(metrics["times"]),
104
+ }
105
+ )
106
+
107
+ # Criar plots
108
+ hist_fig, line_fig = self.create_plots(hist_data, line_data)
109
+
110
+ return (
111
+ None,
112
+ f"Total de Inferências: {metrics['total_inferences']}",
113
+ hist_fig,
114
+ line_fig,
115
+ f"{metrics['avg_time']:.2f}",
116
+ f"{metrics['times'][-1]:.2f}",
117
+ )
118
+
119
+ def create_plots(self, hist_data, line_data):
120
+ """Helper method to create plots"""
121
+ plt.style.use("dark_background")
122
+
123
+ # Histograma
124
+ hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
125
+ hist_ax.set_facecolor("#f0f0f5")
126
+ hist_data.hist(
127
+ bins=20, ax=hist_ax, color="#4F46E5", alpha=0.7, edgecolor="white"
128
+ )
129
+ hist_ax.set_title(
130
+ "Distribuição dos Tempos de Inferência",
131
+ pad=15,
132
+ fontsize=12,
133
+ color="#1f2937",
134
+ )
135
+ hist_ax.set_xlabel("Tempo (ms)", color="#374151")
136
+ hist_ax.set_ylabel("Frequência", color="#374151")
137
+ hist_ax.tick_params(colors="#4b5563")
138
+ hist_ax.grid(True, linestyle="--", alpha=0.3)
139
+
140
+ # Gráfico de linha
141
+ line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
142
+ line_ax.set_facecolor("#f0f0f5")
143
+ line_data.plot(
144
+ x="Inferência",
145
+ y="Tempo (ms)",
146
+ ax=line_ax,
147
+ color="#4F46E5",
148
+ alpha=0.7,
149
+ label="Tempo",
150
+ )
151
+ line_data.plot(
152
+ x="Inferência",
153
+ y="Média",
154
+ ax=line_ax,
155
+ color="#DC2626",
156
+ linestyle="--",
157
+ label="Média",
158
+ )
159
+ line_ax.set_title(
160
+ "Tempo de Inferência por Execução", pad=15, fontsize=12, color="#1f2937"
161
+ )
162
+ line_ax.set_xlabel("Número da Inferência", color="#374151")
163
+ line_ax.set_ylabel("Tempo (ms)", color="#374151")
164
+ line_ax.tick_params(colors="#4b5563")
165
+ line_ax.grid(True, linestyle="--", alpha=0.3)
166
+ line_ax.legend(frameon=True, facecolor="#f0f0f5", edgecolor="none")
167
+
168
+ hist_fig.tight_layout()
169
+ line_fig.tight_layout()
170
+
171
+ # Fechar as figuras para liberar memória
172
+ plt.close(hist_fig)
173
+ plt.close(line_fig)
174
+
175
+ return hist_fig, line_fig
176
+
177
+ def preprocess(self, img):
178
+ # Convert PIL Image to cv2 format
179
+ img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
180
+
181
+ # Get image dimensions
182
+ self.img_height, self.img_width = img_cv2.shape[:2]
183
+
184
+ # Convert back to RGB for processing
185
+ img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
186
+
187
+ # Resize
188
+ img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height))
189
+
190
+ # Normalize and transpose
191
+ image_data = np.array(img_resized) / 255.0
192
+ image_data = np.transpose(image_data, (2, 0, 1))
193
+ image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
194
+
195
+ return image_data, img_cv2
196
+
197
+ def draw_detections(self, img, box, score, class_id):
198
+ x1, y1, w, h = box
199
+ self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
200
+ color = self.color_palette[class_id]
201
+
202
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
203
+
204
+ label = f"{self.classes[class_id]}: {score:.2f}"
205
+ (label_width, label_height), _ = cv2.getTextSize(
206
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
207
+ )
208
+
209
+ label_x = x1
210
+ label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
211
+
212
+ cv2.rectangle(
213
+ img,
214
+ (int(label_x), int(label_y - label_height)),
215
+ (int(label_x + label_width), int(label_y + label_height)),
216
+ color,
217
+ cv2.FILLED,
218
+ )
219
+
220
+ cv2.putText(
221
+ img,
222
+ label,
223
+ (int(label_x), int(label_y)),
224
+ cv2.FONT_HERSHEY_SIMPLEX,
225
+ 0.5,
226
+ (0, 0, 0),
227
+ 1,
228
+ cv2.LINE_AA,
229
+ )
230
+
231
+ def postprocess(self, input_image, output, conf_thres, iou_thres):
232
+ outputs = np.transpose(np.squeeze(output[0]))
233
+ rows = outputs.shape[0]
234
+
235
+ boxes = []
236
+ scores = []
237
+ class_ids = []
238
+
239
+ x_factor = self.img_width / self.input_width
240
+ y_factor = self.img_height / self.input_height
241
+
242
+ for i in range(rows):
243
+ classes_scores = outputs[i][4:]
244
+ max_score = np.amax(classes_scores)
245
+
246
+ if max_score >= conf_thres:
247
+ class_id = np.argmax(classes_scores)
248
+ x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
249
+
250
+ left = int((x - w / 2) * x_factor)
251
+ top = int((y - h / 2) * y_factor)
252
+ width = int(w * x_factor)
253
+ height = int(h * y_factor)
254
+
255
+ class_ids.append(class_id)
256
+ scores.append(max_score)
257
+ boxes.append([left, top, width, height])
258
+
259
+ indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
260
+
261
+ for i in indices:
262
+ box = boxes[i]
263
+ score = scores[i]
264
+ class_id = class_ids[i]
265
+ self.draw_detections(input_image, box, score, class_id)
266
+
267
+ return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
268
+
269
+ def detect(self, image, conf_thres=0.25, iou_thres=0.5):
270
+ # Preprocess the image
271
+ img_data, original_image = self.preprocess(image)
272
+
273
+ # Run inference
274
+ start_time = time.time()
275
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data})
276
+ inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds
277
+
278
+ # Postprocess the results
279
+ output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres)
280
+
281
+ self.update_metrics(inference_time)
282
+
283
+ return output_image, self.get_metrics()
284
+
285
+ def detect_example(self, image, conf_thres=0.25, iou_thres=0.5):
286
+ """Wrapper method for examples that returns only the image"""
287
+ output_image, _ = self.detect(image, conf_thres, iou_thres)
288
+ return output_image
metrics_storage.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+
4
+ from constants import DATABASE_DIR, DATABASE_PATH
5
+
6
+
7
+ class MetricsStorage:
8
+ def __init__(self, db_path=DATABASE_PATH):
9
+ self.db_path = db_path
10
+ self.setup_database()
11
+
12
+ def setup_database(self):
13
+ """Initialize the SQLite database and create tables if they don't exist"""
14
+ with sqlite3.connect(self.db_path) as conn:
15
+ cursor = conn.cursor()
16
+ cursor.execute(
17
+ """
18
+ CREATE TABLE IF NOT EXISTS inference_metrics (
19
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
20
+ inference_time REAL,
21
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
22
+ )
23
+ """
24
+ )
25
+ conn.commit()
26
+
27
+ def add_metric(self, inference_time):
28
+ """Add a new inference time measurement to the database"""
29
+ with sqlite3.connect(self.db_path) as conn:
30
+ cursor = conn.cursor()
31
+ cursor.execute(
32
+ "INSERT INTO inference_metrics (inference_time) VALUES (?)",
33
+ (inference_time,),
34
+ )
35
+ conn.commit()
36
+
37
+ def get_recent_metrics(self, limit=80):
38
+ """Get the most recent metrics from the database"""
39
+ with sqlite3.connect(self.db_path) as conn:
40
+ cursor = conn.cursor()
41
+ cursor.execute(
42
+ "SELECT inference_time FROM inference_metrics ORDER BY timestamp DESC LIMIT ?",
43
+ (limit,),
44
+ )
45
+ results = cursor.fetchall()
46
+ return [r[0] for r in reversed(results)]
47
+
48
+ def get_total_inferences(self):
49
+ """Get the total number of inferences recorded"""
50
+ with sqlite3.connect(self.db_path) as conn:
51
+ cursor = conn.cursor()
52
+ cursor.execute("SELECT COUNT(*) FROM inference_metrics")
53
+ return cursor.fetchone()[0]
54
+
55
+ def get_average_time(self, limit=80):
56
+ """Get the average inference time from the most recent entries"""
57
+ with sqlite3.connect(self.db_path) as conn:
58
+ cursor = conn.cursor()
59
+ cursor.execute(
60
+ "SELECT AVG(inference_time) FROM (SELECT inference_time FROM inference_metrics ORDER BY timestamp DESC LIMIT ?)",
61
+ (limit,),
62
+ )
63
+ result = cursor.fetchone()[0]
64
+ return result if result is not None else 0