signature-detection / detector.py
samuellimabraz's picture
fix: Update labels and titles to English in plots
8c88f9c unverified
import os
import time
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import onnxruntime as ort
import pandas as pd
from typing import Tuple
from huggingface_hub import hf_hub_download
from constants import REPO_ID, FILENAME, MODEL_DIR, MODEL_PATH
from metrics_storage import MetricsStorage
def download_model():
"""Download the model using Hugging Face Hub"""
# Ensure model directory exists
os.makedirs(MODEL_DIR, exist_ok=True)
try:
print(f"Downloading model from {REPO_ID}...")
# Download the model file from Hugging Face Hub
model_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
local_dir=MODEL_DIR,
force_download=True,
cache_dir=None,
)
# Move the file to the correct location if it's not there already
if os.path.exists(model_path) and model_path != MODEL_PATH:
os.rename(model_path, MODEL_PATH)
# Remove empty directories if they exist
empty_dir = os.path.join(MODEL_DIR, "tune")
if os.path.exists(empty_dir):
import shutil
shutil.rmtree(empty_dir)
print("Model downloaded successfully!")
return MODEL_PATH
except Exception as e:
print(f"Error downloading model: {e}")
raise e
class SignatureDetector:
def __init__(self, model_path: str = MODEL_PATH):
self.model_path = model_path
self.classes = ["signature"]
self.input_width = 640
self.input_height = 640
# Initialize ONNX Runtime session
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
self.session = ort.InferenceSession(self.model_path, options)
self.session.set_providers(
["OpenVINOExecutionProvider"], [{"device_type": "CPU"}]
)
self.metrics_storage = MetricsStorage()
def update_metrics(self, inference_time: float) -> None:
"""
Updates metrics in persistent storage.
Args:
inference_time (float): The time taken for inference in milliseconds.
"""
self.metrics_storage.add_metric(inference_time)
def get_metrics(self) -> dict:
"""
Retrieves current metrics from storage.
Returns:
dict: A dictionary containing times, total inferences, average time, and start index.
"""
times = self.metrics_storage.get_recent_metrics()
total = self.metrics_storage.get_total_inferences()
avg = self.metrics_storage.get_average_time()
start_index = max(0, total - len(times))
return {
"times": times,
"total_inferences": total,
"avg_time": avg,
"start_index": start_index,
}
def load_initial_metrics(
self,
) -> Tuple[None, str, plt.Figure, plt.Figure, str, str]:
"""
Loads initial metrics for display.
Returns:
tuple: A tuple containing None, total inferences, histogram figure, line figure, average time, and last time.
"""
metrics = self.get_metrics()
if not metrics["times"]:
return None, None, None, None, None, None
hist_data = pd.DataFrame({"Time (ms)": metrics["times"]})
indices = range(
metrics["start_index"], metrics["start_index"] + len(metrics["times"])
)
line_data = pd.DataFrame(
{
"Inference": indices,
"Time (ms)": metrics["times"],
"Mean": [metrics["avg_time"]] * len(metrics["times"]),
}
)
hist_fig, line_fig = self.create_plots(hist_data, line_data)
return (
None,
f"{metrics['total_inferences']}",
hist_fig,
line_fig,
f"{metrics['avg_time']:.2f}",
f"{metrics['times'][-1]:.2f}",
)
def create_plots(
self, hist_data: pd.DataFrame, line_data: pd.DataFrame
) -> Tuple[plt.Figure, plt.Figure]:
"""
Helper method to create plots.
Args:
hist_data (pd.DataFrame): Data for histogram plot.
line_data (pd.DataFrame): Data for line plot.
Returns:
tuple: A tuple containing histogram figure and line figure.
"""
plt.style.use("dark_background")
# Histogram plot
hist_fig, hist_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
hist_ax.set_facecolor("#f0f0f5")
hist_data.hist(
bins=20, ax=hist_ax, color="#4F46E5", alpha=0.7, edgecolor="white"
)
hist_ax.set_title(
"Distribution of Inference Times",
pad=15,
fontsize=12,
color="#1f2937",
)
hist_ax.set_xlabel("Time (ms)", color="#374151")
hist_ax.set_ylabel("Frequency", color="#374151")
hist_ax.tick_params(colors="#4b5563")
hist_ax.grid(True, linestyle="--", alpha=0.3)
# Line plot
line_fig, line_ax = plt.subplots(figsize=(8, 4), facecolor="#f0f0f5")
line_ax.set_facecolor("#f0f0f5")
line_data.plot(
x="Inference",
y="Time (ms)",
ax=line_ax,
color="#4F46E5",
alpha=0.7,
label="Time",
)
line_data.plot(
x="Inference",
y="Mean",
ax=line_ax,
color="#DC2626",
linestyle="--",
label="Mean",
)
line_ax.set_title(
"Inference Time per Execution", pad=15, fontsize=12, color="#1f2937"
)
line_ax.set_xlabel("Inference Number", color="#374151")
line_ax.set_ylabel("Time (ms)", color="#374151")
line_ax.tick_params(colors="#4b5563")
line_ax.grid(True, linestyle="--", alpha=0.3)
line_ax.legend(
frameon=True, facecolor="#f0f0f5", edgecolor="white", labelcolor="black"
)
hist_fig.tight_layout()
line_fig.tight_layout()
plt.close(hist_fig)
plt.close(line_fig)
return hist_fig, line_fig
def preprocess(self, img: Image.Image) -> Tuple[np.ndarray, np.ndarray]:
"""
Preprocesses the image for inference.
Args:
img: The image to process.
Returns:
tuple: A tuple containing the processed image data and the original image.
"""
# Convert PIL Image to cv2 format
img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
self.img_height, self.img_width = img_cv2.shape[:2]
# Convert back to RGB for processing
img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
# Resize
img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height))
# Normalize and transpose
image_data = np.array(img_resized) / 255.0
image_data = np.transpose(image_data, (2, 0, 1))
image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
return image_data, img_cv2
def draw_detections(
self, img: np.ndarray, box: list, score: float, class_id: int
) -> None:
"""
Draws the detections on the image.
Args:
img: The image to draw on.
box (list): The bounding box coordinates.
score (float): The confidence score.
class_id (int): The class ID.
"""
x1, y1, w, h = box
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
color = self.color_palette[class_id]
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
label = f"{self.classes[class_id]}: {score:.2f}"
(label_width, label_height), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
label_x = x1
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
cv2.rectangle(
img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
)
cv2.putText(
img,
label,
(int(label_x), int(label_y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(0, 0, 0),
1,
cv2.LINE_AA,
)
def postprocess(
self,
input_image: np.ndarray,
output: np.ndarray,
conf_thres: float,
iou_thres: float,
) -> np.ndarray:
"""
Postprocesses the output from inference.
Args:
input_image: The input image.
output: The output from inference.
conf_thres (float): Confidence threshold for detection.
iou_thres (float): Intersection over Union threshold for detection.
Returns:
np.ndarray: The output image with detections drawn
"""
outputs = np.transpose(np.squeeze(output[0]))
rows = outputs.shape[0]
boxes = []
scores = []
class_ids = []
x_factor = self.img_width / self.input_width
y_factor = self.img_height / self.input_height
for i in range(rows):
classes_scores = outputs[i][4:]
max_score = np.amax(classes_scores)
if max_score >= conf_thres:
class_id = np.argmax(classes_scores)
x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
left = int((x - w / 2) * x_factor)
top = int((y - h / 2) * y_factor)
width = int(w * x_factor)
height = int(h * y_factor)
class_ids.append(class_id)
scores.append(max_score)
boxes.append([left, top, width, height])
indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
for i in indices:
box = boxes[i]
score = scores[i]
class_id = class_ids[i]
self.draw_detections(input_image, box, score, class_id)
return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
def detect(
self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
) -> Tuple[Image.Image, dict]:
"""
Detects signatures in the given image.
Args:
image: The image to process.
conf_thres (float): Confidence threshold for detection.
iou_thres (float): Intersection over Union threshold for detection.
Returns:
tuple: A tuple containing the output image and metrics.
"""
# Preprocess the image
img_data, original_image = self.preprocess(image)
# Run inference
start_time = time.time()
outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data})
inference_time = (time.time() - start_time) * 1000 # Convert to milliseconds
# Postprocess the results
output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres)
self.update_metrics(inference_time)
return output_image, self.get_metrics()
def detect_example(
self, image: Image.Image, conf_thres: float = 0.25, iou_thres: float = 0.5
) -> Image.Image:
"""
Wrapper method for examples that returns only the image.
Args:
image: The image to process.
conf_thres (float): Confidence threshold for detection.
iou_thres (float): Intersection over Union threshold for detection.
Returns:
The output image.
"""
output_image, _ = self.detect(image, conf_thres, iou_thres)
return output_image