wuhp commited on
Commit
40291e5
·
verified ·
1 Parent(s): 178b26a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +803 -0
app.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # --------------------------------------------------------------------
3
+ # A Gradio-based face recognition system that mimics most features of
4
+ # your Streamlit app: real-time webcam, image tests, configuration,
5
+ # database enrollment, searching, user removal, etc.
6
+ # --------------------------------------------------------------------
7
+
8
+ import os
9
+ import sys
10
+ import math
11
+ import requests
12
+ import numpy as np
13
+ import cv2
14
+ import torch
15
+ import pickle
16
+ import logging
17
+ from PIL import Image
18
+ from typing import Optional, Dict, List, Tuple
19
+ from dataclasses import dataclass, field
20
+ from collections import Counter
21
+
22
+ # 3rd-party modules
23
+ import gradio as gr
24
+ from ultralytics import YOLO
25
+ from facenet_pytorch import InceptionResnetV1
26
+ from torchvision import transforms
27
+ from deep_sort_realtime.deepsort_tracker import DeepSort
28
+
29
+ # --------------------------------------------------------------------
30
+ # GLOBALS & CONSTANTS
31
+ # --------------------------------------------------------------------
32
+ logging.basicConfig(
33
+ level=logging.INFO,
34
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
35
+ handlers=[logging.FileHandler('face_pipeline.log'), logging.StreamHandler()],
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+ logging.getLogger('torch').setLevel(logging.ERROR)
40
+ logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR)
41
+
42
+ DEFAULT_MODEL_URL = "https://github.com/wuhplaptop/face-11-n/blob/main/face2.pt?raw=true"
43
+ DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl")
44
+ MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
45
+ CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
46
+
47
+ # If you need blink detection or face mesh, keep or define your landmarks:
48
+ LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
49
+ RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
50
+
51
+ # --------------------------------------------------------------------
52
+ # PIPELINE CONFIG DATACLASS
53
+ # --------------------------------------------------------------------
54
+ @dataclass
55
+ class PipelineConfig:
56
+ detector: Dict = field(default_factory=dict)
57
+ tracker: Dict = field(default_factory=dict)
58
+ recognition: Dict = field(default_factory=dict)
59
+ anti_spoof: Dict = field(default_factory=dict)
60
+ blink: Dict = field(default_factory=dict)
61
+ face_mesh_options: Dict = field(default_factory=dict)
62
+ hand: Dict = field(default_factory=dict)
63
+ eye_color: Dict = field(default_factory=dict)
64
+ enabled_components: Dict = field(default_factory=dict)
65
+ detection_conf_thres: float = 0.4
66
+ recognition_conf_thres: float = 0.85
67
+ bbox_color: Tuple[int, int, int] = (0, 255, 0)
68
+ spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255)
69
+ unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255)
70
+ eye_outline_color: Tuple[int, int, int] = (255, 255, 0)
71
+ blink_text_color: Tuple[int, int, int] = (0, 0, 255)
72
+ hand_landmark_color: Tuple[int, int, int] = (255, 210, 77)
73
+ hand_connection_color: Tuple[int, int, int] = (204, 102, 0)
74
+ hand_text_color: Tuple[int, int, int] = (255, 255, 255)
75
+ mesh_color: Tuple[int, int, int] = (100, 255, 100)
76
+ contour_color: Tuple[int, int, int] = (200, 200, 0)
77
+ iris_color: Tuple[int, int, int] = (255, 0, 255)
78
+ eye_color_text_color: Tuple[int, int, int] = (255, 255, 255)
79
+
80
+ def __post_init__(self):
81
+ self.detector = self.detector or {
82
+ 'model_path': os.path.join(MODEL_DIR, "face2.pt"),
83
+ 'device': 'cuda' if torch.cuda.is_available() else 'cpu',
84
+ }
85
+ self.tracker = self.tracker or {'max_age': 30}
86
+ self.recognition = self.recognition or {'enable': True}
87
+ self.anti_spoof = self.anti_spoof or {'enable': True, 'lap_thresh': 80.0}
88
+ self.blink = self.blink or {'enable': True, 'ear_thresh': 0.25}
89
+ self.face_mesh_options = self.face_mesh_options or {
90
+ 'enable': False,
91
+ 'tesselation': False,
92
+ 'contours': False,
93
+ 'irises': False,
94
+ }
95
+ self.hand = self.hand or {
96
+ 'enable': True,
97
+ 'min_detection_confidence': 0.5,
98
+ 'min_tracking_confidence': 0.5,
99
+ }
100
+ self.eye_color = self.eye_color or {'enable': False}
101
+ self.enabled_components = self.enabled_components or {
102
+ 'detection': True,
103
+ 'tracking': True,
104
+ 'anti_spoof': True,
105
+ 'recognition': True,
106
+ 'blink': True,
107
+ 'face_mesh': False,
108
+ 'hand': True,
109
+ 'eye_color': False,
110
+ }
111
+
112
+ def save(self, path: str):
113
+ try:
114
+ os.makedirs(os.path.dirname(path), exist_ok=True)
115
+ with open(path, 'wb') as f:
116
+ pickle.dump(self.__dict__, f)
117
+ logger.info(f"Saved config to {path}")
118
+ except Exception as e:
119
+ logger.error(f"Config save failed: {str(e)}")
120
+ raise RuntimeError(f"Config save failed: {str(e)}") from e
121
+
122
+ @classmethod
123
+ def load(cls, path: str) -> 'PipelineConfig':
124
+ try:
125
+ if os.path.exists(path):
126
+ with open(path, 'rb') as f:
127
+ data = pickle.load(f)
128
+ return cls(**data)
129
+ return cls()
130
+ except Exception as e:
131
+ logger.error(f"Config load failed: {str(e)}")
132
+ return cls()
133
+
134
+ # --------------------------------------------------------------------
135
+ # FACE DATABASE
136
+ # --------------------------------------------------------------------
137
+ class FaceDatabase:
138
+ def __init__(self, db_path: str = DEFAULT_DB_PATH):
139
+ self.db_path = db_path
140
+ self.embeddings: Dict[str, List[np.ndarray]] = {}
141
+ self._load()
142
+
143
+ def _load(self):
144
+ try:
145
+ if os.path.exists(self.db_path):
146
+ with open(self.db_path, 'rb') as f:
147
+ self.embeddings = pickle.load(f)
148
+ logger.info(f"Loaded database from {self.db_path}")
149
+ except Exception as e:
150
+ logger.error(f"Database load failed: {str(e)}")
151
+ self.embeddings = {}
152
+
153
+ def save(self):
154
+ try:
155
+ os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
156
+ with open(self.db_path, 'wb') as f:
157
+ pickle.dump(self.embeddings, f)
158
+ logger.info(f"Saved database to {self.db_path}")
159
+ except Exception as e:
160
+ logger.error(f"Database save failed: {str(e)}")
161
+ raise RuntimeError(f"Database save failed: {str(e)}") from e
162
+
163
+ def add_embedding(self, label: str, embedding: np.ndarray):
164
+ try:
165
+ if not isinstance(embedding, np.ndarray) or embedding.ndim != 1:
166
+ raise ValueError("Invalid embedding format")
167
+ if label not in self.embeddings:
168
+ self.embeddings[label] = []
169
+ self.embeddings[label].append(embedding)
170
+ logger.debug(f"Added embedding for {label}")
171
+ except Exception as e:
172
+ logger.error(f"Add embedding failed: {str(e)}")
173
+ raise
174
+
175
+ def remove_label(self, label: str):
176
+ try:
177
+ if label in self.embeddings:
178
+ del self.embeddings[label]
179
+ logger.info(f"Removed {label}")
180
+ else:
181
+ logger.warning(f"Label {label} not found")
182
+ except Exception as e:
183
+ logger.error(f"Remove label failed: {str(e)}")
184
+ raise
185
+
186
+ def list_labels(self) -> List[str]:
187
+ return list(self.embeddings.keys())
188
+
189
+ def get_embeddings_by_label(self, label: str) -> Optional[List[np.ndarray]]:
190
+ return self.embeddings.get(label)
191
+
192
+ def search_by_image(self, query_embedding: np.ndarray, threshold: float = 0.7) -> List[Tuple[str, float]]:
193
+ results = []
194
+ for label, embeddings in self.embeddings.items():
195
+ for db_emb in embeddings:
196
+ similarity = FacePipeline.cosine_similarity(query_embedding, db_emb)
197
+ if similarity >= threshold:
198
+ results.append((label, similarity))
199
+ return sorted(results, key=lambda x: x[1], reverse=True)
200
+
201
+ # --------------------------------------------------------------------
202
+ # YOLO FACE DETECTOR
203
+ # --------------------------------------------------------------------
204
+ class YOLOFaceDetector:
205
+ def __init__(self, model_path: str, device: str = 'cpu'):
206
+ self.model = None
207
+ self.device = device
208
+ try:
209
+ if not os.path.exists(model_path):
210
+ logger.info(f"Model file not found at {model_path}. Attempting to download...")
211
+ response = requests.get(DEFAULT_MODEL_URL)
212
+ response.raise_for_status()
213
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
214
+ with open(model_path, 'wb') as f:
215
+ f.write(response.content)
216
+ logger.info(f"Downloaded YOLO model to {model_path}")
217
+
218
+ self.model = YOLO(model_path)
219
+ self.model.to(device)
220
+ logger.info(f"Loaded YOLO model from {model_path}")
221
+ except Exception as e:
222
+ logger.error(f"YOLO initialization failed: {str(e)}")
223
+ raise
224
+
225
+ def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]:
226
+ try:
227
+ results = self.model.predict(
228
+ source=image, conf=conf_thres, verbose=False, device=self.device
229
+ )
230
+ detections = []
231
+ for result in results:
232
+ for box in result.boxes:
233
+ x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
234
+ conf = float(box.conf[0].cpu().numpy())
235
+ cls = int(box.cls[0].cpu().numpy()) if box.cls is not None else 0
236
+ detections.append((int(x1), int(y1), int(x2), int(y2), conf, cls))
237
+ logger.debug(f"Detected {len(detections)} faces.")
238
+ return detections
239
+ except Exception as e:
240
+ logger.error(f"Detection failed: {str(e)}")
241
+ return []
242
+
243
+ # --------------------------------------------------------------------
244
+ # FACE TRACKER
245
+ # --------------------------------------------------------------------
246
+ class FaceTracker:
247
+ def __init__(self, max_age: int = 30):
248
+ self.tracker = DeepSort(max_age=max_age, embedder='mobilenet')
249
+
250
+ def update(self, detections: List[Tuple], frame: np.ndarray):
251
+ try:
252
+ ds_detections = [
253
+ ([x1, y1, x2 - x1, y2 - y1], conf, cls)
254
+ for (x1, y1, x2, y2, conf, cls) in detections
255
+ ]
256
+ tracks = self.tracker.update_tracks(ds_detections, frame=frame)
257
+ logger.debug(f"Updated tracker with {len(tracks)} tracks.")
258
+ return tracks
259
+ except Exception as e:
260
+ logger.error(f"Tracking update failed: {str(e)}")
261
+ return []
262
+
263
+ # --------------------------------------------------------------------
264
+ # FACENET EMBEDDER
265
+ # --------------------------------------------------------------------
266
+ class FaceNetEmbedder:
267
+ def __init__(self, device: str = 'cpu'):
268
+ self.device = device
269
+ self.model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
270
+ self.transform = transforms.Compose([
271
+ transforms.Resize((160, 160)),
272
+ transforms.ToTensor(),
273
+ transforms.Normalize([0.5]*3, [0.5]*3),
274
+ ])
275
+
276
+ def get_embedding(self, face_bgr: np.ndarray) -> Optional[np.ndarray]:
277
+ try:
278
+ face_rgb = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB)
279
+ pil_img = Image.fromarray(face_rgb).convert('RGB')
280
+ tens = self.transform(pil_img).unsqueeze(0).to(self.device)
281
+ with torch.no_grad():
282
+ embedding = self.model(tens)[0].cpu().numpy()
283
+ logger.debug(f"Generated embedding: {embedding[:5]}...")
284
+ return embedding
285
+ except Exception as e:
286
+ logger.error(f"Embedding generation failed: {str(e)}")
287
+ return None
288
+
289
+ # --------------------------------------------------------------------
290
+ # MAIN PIPELINE
291
+ # --------------------------------------------------------------------
292
+ class FacePipeline:
293
+ def __init__(self, config: PipelineConfig):
294
+ self.config = config
295
+ self.detector = None
296
+ self.tracker = None
297
+ self.facenet = None
298
+ self.db = None
299
+ self._initialized = False
300
+
301
+ def initialize(self):
302
+ try:
303
+ self.detector = YOLOFaceDetector(
304
+ model_path=self.config.detector['model_path'],
305
+ device=self.config.detector['device']
306
+ )
307
+ self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
308
+ self.facenet = FaceNetEmbedder(device=self.config.detector['device'])
309
+ self.db = FaceDatabase()
310
+ self._initialized = True
311
+ logger.info("FacePipeline initialized successfully.")
312
+ except Exception as e:
313
+ logger.error(f"Pipeline initialization failed: {str(e)}")
314
+ self._initialized = False
315
+ raise
316
+
317
+ def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
318
+ if not self._initialized:
319
+ logger.error("Pipeline not initialized!")
320
+ return frame, []
321
+
322
+ try:
323
+ detections = self.detector.detect(frame, self.config.detection_conf_thres)
324
+ tracked_objects = self.tracker.update(detections, frame)
325
+ annotated_frame = frame.copy()
326
+ results = []
327
+
328
+ for obj in tracked_objects:
329
+ if not obj.is_confirmed():
330
+ continue
331
+ track_id = obj.track_id
332
+ bbox = obj.to_tlbr()
333
+ x1, y1, x2, y2 = bbox.astype(int)
334
+ conf = getattr(obj, 'score', 1.0)
335
+ cls = getattr(obj, 'class_id', 0)
336
+
337
+ face_roi = frame[y1:y2, x1:x2]
338
+ if face_roi.size == 0:
339
+ logger.warning(f"Empty face ROI for track {track_id}")
340
+ continue
341
+
342
+ # Anti-spoof
343
+ is_spoofed = False
344
+ if self.config.anti_spoof['enable']:
345
+ is_spoofed = not self.is_real_face(face_roi)
346
+ if is_spoofed:
347
+ cls = 1
348
+
349
+ if is_spoofed:
350
+ box_color_bgr = self.config.spoofed_bbox_color[::-1]
351
+ name = "Spoofed"
352
+ similarity = 0.0
353
+ else:
354
+ # Face recognition
355
+ embedding = self.facenet.get_embedding(face_roi)
356
+ if embedding is not None and self.config.recognition['enable']:
357
+ name, similarity = self.recognize_face(
358
+ embedding, self.config.recognition_conf_thres
359
+ )
360
+ else:
361
+ name = "Unknown"
362
+ similarity = 0.0
363
+
364
+ box_color_rgb = (
365
+ self.config.bbox_color
366
+ if name != "Unknown"
367
+ else self.config.unknown_bbox_color
368
+ )
369
+ box_color_bgr = box_color_rgb[::-1]
370
+
371
+ label_text = f"{name}"
372
+ cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), box_color_bgr, 2)
373
+ cv2.putText(
374
+ annotated_frame, label_text, (x1, y1 - 10),
375
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2
376
+ )
377
+
378
+ detection_info = {
379
+ 'track_id': track_id,
380
+ 'bbox': (x1, y1, x2, y2),
381
+ 'confidence': float(conf),
382
+ 'class_id': cls,
383
+ 'name': name,
384
+ 'similarity': float(similarity),
385
+ }
386
+ results.append(detection_info)
387
+
388
+ return annotated_frame, results
389
+
390
+ except Exception as e:
391
+ logger.error(f"Frame processing failed: {str(e)}", exc_info=True)
392
+ return frame, []
393
+
394
+ def is_real_face(self, face_roi: np.ndarray) -> bool:
395
+ try:
396
+ gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY)
397
+ lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
398
+ return lap_var > self.config.anti_spoof['lap_thresh']
399
+ except Exception as e:
400
+ logger.error(f"Anti-spoof check failed: {str(e)}")
401
+ return False
402
+
403
+ def recognize_face(self, embedding: np.ndarray, recognition_threshold: float) -> Tuple[str, float]:
404
+ try:
405
+ best_match = "Unknown"
406
+ best_similarity = 0.0
407
+ for label, embeddings in self.db.embeddings.items():
408
+ for db_emb in embeddings:
409
+ similarity = FacePipeline.cosine_similarity(embedding, db_emb)
410
+ if similarity > best_similarity:
411
+ best_similarity = similarity
412
+ best_match = label
413
+ if best_similarity < recognition_threshold:
414
+ best_match = "Unknown"
415
+ return best_match, best_similarity
416
+ except Exception as e:
417
+ logger.error(f"Face recognition failed: {str(e)}")
418
+ return "Unknown", 0.0
419
+
420
+ @staticmethod
421
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
422
+ return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
423
+
424
+ # --------------------------------------------------------------------
425
+ # GLOBAL pipeline instance (we can store it in a lazy loader)
426
+ # --------------------------------------------------------------------
427
+ pipeline = None
428
+
429
+ def load_pipeline() -> FacePipeline:
430
+ global pipeline
431
+ if pipeline is None:
432
+ logger.info("Loading pipeline for the first time...")
433
+ cfg = PipelineConfig.load(CONFIG_PATH)
434
+ pipeline = FacePipeline(cfg)
435
+ pipeline.initialize()
436
+ return pipeline
437
+
438
+ # --------------------------------------------------------------------
439
+ # GRADIO HELPER FUNCTIONS
440
+ # --------------------------------------------------------------------
441
+ def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
442
+ """
443
+ Convert a hex string (#RRGGBB) into a BGR tuple as used in OpenCV.
444
+ """
445
+ if not hex_str.startswith('#'):
446
+ hex_str = f"#{hex_str}"
447
+ hex_str = hex_str.lstrip('#')
448
+ if len(hex_str) != 6:
449
+ return (255, 0, 0) # fallback to something
450
+ r = int(hex_str[0:2], 16)
451
+ g = int(hex_str[2:4], 16)
452
+ b = int(hex_str[4:6], 16)
453
+ return (b,g,r)
454
+
455
+ def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
456
+ """
457
+ Convert a BGR tuple (as stored in pipeline config) to a #RRGGBB hex string.
458
+ """
459
+ b,g,r = bgr
460
+ return f"#{r:02x}{g:02x}{b:02x}"
461
+
462
+ # --------------------------------------------------------------------
463
+ # TAB: Configuration
464
+ # --------------------------------------------------------------------
465
+ def update_config(
466
+ enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
467
+ show_tesselation, show_contours, show_irises,
468
+ detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
469
+ hand_det_conf, hand_track_conf,
470
+ bbox_hex, spoofed_hex, unknown_hex,
471
+ eye_hex, blink_hex,
472
+ hand_landmark_hex, hand_connection_hex, hand_text_hex,
473
+ mesh_hex, contour_hex, iris_hex,
474
+ eye_color_text_hex
475
+ ):
476
+ # Load pipeline
477
+ pl = load_pipeline()
478
+ cfg = pl.config
479
+
480
+ # Update toggles
481
+ cfg.recognition['enable'] = enable_recognition
482
+ cfg.anti_spoof['enable'] = enable_antispoof
483
+ cfg.blink['enable'] = enable_blink
484
+ cfg.hand['enable'] = enable_hand
485
+ cfg.eye_color['enable'] = enable_eyecolor
486
+ cfg.face_mesh_options['enable'] = enable_facemesh
487
+
488
+ cfg.face_mesh_options['tesselation'] = show_tesselation
489
+ cfg.face_mesh_options['contours'] = show_contours
490
+ cfg.face_mesh_options['irises'] = show_irises
491
+
492
+ # Update thresholds
493
+ cfg.detection_conf_thres = detection_conf
494
+ cfg.recognition_conf_thres = recognition_thresh
495
+ cfg.anti_spoof['lap_thresh'] = antispoof_thresh
496
+ cfg.blink['ear_thresh'] = blink_thresh
497
+ cfg.hand['min_detection_confidence'] = hand_det_conf
498
+ cfg.hand['min_tracking_confidence'] = hand_track_conf
499
+
500
+ # Update color fields
501
+ cfg.bbox_color = hex_to_bgr(bbox_hex)[::-1] # store in (R,G,B)
502
+ cfg.spoofed_bbox_color = hex_to_bgr(spoofed_hex)[::-1]
503
+ cfg.unknown_bbox_color = hex_to_bgr(unknown_hex)[::-1]
504
+ cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1]
505
+ cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1]
506
+ cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1]
507
+ cfg.hand_connection_color = hex_to_bgr(hand_connection_hex)[::-1]
508
+ cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1]
509
+ cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1]
510
+ cfg.contour_color = hex_to_bgr(contour_hex)[::-1]
511
+ cfg.iris_color = hex_to_bgr(iris_hex)[::-1]
512
+ cfg.eye_color_text_color = hex_to_bgr(eye_color_text_hex)[::-1]
513
+
514
+ # Save config
515
+ cfg.save(CONFIG_PATH)
516
+
517
+ return "Configuration saved successfully!"
518
+
519
+ # --------------------------------------------------------------------
520
+ # TAB: Database Management
521
+ # --------------------------------------------------------------------
522
+ def enroll_user(name: str, images: List[np.ndarray]) -> str:
523
+ """
524
+ Enroll user by name using one or more images. images is a list of
525
+ NxMx3 numpy arrays in BGR or RGB depending on Gradio type.
526
+ """
527
+ pl = load_pipeline()
528
+ if not name:
529
+ return "Please provide a user name."
530
+
531
+ if not images or len(images) == 0:
532
+ return "No images provided."
533
+
534
+ count_enrolled = 0
535
+ for img in images:
536
+ if img is None:
537
+ continue
538
+ # Gradio provides images in RGB by default, let's ensure BGR for pipeline
539
+ if img.shape[-1] == 3: # RGB
540
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
541
+ else:
542
+ img_bgr = img
543
+
544
+ # Run YOLO detection on each image
545
+ detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
546
+ for x1, y1, x2, y2, conf, cls in detections:
547
+ face_roi = img_bgr[y1:y2, x1:x2]
548
+ if face_roi.size == 0:
549
+ continue
550
+ emb = pl.facenet.get_embedding(face_roi)
551
+ if emb is not None:
552
+ pl.db.add_embedding(name, emb)
553
+ count_enrolled += 1
554
+
555
+ if count_enrolled > 0:
556
+ pl.db.save()
557
+ return f"Enrolled {name} with {count_enrolled} face(s)!"
558
+ else:
559
+ return "No faces were detected or embedded. Enrollment failed."
560
+
561
+ def search_by_name(name: str) -> str:
562
+ pl = load_pipeline()
563
+ if not name:
564
+ return "No name provided"
565
+ embeddings = pl.db.get_embeddings_by_label(name)
566
+ if embeddings:
567
+ return f"User '{name}' found with {len(embeddings)} embedding(s)."
568
+ else:
569
+ return f"No embeddings found for user '{name}'."
570
+
571
+ def search_by_image(image: np.ndarray) -> str:
572
+ """
573
+ Search database by face in the uploaded image.
574
+ """
575
+ pl = load_pipeline()
576
+ if image is None:
577
+ return "No image uploaded."
578
+
579
+ img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
580
+ detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
581
+ if not detections:
582
+ return "No faces detected in the uploaded image."
583
+
584
+ x1, y1, x2, y2, conf, cls = detections[0]
585
+ face_roi = img_bgr[y1:y2, x1:x2]
586
+ if face_roi.size == 0:
587
+ return "Empty face ROI in the uploaded image."
588
+
589
+ emb = pl.facenet.get_embedding(face_roi)
590
+ if emb is None:
591
+ return "Failed to generate embedding from the face."
592
+
593
+ search_results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres)
594
+ if not search_results:
595
+ return "No matching users found in the database (under current threshold)."
596
+
597
+ lines = []
598
+ for label, sim in search_results:
599
+ lines.append(f" - {label}, similarity={sim:.3f}")
600
+ return "Search results:\n" + "\n".join(lines)
601
+
602
+ def remove_user(label: str) -> str:
603
+ pl = load_pipeline()
604
+ if not label:
605
+ return "No user label selected."
606
+ pl.db.remove_label(label)
607
+ pl.db.save()
608
+ return f"User '{label}' removed."
609
+
610
+ def list_users() -> str:
611
+ pl = load_pipeline()
612
+ labels = pl.db.list_labels()
613
+ if labels:
614
+ return f"Enrolled users:\n{', '.join(labels)}"
615
+ return "No users enrolled."
616
+
617
+ # --------------------------------------------------------------------
618
+ # TAB: Real-Time Recognition
619
+ # --------------------------------------------------------------------
620
+ def process_webcam_frame(frame: np.ndarray) -> Tuple[np.ndarray, str]:
621
+ """
622
+ Called for every incoming webcam frame. Return annotated frame + textual info.
623
+ Gradio delivers frames in RGB.
624
+ """
625
+ if frame is None:
626
+ return None, "No frame."
627
+ pl = load_pipeline()
628
+ frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
629
+ annotated_bgr, detections = pl.process_frame(frame_bgr)
630
+ annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
631
+ return annotated_rgb, str(detections)
632
+
633
+ # --------------------------------------------------------------------
634
+ # TAB: Image Test
635
+ # --------------------------------------------------------------------
636
+ def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
637
+ if img is None:
638
+ return None, "No image uploaded."
639
+ pl = load_pipeline()
640
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
641
+ processed, detections = pl.process_frame(img_bgr)
642
+ out_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
643
+ return out_rgb, str(detections)
644
+
645
+ # --------------------------------------------------------------------
646
+ # BUILD THE GRADIO APP
647
+ # --------------------------------------------------------------------
648
+ def build_app():
649
+ with gr.Blocks() as demo:
650
+ gr.Markdown("# Face Recognition System (Gradio)")
651
+
652
+ with gr.Tab("Real-Time Recognition"):
653
+ gr.Markdown("Live face recognition from your webcam (roughly 'real-time').")
654
+ webcam_input = gr.Video(source="webcam", mirror=True, streaming=True)
655
+ webcam_output = gr.Image()
656
+ webcam_info = gr.Textbox(label="Detections", interactive=False)
657
+ webcam_input.change(
658
+ fn=process_webcam_frame,
659
+ inputs=webcam_input,
660
+ outputs=[webcam_output, webcam_info],
661
+ )
662
+
663
+ with gr.Tab("Image Test"):
664
+ gr.Markdown("Upload a single image for face detection and recognition.")
665
+ image_input = gr.Image(type="numpy", label="Upload Image")
666
+ image_out = gr.Image()
667
+ image_info = gr.Textbox(label="Detections", interactive=False)
668
+ process_btn = gr.Button("Process Image")
669
+
670
+ process_btn.click(
671
+ fn=process_test_image,
672
+ inputs=image_input,
673
+ outputs=[image_out, image_info],
674
+ )
675
+
676
+ with gr.Tab("Configuration"):
677
+ gr.Markdown("Modify the pipeline settings and thresholds here.")
678
+
679
+ with gr.Row():
680
+ enable_recognition = gr.Checkbox(label="Enable Face Recognition", value=True)
681
+ enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True)
682
+ enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True)
683
+ enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True)
684
+ enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False)
685
+ enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False)
686
+
687
+ gr.Markdown("**Face Mesh Options** (only if Face Mesh is enabled):")
688
+ with gr.Row():
689
+ show_tesselation = gr.Checkbox(label="Show Tesselation", value=False)
690
+ show_contours = gr.Checkbox(label="Show Contours", value=False)
691
+ show_irises = gr.Checkbox(label="Show Irises", value=False)
692
+
693
+ gr.Markdown("**Thresholds**")
694
+ detection_conf = gr.Slider(0.0, 1.0, value=0.4, step=0.01, label="Detection Confidence Threshold")
695
+ recognition_thresh = gr.Slider(0.5, 1.0, value=0.85, step=0.01, label="Recognition Similarity Threshold")
696
+ antispoof_thresh = gr.Slider(0, 200, value=80, step=1, label="Anti-Spoof Laplacian Threshold")
697
+ blink_thresh = gr.Slider(0, 0.5, value=0.25, step=0.01, label="Blink EAR Threshold")
698
+ hand_det_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Detection Confidence")
699
+ hand_track_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Tracking Confidence")
700
+
701
+ gr.Markdown("**Color Options (Hex)**")
702
+ bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00")
703
+ spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000")
704
+ unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000")
705
+
706
+ eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00")
707
+ blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff")
708
+
709
+ hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d")
710
+ hand_connection_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600")
711
+ hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff")
712
+
713
+ mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64")
714
+ contour_hex = gr.Textbox(label="Contour Color", value="#c8c800")
715
+ iris_hex = gr.Textbox(label="Iris Color", value="#ff00ff")
716
+
717
+ eye_color_text_hex = gr.Textbox(label="Eye Color Text Color", value="#ffffff")
718
+
719
+ save_btn = gr.Button("Save Configuration")
720
+ save_msg = gr.Textbox(label="", interactive=False)
721
+
722
+ save_btn.click(
723
+ fn=update_config,
724
+ inputs=[
725
+ enable_recognition, enable_antispoof, enable_blink,
726
+ enable_hand, enable_eyecolor, enable_facemesh,
727
+ show_tesselation, show_contours, show_irises,
728
+ detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
729
+ hand_det_conf, hand_track_conf,
730
+ bbox_hex, spoofed_hex, unknown_hex,
731
+ eye_hex, blink_hex,
732
+ hand_landmark_hex, hand_connection_hex, hand_text_hex,
733
+ mesh_hex, contour_hex, iris_hex, eye_color_text_hex
734
+ ],
735
+ outputs=[save_msg],
736
+ )
737
+
738
+ with gr.Tab("Database Management"):
739
+ gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
740
+
741
+ with gr.Accordion("User Enrollment", open=False):
742
+ enroll_name = gr.Textbox(label="Enter name for enrollment")
743
+ enroll_images = gr.Image(type="numpy", label="Upload Enrollment Images", multiple=True)
744
+ enroll_btn = gr.Button("Enroll User")
745
+ enroll_result = gr.Textbox(label="", interactive=False)
746
+ enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_images], outputs=[enroll_result])
747
+
748
+ with gr.Accordion("User Search", open=False):
749
+ search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
750
+ search_name_input = gr.Dropdown(label="Select User", choices=[], value=None, interactive=True)
751
+ search_image_input = gr.Image(type="numpy", label="Upload Image", visible=False)
752
+ search_btn = gr.Button("Search")
753
+ search_result = gr.Textbox(label="", interactive=False)
754
+
755
+ def update_search_visibility(mode):
756
+ if mode == "Name":
757
+ return gr.update(visible=True), gr.update(visible=False)
758
+ else:
759
+ return gr.update(visible=False), gr.update(visible=True)
760
+
761
+ search_mode.change(fn=update_search_visibility,
762
+ inputs=[search_mode],
763
+ outputs=[search_name_input, search_image_input])
764
+
765
+ def search_user(mode, name, img):
766
+ if mode == "Name":
767
+ return search_by_name(name)
768
+ else:
769
+ return search_by_image(img)
770
+
771
+ search_btn.click(fn=search_user,
772
+ inputs=[search_mode, search_name_input, search_image_input],
773
+ outputs=[search_result])
774
+
775
+ with gr.Accordion("User Management Tools", open=False):
776
+ list_btn = gr.Button("List Enrolled Users")
777
+ list_output = gr.Textbox(label="", interactive=False)
778
+ list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_output])
779
+
780
+ # Reload user list dropdown
781
+ def get_user_list():
782
+ pl = load_pipeline()
783
+ return gr.update(choices=pl.db.list_labels())
784
+
785
+ # A dedicated button to refresh the dropdown
786
+ refresh_users_btn = gr.Button("Refresh User List")
787
+ refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[search_name_input])
788
+
789
+ remove_user_select = gr.Dropdown(label="Select User to Remove", choices=[])
790
+ remove_btn = gr.Button("Remove Selected User")
791
+ remove_output = gr.Textbox(label="", interactive=False)
792
+
793
+ remove_btn.click(fn=remove_user, inputs=[remove_user_select], outputs=[remove_output])
794
+ refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[remove_user_select])
795
+
796
+ return demo
797
+
798
+ # --------------------------------------------------------------------
799
+ # MAIN
800
+ # --------------------------------------------------------------------
801
+ if __name__ == "__main__":
802
+ app = build_app()
803
+ app.queue().launch(server_name="0.0.0.0", server_port=7860)