samuellimabraz commited on
Commit
e17c6a5
·
unverified ·
1 Parent(s): 1a2f154

feat: Add signature detection model with Gradio interface

Browse files
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ import gradio as gr
5
+ import os
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ # Model info
9
+ REPO_ID = "tech4humans/yolov8s-signature-detector"
10
+ FILENAME = "tune/trial_10/weights/best.onnx"
11
+ MODEL_DIR = "model"
12
+ MODEL_PATH = os.path.join(MODEL_DIR, "model.onnx")
13
+
14
+ def download_model():
15
+ """Download the model using Hugging Face Hub"""
16
+ # Ensure model directory exists
17
+ os.makedirs(MODEL_DIR, exist_ok=True)
18
+
19
+ try:
20
+ print(f"Downloading model from {REPO_ID}...")
21
+ # Download the model file from Hugging Face Hub
22
+ model_path = hf_hub_download(
23
+ repo_id=REPO_ID,
24
+ filename=FILENAME,
25
+ local_dir=MODEL_DIR,
26
+ local_dir_use_symlinks=False,
27
+ force_download=True,
28
+ cache_dir=None
29
+ )
30
+
31
+ # Move the file to the correct location if it's not there already
32
+ if os.path.exists(model_path) and model_path != MODEL_PATH:
33
+ os.rename(model_path, MODEL_PATH)
34
+
35
+ # Remove empty directories if they exist
36
+ empty_dir = os.path.join(MODEL_DIR, "tune")
37
+ if os.path.exists(empty_dir):
38
+ import shutil
39
+ shutil.rmtree(empty_dir)
40
+
41
+ print("Model downloaded successfully!")
42
+ return MODEL_PATH
43
+
44
+ except Exception as e:
45
+ print(f"Error downloading model: {str(e)}")
46
+ raise e
47
+
48
+ class SignatureDetector:
49
+ def __init__(self, model_path):
50
+ self.model_path = model_path
51
+ self.classes = ["signature"]
52
+ self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
53
+ self.input_width = 640
54
+ self.input_height = 640
55
+
56
+ # Initialize ONNX Runtime session
57
+ self.session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
58
+
59
+ def preprocess(self, img):
60
+ # Convert PIL Image to cv2 format
61
+ img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
62
+
63
+ # Get image dimensions
64
+ self.img_height, self.img_width = img_cv2.shape[:2]
65
+
66
+ # Convert back to RGB for processing
67
+ img_rgb = cv2.cvtColor(img_cv2, cv2.COLOR_BGR2RGB)
68
+
69
+ # Resize
70
+ img_resized = cv2.resize(img_rgb, (self.input_width, self.input_height))
71
+
72
+ # Normalize and transpose
73
+ image_data = np.array(img_resized) / 255.0
74
+ image_data = np.transpose(image_data, (2, 0, 1))
75
+ image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
76
+
77
+ return image_data, img_cv2
78
+
79
+ def draw_detections(self, img, box, score, class_id):
80
+ x1, y1, w, h = box
81
+ color = self.color_palette[class_id]
82
+
83
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
84
+
85
+ label = f"{self.classes[class_id]}: {score:.2f}"
86
+ (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
87
+
88
+ label_x = x1
89
+ label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
90
+
91
+ cv2.rectangle(
92
+ img,
93
+ (int(label_x), int(label_y - label_height)),
94
+ (int(label_x + label_width), int(label_y + label_height)),
95
+ color,
96
+ cv2.FILLED
97
+ )
98
+
99
+ cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
100
+
101
+ def postprocess(self, input_image, output, conf_thres, iou_thres):
102
+ outputs = np.transpose(np.squeeze(output[0]))
103
+ rows = outputs.shape[0]
104
+
105
+ boxes = []
106
+ scores = []
107
+ class_ids = []
108
+
109
+ x_factor = self.img_width / self.input_width
110
+ y_factor = self.img_height / self.input_height
111
+
112
+ for i in range(rows):
113
+ classes_scores = outputs[i][4:]
114
+ max_score = np.amax(classes_scores)
115
+
116
+ if max_score >= conf_thres:
117
+ class_id = np.argmax(classes_scores)
118
+ x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
119
+
120
+ left = int((x - w / 2) * x_factor)
121
+ top = int((y - h / 2) * y_factor)
122
+ width = int(w * x_factor)
123
+ height = int(h * y_factor)
124
+
125
+ class_ids.append(class_id)
126
+ scores.append(max_score)
127
+ boxes.append([left, top, width, height])
128
+
129
+ indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thres, iou_thres)
130
+
131
+ for i in indices:
132
+ box = boxes[i]
133
+ score = scores[i]
134
+ class_id = class_ids[i]
135
+ self.draw_detections(input_image, box, score, class_id)
136
+
137
+ return cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
138
+
139
+ def detect(self, image, conf_thres, iou_thres):
140
+ # Preprocess the image
141
+ img_data, original_image = self.preprocess(image)
142
+
143
+ # Run inference
144
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: img_data})
145
+
146
+ # Postprocess the results
147
+ output_image = self.postprocess(original_image, outputs, conf_thres, iou_thres)
148
+
149
+ return output_image
150
+
151
+ def create_gradio_interface():
152
+ # Download model if it doesn't exist
153
+ if not os.path.exists(MODEL_PATH):
154
+ download_model()
155
+
156
+ # Initialize the detector
157
+ detector = SignatureDetector(MODEL_PATH)
158
+
159
+ # Create Gradio interface
160
+ iface = gr.Interface(
161
+ fn=detector.detect,
162
+ inputs=[
163
+ gr.Image(label="Upload your Document", type="pil"),
164
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.05,
165
+ label="Confidence Threshold",
166
+ info="Adjust the minimum confidence score required for detection"),
167
+ gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05,
168
+ label="IoU Threshold",
169
+ info="Adjust the Intersection over Union threshold for NMS")
170
+ ],
171
+ outputs=gr.Image(label="Detection Results"),
172
+ title="Signature Detector",
173
+ description="Upload an image to detect signatures using YOLOv8. Use the sliders to adjust detection sensitivity.",
174
+ examples=[
175
+ ["assets/images/example_1.jpg", 0.2, 0.5],
176
+ ["assets/images/example_2.jpg", 0.2, 0.5],
177
+ ["assets/images/example_3.jpg", 0.2, 0.5],
178
+ ["assets/images/example_4.jpg", 0.2, 0.5]
179
+ ]
180
+ )
181
+
182
+ return iface
183
+
184
+ if __name__ == "__main__":
185
+ iface = create_gradio_interface()
186
+ iface.launch()
assets/images/example_1.jpg ADDED
assets/images/example_2.jpg ADDED
assets/images/example_3.jpg ADDED
assets/images/example_4.jpg ADDED
requirements.txt ADDED
File without changes