samuellimabraz commited on
Commit
2840cb3
·
unverified ·
1 Parent(s): a285eb6

feat: Adicionar anotações de tipo e documentação para métodos na classe SignatureDetector

Browse files
Files changed (1) hide show
  1. detector.py +113 -24
detector.py CHANGED
@@ -3,9 +3,11 @@ import time
3
 
4
  import cv2
5
  import matplotlib.pyplot as plt
 
6
  import numpy as np
7
  import onnxruntime as ort
8
  import pandas as pd
 
9
  from huggingface_hub import hf_hub_download
10
 
11
  from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH
@@ -48,7 +50,7 @@ def download_model():
48
 
49
 
50
  class SignatureDetector:
51
- def __init__(self, model_path):
52
  self.model_path = model_path
53
  self.classes = ["signature"]
54
  self.input_width = 640
@@ -57,19 +59,29 @@ class SignatureDetector:
57
  # Initialize ONNX Runtime session
58
  options = ort.SessionOptions()
59
  options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
60
- self.session = ort.InferenceSession(MODEL_PATH, options)
61
  self.session.set_providers(
62
  ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
63
  )
64
 
65
  self.metrics_storage = MetricsStorage()
66
 
67
- def update_metrics(self, inference_time):
68
- """Update metrics in persistent storage"""
 
 
 
 
 
69
  self.metrics_storage.add_metric(inference_time)
70
 
71
- def get_metrics(self):
72
- """Get current metrics from storage"""
 
 
 
 
 
73
  times = self.metrics_storage.get_recent_metrics()
74
  total = self.metrics_storage.get_total_inferences()
75
  avg = self.metrics_storage.get_average_time()
@@ -80,17 +92,23 @@ class SignatureDetector:
80
  "times": times,
81
  "total_inferences": total,
82
  "avg_time": avg,
83
- "start_index": start_index, # Adicionar índice inicial
84
  }
85
 
86
- def load_initial_metrics(self):
87
- """Load initial metrics for display"""
 
 
 
 
 
 
 
88
  metrics = self.get_metrics()
89
 
90
- if not metrics["times"]: # Se não houver dados
91
  return None, None, None, None, None, None
92
 
93
- # Criar plots data
94
  hist_data = pd.DataFrame({"Tempo (ms)": metrics["times"]})
95
  indices = range(
96
  metrics["start_index"], metrics["start_index"] + len(metrics["times"])
@@ -104,7 +122,6 @@ class SignatureDetector:
104
  }
105
  )
106
 
107
- # Criar plots
108
  hist_fig, line_fig = self.create_plots(hist_data, line_data)
109
 
110
  return (
@@ -116,11 +133,22 @@ class SignatureDetector:
116
  f"{metrics['times'][-1]:.2f}",
117
  )
118
 
119
- def create_plots(self, hist_data, line_data):
120
- """Helper method to create plots"""
 
 
 
 
 
 
 
 
 
 
 
121
  plt.style.use("dark_background")
122
 
123
- # Histograma
124
  hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
125
  hist_ax.set_facecolor("#f0f0f5")
126
  hist_data.hist(
@@ -137,7 +165,7 @@ class SignatureDetector:
137
  hist_ax.tick_params(colors="#4b5563")
138
  hist_ax.grid(True, linestyle="--", alpha=0.3)
139
 
140
- # Gráfico de linha
141
  line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
142
  line_ax.set_facecolor("#f0f0f5")
143
  line_data.plot(
@@ -168,17 +196,24 @@ class SignatureDetector:
168
  hist_fig.tight_layout()
169
  line_fig.tight_layout()
170
 
171
- # Fechar as figuras para liberar memória
172
  plt.close(hist_fig)
173
  plt.close(line_fig)
174
 
175
  return hist_fig, line_fig
176
 
177
- def preprocess(self, img):
 
 
 
 
 
 
 
 
 
178
  # Convert PIL Image to cv2 format
179
  img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
180
 
181
- # Get image dimensions
182
  self.img_height, self.img_width = img_cv2.shape[:2]
183
 
184
  # Convert back to RGB for processing
@@ -194,7 +229,18 @@ class SignatureDetector:
194
 
195
  return image_data, img_cv2
196
 
197
- def draw_detections(self, img, box, score, class_id):
 
 
 
 
 
 
 
 
 
 
 
198
  x1, y1, w, h = box
199
  self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
200
  color = self.color_palette[class_id]
@@ -228,7 +274,25 @@ class SignatureDetector:
228
  cv2.LINE_AA,
229
  )
230
 
231
- def postprocess(self, input_image, output, conf_thres, iou_thres):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  outputs = np.transpose(np.squeeze(output[0]))
233
  rows = outputs.shape[0]
234
 
@@ -266,7 +330,20 @@ class SignatureDetector:
266
 
267
  return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
268
 
269
- def detect(self, image, conf_thres=0.25, iou_thres=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  # Preprocess the image
271
  img_data, original_image = self.preprocess(image)
272
 
@@ -282,7 +359,19 @@ class SignatureDetector:
282
 
283
  return output_image, self.get_metrics()
284
 
285
- def detect_example(self, image, conf_thres=0.25, iou_thres=0.5):
286
- """Wrapper method for examples that returns only the image"""
 
 
 
 
 
 
 
 
 
 
 
 
287
  output_image, _ = self.detect(image, conf_thres, iou_thres)
288
  return output_image
 
3
 
4
  import cv2
5
  import matplotlib.pyplot as plt
6
+ from PIL import Image
7
  import numpy as np
8
  import onnxruntime as ort
9
  import pandas as pd
10
+ from typing import Tuple
11
  from huggingface_hub import hf_hub_download
12
 
13
  from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH
 
50
 
51
 
52
  class SignatureDetector:
53
+ def __init__(self, model_path: str = MODEL_PATH):
54
  self.model_path = model_path
55
  self.classes = ["signature"]
56
  self.input_width = 640
 
59
  # Initialize ONNX Runtime session
60
  options = ort.SessionOptions()
61
  options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
62
+ self.session = ort.InferenceSession(self.model_path, options)
63
  self.session.set_providers(
64
  ["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
65
  )
66
 
67
  self.metrics_storage = MetricsStorage()
68
 
69
+ def update_metrics(self, inference_time: float) -> None:
70
+ """
71
+ Updates metrics in persistent storage.
72
+
73
+ Args:
74
+ inference_time (float): The time taken for inference in milliseconds.
75
+ """
76
  self.metrics_storage.add_metric(inference_time)
77
 
78
+ def get_metrics(self) -> dict:
79
+ """
80
+ Retrieves current metrics from storage.
81
+
82
+ Returns:
83
+ dict: A dictionary containing times, total inferences, average time, and start index.
84
+ """
85
  times = self.metrics_storage.get_recent_metrics()
86
  total = self.metrics_storage.get_total_inferences()
87
  avg = self.metrics_storage.get_average_time()
 
92
  "times": times,
93
  "total_inferences": total,
94
  "avg_time": avg,
95
+ "start_index": start_index,
96
  }
97
 
98
+ def load_initial_metrics(
99
+ self,
100
+ ) -> Tuple[None, str, plt.Figure, plt.Figure, str, str]:
101
+ """
102
+ Loads initial metrics for display.
103
+
104
+ Returns:
105
+ tuple: A tuple containing None, total inferences, histogram figure, line figure, average time, and last time.
106
+ """
107
  metrics = self.get_metrics()
108
 
109
+ if not metrics["times"]:
110
  return None, None, None, None, None, None
111
 
 
112
  hist_data = pd.DataFrame({"Tempo (ms)": metrics["times"]})
113
  indices = range(
114
  metrics["start_index"], metrics["start_index"] + len(metrics["times"])
 
122
  }
123
  )
124
 
 
125
  hist_fig, line_fig = self.create_plots(hist_data, line_data)
126
 
127
  return (
 
133
  f"{metrics['times'][-1]:.2f}",
134
  )
135
 
136
+ def create_plots(
137
+ self, hist_data: pd.DataFrame, line_data: pd.DataFrame
138
+ ) -> Tuple[plt.Figure, plt.Figure]:
139
+ """
140
+ Helper method to create plots.
141
+
142
+ Args:
143
+ hist_data (pd.DataFrame): Data for histogram plot.
144
+ line_data (pd.DataFrame): Data for line plot.
145
+
146
+ Returns:
147
+ tuple: A tuple containing histogram figure and line figure.
148
+ """
149
  plt.style.use("dark_background")
150
 
151
+ # Histogram plot
152
  hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
153
  hist_ax.set_facecolor("#f0f0f5")
154
  hist_data.hist(
 
165
  hist_ax.tick_params(colors="#4b5563")
166
  hist_ax.grid(True, linestyle="--", alpha=0.3)
167
 
168
+ # Line plot
169
  line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
170
  line_ax.set_facecolor("#f0f0f5")
171
  line_data.plot(
 
196
  hist_fig.tight_layout()
197
  line_fig.tight_layout()
198
 
 
199
  plt.close(hist_fig)
200
  plt.close(line_fig)
201
 
202
  return hist_fig, line_fig
203
 
204
+ def preprocess(self, img: Image.Image) -> Tuple[np.ndarray, np.ndarray]:
205
+ """
206
+ Preprocesses the image for inference.
207
+
208
+ Args:
209
+ img: The image to process.
210
+
211
+ Returns:
212
+ tuple: A tuple containing the processed image data and the original image.
213
+ """
214
  # Convert PIL Image to cv2 format
215
  img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
216
 
 
217
  self.img_height, self.img_width = img_cv2.shape[:2]
218
 
219
  # Convert back to RGB for processing
 
229
 
230
  return image_data, img_cv2
231
 
232
+ def draw_detections(
233
+ self, img: np.ndarray, box: list, score: float, class_id: int
234
+ ) -> None:
235
+ """
236
+ Draws the detections on the image.
237
+
238
+ Args:
239
+ img: The image to draw on.
240
+ box (list): The bounding box coordinates.
241
+ score (float): The confidence score.
242
+ class_id (int): The class ID.
243
+ """
244
  x1, y1, w, h = box
245
  self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
246
  color = self.color_palette[class_id]
 
274
  cv2.LINE_AA,
275
  )
276
 
277
+ def postprocess(
278
+ self,
279
+ input_image: np.ndarray,
280
+ output: np.ndarray,
281
+ conf_thres: float,
282
+ iou_thres: float,
283
+ ) -> np.ndarray:
284
+ """
285
+ Postprocesses the output from inference.
286
+
287
+ Args:
288
+ input_image: The input image.
289
+ output: The output from inference.
290
+ conf_thres (float): Confidence threshold for detection.
291
+ iou_thres (float): Intersection over Union threshold for detection.
292
+
293
+ Returns:
294
+ np.ndarray: The output image with detections drawn
295
+ """
296
  outputs = np.transpose(np.squeeze(output[0]))
297
  rows = outputs.shape[0]
298
 
 
330
 
331
  return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
332
 
333
+ def detect(
334
+ self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
335
+ ) -> Tuple[Image.Image, dict]:
336
+ """
337
+ Detects signatures in the given image.
338
+
339
+ Args:
340
+ image: The image to process.
341
+ conf_thres (float): Confidence threshold for detection.
342
+ iou_thres (float): Intersection over Union threshold for detection.
343
+
344
+ Returns:
345
+ tuple: A tuple containing the output image and metrics.
346
+ """
347
  # Preprocess the image
348
  img_data, original_image = self.preprocess(image)
349
 
 
359
 
360
  return output_image, self.get_metrics()
361
 
362
+ def detect_example(
363
+ self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
364
+ ) -> Image.Image:
365
+ """
366
+ Wrapper method for examples that returns only the image.
367
+
368
+ Args:
369
+ image: The image to process.
370
+ conf_thres (float): Confidence threshold for detection.
371
+ iou_thres (float): Intersection over Union threshold for detection.
372
+
373
+ Returns:
374
+ The output image.
375
+ """
376
  output_image, _ = self.detect(image, conf_thres, iou_thres)
377
  return output_image