wuhp commited on
Commit
32ae369
·
verified ·
1 Parent(s): a6ec168

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +498 -221
app.py CHANGED
@@ -15,16 +15,15 @@ from collections import Counter
15
  # Gradio
16
  import gradio as gr
17
 
18
- # YOLO (Ultralytics)
19
  from ultralytics import YOLO
20
-
21
- # FaceNet
22
  from facenet_pytorch import InceptionResnetV1
23
  from torchvision import transforms
24
-
25
- # DeepSORT tracking
26
  from deep_sort_realtime.deepsort_tracker import DeepSort
27
 
 
 
 
28
  # --------------------------------------------------------------------
29
  # LOGGING
30
  # --------------------------------------------------------------------
@@ -35,8 +34,9 @@ logging.basicConfig(
35
  )
36
  logger = logging.getLogger(__name__)
37
 
38
- # Mute some debug logs from third-party libraries
39
  logging.getLogger('torch').setLevel(logging.ERROR)
 
40
  logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR)
41
 
42
  # --------------------------------------------------------------------
@@ -47,12 +47,17 @@ DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl")
47
  MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
48
  CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
49
 
50
- # If you still want blink detection or face mesh references, keep them
51
  LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
52
  RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
53
 
 
 
 
 
 
54
  # --------------------------------------------------------------------
55
- # PIPELINE CONFIG
56
  # --------------------------------------------------------------------
57
  @dataclass
58
  class PipelineConfig:
@@ -65,8 +70,10 @@ class PipelineConfig:
65
  hand: Dict = field(default_factory=dict)
66
  eye_color: Dict = field(default_factory=dict)
67
  enabled_components: Dict = field(default_factory=dict)
 
68
  detection_conf_thres: float = 0.4
69
  recognition_conf_thres: float = 0.85
 
70
  bbox_color: Tuple[int, int, int] = (0, 255, 0)
71
  spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255)
72
  unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255)
@@ -210,19 +217,19 @@ class YOLOFaceDetector:
210
  self.device = device
211
  try:
212
  if not os.path.exists(model_path):
213
- logger.info(f"Model file not found at {model_path}. Attempting to download...")
214
- response = requests.get(DEFAULT_MODEL_URL)
215
- response.raise_for_status()
216
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
217
  with open(model_path, 'wb') as f:
218
- f.write(response.content)
219
  logger.info(f"Downloaded YOLO model to {model_path}")
220
 
221
  self.model = YOLO(model_path)
222
  self.model.to(device)
223
  logger.info(f"Loaded YOLO model from {model_path}")
224
  except Exception as e:
225
- logger.error(f"YOLO initialization failed: {str(e)}")
226
  raise
227
 
228
  def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]:
@@ -240,11 +247,11 @@ class YOLOFaceDetector:
240
  logger.debug(f"Detected {len(detections)} faces.")
241
  return detections
242
  except Exception as e:
243
- logger.error(f"Detection failed: {str(e)}")
244
  return []
245
 
246
  # --------------------------------------------------------------------
247
- # FACE TRACKER
248
  # --------------------------------------------------------------------
249
  class FaceTracker:
250
  def __init__(self, max_age: int = 30):
@@ -260,7 +267,7 @@ class FaceTracker:
260
  logger.debug(f"Updated tracker with {len(tracks)} tracks.")
261
  return tracks
262
  except Exception as e:
263
- logger.error(f"Tracking update failed: {str(e)}")
264
  return []
265
 
266
  # --------------------------------------------------------------------
@@ -283,12 +290,229 @@ class FaceNetEmbedder:
283
  tens = self.transform(pil_img).unsqueeze(0).to(self.device)
284
  with torch.no_grad():
285
  embedding = self.model(tens)[0].cpu().numpy()
286
- logger.debug(f"Generated embedding: {embedding[:5]}...")
287
  return embedding
288
  except Exception as e:
289
- logger.error(f"Embedding generation failed: {str(e)}")
290
  return None
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  # --------------------------------------------------------------------
293
  # FACE PIPELINE
294
  # --------------------------------------------------------------------
@@ -299,155 +523,231 @@ class FacePipeline:
299
  self.tracker = None
300
  self.facenet = None
301
  self.db = None
 
302
  self._initialized = False
303
 
304
  def initialize(self):
305
  try:
 
306
  self.detector = YOLOFaceDetector(
307
  model_path=self.config.detector['model_path'],
308
  device=self.config.detector['device']
309
  )
 
310
  self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
 
311
  self.facenet = FaceNetEmbedder(device=self.config.detector['device'])
 
312
  self.db = FaceDatabase()
 
 
 
 
 
 
 
 
313
  self._initialized = True
314
  logger.info("FacePipeline initialized successfully.")
315
  except Exception as e:
316
- logger.error(f"Pipeline initialization failed: {str(e)}")
317
  self._initialized = False
318
  raise
319
 
320
  def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
 
 
 
 
321
  if not self._initialized:
322
- logger.error("Pipeline not initialized!")
323
  return frame, []
324
 
325
  try:
 
326
  detections = self.detector.detect(frame, self.config.detection_conf_thres)
327
- tracked_objects = self.tracker.update(detections, frame)
328
- annotated_frame = frame.copy()
329
  results = []
330
 
331
- for obj in tracked_objects:
 
 
 
 
 
 
 
 
 
332
  if not obj.is_confirmed():
333
  continue
 
334
  track_id = obj.track_id
335
- bbox = obj.to_tlbr()
336
- x1, y1, x2, y2 = bbox.astype(int)
337
  conf = getattr(obj, 'score', 1.0)
338
  cls = getattr(obj, 'class_id', 0)
339
 
340
  face_roi = frame[y1:y2, x1:x2]
341
  if face_roi.size == 0:
342
- logger.warning(f"Empty face ROI for track {track_id}")
343
  continue
344
 
345
- # Anti-spoof
346
  is_spoofed = False
347
- if self.config.anti_spoof['enable']:
348
  is_spoofed = not self.is_real_face(face_roi)
349
  if is_spoofed:
350
- cls = 1
351
 
352
  if is_spoofed:
353
  box_color_bgr = self.config.spoofed_bbox_color[::-1]
354
  name = "Spoofed"
355
  similarity = 0.0
356
  else:
357
- embedding = self.facenet.get_embedding(face_roi)
358
- if embedding is not None and self.config.recognition['enable']:
359
- name, similarity = self.recognize_face(
360
- embedding, self.config.recognition_conf_thres
361
- )
362
  else:
363
  name = "Unknown"
364
  similarity = 0.0
365
 
366
- box_color_rgb = (
367
- self.config.bbox_color
368
- if name != "Unknown"
369
- else self.config.unknown_bbox_color
370
- )
371
  box_color_bgr = box_color_rgb[::-1]
372
 
373
- label_text = f"{name}"
374
- cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), box_color_bgr, 2)
375
- cv2.putText(
376
- annotated_frame, label_text, (x1, y1 - 10),
377
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2
378
- )
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  detection_info = {
381
- 'track_id': track_id,
382
- 'bbox': (x1, y1, x2, y2),
383
- 'confidence': float(conf),
384
- 'class_id': cls,
385
- 'name': name,
386
- 'similarity': float(similarity),
 
 
 
 
 
387
  }
388
  results.append(detection_info)
389
 
390
- return annotated_frame, results
391
 
392
  except Exception as e:
393
- logger.error(f"Frame processing failed: {str(e)}", exc_info=True)
394
  return frame, []
395
 
396
  def is_real_face(self, face_roi: np.ndarray) -> bool:
397
  try:
398
  gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY)
399
- lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
400
- return lap_var > self.config.anti_spoof['lap_thresh']
401
  except Exception as e:
402
- logger.error(f"Anti-spoof check failed: {str(e)}")
403
  return False
404
 
405
- def recognize_face(self, embedding: np.ndarray, recognition_threshold: float) -> Tuple[str, float]:
406
  try:
407
- best_match = "Unknown"
408
- best_similarity = 0.0
409
- for label, embeddings in self.db.embeddings.items():
410
- for db_emb in embeddings:
411
- similarity = FacePipeline.cosine_similarity(embedding, db_emb)
412
- if similarity > best_similarity:
413
- best_similarity = similarity
414
- best_match = label
415
- if best_similarity < recognition_threshold:
416
- best_match = "Unknown"
417
- return best_match, best_similarity
418
  except Exception as e:
419
- logger.error(f"Face recognition failed: {str(e)}")
420
- return "Unknown", 0.0
421
 
422
  @staticmethod
423
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
424
- return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
425
 
426
  # --------------------------------------------------------------------
427
- # GLOBAL PIPELINE
428
  # --------------------------------------------------------------------
429
  pipeline = None
430
  def load_pipeline() -> FacePipeline:
431
  global pipeline
432
  if pipeline is None:
433
- logger.info("Loading pipeline for the first time...")
434
  cfg = PipelineConfig.load(CONFIG_PATH)
435
  pipeline = FacePipeline(cfg)
436
  pipeline.initialize()
437
  return pipeline
438
 
439
  # --------------------------------------------------------------------
440
- # HELPER FUNCTIONS FOR GRADIO
441
  # --------------------------------------------------------------------
442
- def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
443
- if not hex_str.startswith('#'):
444
- hex_str = f"#{hex_str}"
445
- hex_str = hex_str.lstrip('#')
446
- if len(hex_str) != 6:
447
  return (255, 0, 0)
448
- r = int(hex_str[0:2], 16)
449
- g = int(hex_str[2:4], 16)
450
- b = int(hex_str[4:6], 16)
451
  return (b,g,r)
452
 
453
  def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
@@ -455,18 +755,18 @@ def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
455
  return f"#{r:02x}{g:02x}{b:02x}"
456
 
457
  # --------------------------------------------------------------------
458
- # UPDATE CONFIG
459
  # --------------------------------------------------------------------
460
  def update_config(
 
461
  enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
462
  show_tesselation, show_contours, show_irises,
463
- detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
464
- hand_det_conf, hand_track_conf,
465
- bbox_hex, spoofed_hex, unknown_hex,
466
- eye_hex, blink_hex,
467
- hand_landmark_hex, hand_connection_hex, hand_text_hex,
468
- mesh_hex, contour_hex, iris_hex,
469
- eye_color_text_hex
470
  ):
471
  pl = load_pipeline()
472
  cfg = pl.config
@@ -498,7 +798,7 @@ def update_config(
498
  cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1]
499
  cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1]
500
  cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1]
501
- cfg.hand_connection_color = hex_to_bgr(hand_connection_hex)[::-1]
502
  cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1]
503
  cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1]
504
  cfg.contour_color = hex_to_bgr(contour_hex)[::-1]
@@ -508,87 +808,68 @@ def update_config(
508
  cfg.save(CONFIG_PATH)
509
  return "Configuration saved successfully!"
510
 
511
- # --------------------------------------------------------------------
512
- # ENROLL MULTIPLE IMAGES (using gr.File)
513
- # --------------------------------------------------------------------
514
- def enroll_user(name: str, files: List[dict]) -> str:
515
- """
516
- Each item in `files` is a dict with keys:
517
- - 'name': filename
518
- - 'size': file size in bytes
519
- - 'data': the binary file contents
520
- We decode them into OpenCV images and process.
521
- """
522
  pl = load_pipeline()
523
- if not name:
524
  return "Please provide a user name."
525
- if not files or len(files) == 0:
526
  return "No images provided."
527
 
528
- count_enrolled = 0
529
- for f in files:
530
- file_bytes = f["data"]
531
- np_array = np.frombuffer(file_bytes, np.uint8)
532
- img_bgr = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
533
  if img_bgr is None:
534
  continue
535
-
536
- detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
537
- for x1, y1, x2, y2, conf, cls in detections:
538
- face_roi = img_bgr[y1:y2, x1:x2]
539
- if face_roi.size == 0:
540
  continue
541
- emb = pl.facenet.get_embedding(face_roi)
542
  if emb is not None:
543
- pl.db.add_embedding(name, emb)
544
- count_enrolled += 1
545
 
546
- if count_enrolled > 0:
547
  pl.db.save()
548
- return f"Enrolled {name} with {count_enrolled} face(s)!"
549
  else:
550
- return "No faces were detected or embedded. Enrollment failed."
551
 
552
- # --------------------------------------------------------------------
553
- # SEARCH / USER MGMT
554
- # --------------------------------------------------------------------
555
  def search_by_name(name: str) -> str:
556
  pl = load_pipeline()
557
  if not name:
558
- return "No name provided"
559
- embeddings = pl.db.get_embeddings_by_label(name)
560
- if embeddings:
561
- return f"User '{name}' found with {len(embeddings)} embedding(s)."
562
  else:
563
- return f"No embeddings found for user '{name}'."
564
 
565
- def search_by_image(image: np.ndarray) -> str:
566
  pl = load_pipeline()
567
- if image is None:
568
  return "No image uploaded."
569
-
570
- # Convert to BGR for pipeline
571
- img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
572
- detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
573
- if not detections:
574
  return "No faces detected in the uploaded image."
575
-
576
- x1, y1, x2, y2, conf, cls = detections[0]
577
- face_roi = img_bgr[y1:y2, x1:x2]
578
- if face_roi.size == 0:
579
  return "Empty face ROI in the uploaded image."
580
 
581
- emb = pl.facenet.get_embedding(face_roi)
582
  if emb is None:
583
- return "Failed to generate embedding from the face."
584
-
585
- search_results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres)
586
- if not search_results:
587
- return "No matching users found in the database (under current threshold)."
588
-
589
- lines = []
590
- for label, sim in search_results:
591
- lines.append(f" - {label}, similarity={sim:.3f}")
592
  return "Search results:\n" + "\n".join(lines)
593
 
594
  def remove_user(label: str) -> str:
@@ -606,75 +887,74 @@ def list_users() -> str:
606
  return "Enrolled users:\n" + ", ".join(labels)
607
  return "No users enrolled."
608
 
609
- # --------------------------------------------------------------------
610
- # PROCESS SINGLE IMAGE (IMAGE TEST)
611
- # --------------------------------------------------------------------
612
  def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
 
613
  if img is None:
614
  return None, "No image uploaded."
 
615
  pl = load_pipeline()
616
- img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
617
- processed, detections = pl.process_frame(img_bgr)
618
- out_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
619
- return out_rgb, str(detections)
620
 
621
  # --------------------------------------------------------------------
622
- # BUILD THE GRADIO APP
623
  # --------------------------------------------------------------------
624
  def build_app():
625
  with gr.Blocks() as demo:
626
- gr.Markdown("# Face Recognition System (Image-Only)")
627
 
628
- # -- Tab: Image Test --
629
  with gr.Tab("Image Test"):
630
- gr.Markdown("Upload a single image for face detection & recognition:")
631
- image_input = gr.Image(type="numpy", label="Upload Image")
632
- image_out = gr.Image()
633
- image_info = gr.Textbox(label="Detections", interactive=False)
634
  process_btn = gr.Button("Process Image")
635
 
636
  process_btn.click(
637
  fn=process_test_image,
638
- inputs=image_input,
639
- outputs=[image_out, image_info],
640
  )
641
 
642
- # -- Tab: Configuration --
643
  with gr.Tab("Configuration"):
644
- gr.Markdown("Modify pipeline settings & thresholds.")
645
 
646
  with gr.Row():
647
- enable_recognition = gr.Checkbox(label="Enable Face Recognition", value=True)
648
  enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True)
649
  enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True)
650
  enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True)
651
  enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False)
652
  enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False)
653
 
 
654
  gr.Markdown("**Face Mesh Options**")
655
  with gr.Row():
656
- show_tesselation = gr.Checkbox(label="Show Tesselation", value=False)
657
- show_contours = gr.Checkbox(label="Show Contours", value=False)
658
- show_irises = gr.Checkbox(label="Show Irises", value=False)
659
 
660
  gr.Markdown("**Thresholds**")
661
- detection_conf = gr.Slider(0.0, 1.0, value=0.4, step=0.01, label="Detection Confidence Threshold")
662
- recognition_thresh = gr.Slider(0.5, 1.0, value=0.85, step=0.01, label="Recognition Similarity Threshold")
663
- antispoof_thresh = gr.Slider(0, 200, value=80, step=1, label="Anti-Spoof Laplacian Threshold")
664
- blink_thresh = gr.Slider(0, 0.5, value=0.25, step=0.01, label="Blink EAR Threshold")
665
- hand_det_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Detection Confidence")
666
- hand_track_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Tracking Confidence")
667
 
668
  gr.Markdown("**Color Options (Hex)**")
669
  bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00")
670
  spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000")
671
  unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000")
672
-
673
  eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00")
674
  blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff")
675
 
676
  hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d")
677
- hand_connection_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600")
678
  hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff")
679
 
680
  mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64")
@@ -688,84 +968,81 @@ def build_app():
688
  save_btn.click(
689
  fn=update_config,
690
  inputs=[
691
- enable_recognition, enable_antispoof, enable_blink, enable_hand,
692
- enable_eyecolor, enable_facemesh,
693
  show_tesselation, show_contours, show_irises,
694
- detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
695
- hand_det_conf, hand_track_conf,
696
- bbox_hex, spoofed_hex, unknown_hex,
697
- eye_hex, blink_hex,
698
- hand_landmark_hex, hand_connection_hex, hand_text_hex,
699
  mesh_hex, contour_hex, iris_hex, eye_color_text_hex
700
  ],
701
- outputs=[save_msg],
702
  )
703
 
704
- # -- Tab: Database Management --
705
  with gr.Tab("Database Management"):
706
- gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
707
 
708
  with gr.Accordion("User Enrollment", open=False):
709
- enroll_name = gr.Textbox(label="Enter name for enrollment")
710
- # Use gr.File with multiple file support:
711
- enroll_files = gr.File(
712
- file_count="multiple",
713
- type="file",
714
- label="Upload Enrollment Images (multiple allowed)"
715
- )
716
  enroll_btn = gr.Button("Enroll User")
717
- enroll_result = gr.Textbox(label="", interactive=False)
718
- enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_files], outputs=[enroll_result])
 
 
 
 
 
719
 
720
  with gr.Accordion("User Search", open=False):
721
- search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
722
- search_name_input = gr.Dropdown(label="Select User", choices=[], value=None, interactive=True)
723
- search_image_input = gr.Image(type="numpy", label="Upload Image", visible=False)
724
  search_btn = gr.Button("Search")
725
- search_result = gr.Textbox(label="", interactive=False)
726
 
727
- def update_search_visibility(mode):
728
  if mode == "Name":
729
  return gr.update(visible=True), gr.update(visible=False)
730
  else:
731
  return gr.update(visible=False), gr.update(visible=True)
732
 
733
  search_mode.change(
734
- fn=update_search_visibility,
735
  inputs=[search_mode],
736
- outputs=[search_name_input, search_image_input]
737
  )
738
 
739
- def search_user(mode, name, img):
740
  if mode == "Name":
741
- return search_by_name(name)
742
  else:
743
  return search_by_image(img)
744
 
745
  search_btn.click(
746
- fn=search_user,
747
- inputs=[search_mode, search_name_input, search_image_input],
748
- outputs=[search_result]
749
  )
750
 
751
  with gr.Accordion("User Management Tools", open=False):
752
  list_btn = gr.Button("List Enrolled Users")
753
- list_output = gr.Textbox(label="", interactive=False)
754
- list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_output])
755
 
756
- def get_user_list():
757
  pl = load_pipeline()
758
  return gr.update(choices=pl.db.list_labels())
759
 
760
- refresh_users_btn = gr.Button("Refresh User List")
761
- refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[search_name_input])
762
 
763
- remove_user_select = gr.Dropdown(label="Select User to Remove", choices=[])
764
- remove_btn = gr.Button("Remove Selected User")
765
- remove_output = gr.Textbox(label="", interactive=False)
766
 
767
- remove_btn.click(fn=remove_user, inputs=[remove_user_select], outputs=[remove_output])
768
- refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[remove_user_select])
769
 
770
  return demo
771
 
@@ -774,5 +1051,5 @@ def build_app():
774
  # --------------------------------------------------------------------
775
  if __name__ == "__main__":
776
  app = build_app()
777
- # If concurrency or GPU usage is high, you can use .queue():
778
  app.queue().launch(server_name="0.0.0.0", server_port=7860)
 
15
  # Gradio
16
  import gradio as gr
17
 
18
+ # PyTorch, YOLO, FaceNet, and deep_sort
19
  from ultralytics import YOLO
 
 
20
  from facenet_pytorch import InceptionResnetV1
21
  from torchvision import transforms
 
 
22
  from deep_sort_realtime.deepsort_tracker import DeepSort
23
 
24
+ # Mediapipe for face mesh, iris detection, blink detection, and hand tracking
25
+ import mediapipe as mp
26
+
27
  # --------------------------------------------------------------------
28
  # LOGGING
29
  # --------------------------------------------------------------------
 
34
  )
35
  logger = logging.getLogger(__name__)
36
 
37
+ # Mute debug logs from third-party libraries
38
  logging.getLogger('torch').setLevel(logging.ERROR)
39
+ logging.getLogger('mediapipe').setLevel(logging.ERROR)
40
  logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR)
41
 
42
  # --------------------------------------------------------------------
 
47
  MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
48
  CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
49
 
50
+ # Landmark indices for blink detection
51
  LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
52
  RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
53
 
54
+ # Mediapipe references
55
+ mp_drawing = mp.solutions.drawing_utils
56
+ mp_face_mesh = mp.solutions.face_mesh
57
+ mp_hands = mp.solutions.hands
58
+
59
  # --------------------------------------------------------------------
60
+ # DATACLASS: PipelineConfig
61
  # --------------------------------------------------------------------
62
  @dataclass
63
  class PipelineConfig:
 
70
  hand: Dict = field(default_factory=dict)
71
  eye_color: Dict = field(default_factory=dict)
72
  enabled_components: Dict = field(default_factory=dict)
73
+
74
  detection_conf_thres: float = 0.4
75
  recognition_conf_thres: float = 0.85
76
+
77
  bbox_color: Tuple[int, int, int] = (0, 255, 0)
78
  spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255)
79
  unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255)
 
217
  self.device = device
218
  try:
219
  if not os.path.exists(model_path):
220
+ logger.info(f"Model not found at {model_path}. Downloading from GitHub...")
221
+ resp = requests.get(DEFAULT_MODEL_URL)
222
+ resp.raise_for_status()
223
  os.makedirs(os.path.dirname(model_path), exist_ok=True)
224
  with open(model_path, 'wb') as f:
225
+ f.write(resp.content)
226
  logger.info(f"Downloaded YOLO model to {model_path}")
227
 
228
  self.model = YOLO(model_path)
229
  self.model.to(device)
230
  logger.info(f"Loaded YOLO model from {model_path}")
231
  except Exception as e:
232
+ logger.error(f"YOLO init failed: {str(e)}")
233
  raise
234
 
235
  def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]:
 
247
  logger.debug(f"Detected {len(detections)} faces.")
248
  return detections
249
  except Exception as e:
250
+ logger.error(f"Detection error: {str(e)}")
251
  return []
252
 
253
  # --------------------------------------------------------------------
254
+ # FACE TRACKER (DeepSort)
255
  # --------------------------------------------------------------------
256
  class FaceTracker:
257
  def __init__(self, max_age: int = 30):
 
267
  logger.debug(f"Updated tracker with {len(tracks)} tracks.")
268
  return tracks
269
  except Exception as e:
270
+ logger.error(f"Tracking error: {str(e)}")
271
  return []
272
 
273
  # --------------------------------------------------------------------
 
290
  tens = self.transform(pil_img).unsqueeze(0).to(self.device)
291
  with torch.no_grad():
292
  embedding = self.model(tens)[0].cpu().numpy()
293
+ logger.debug(f"Generated embedding sample: {embedding[:5]}...")
294
  return embedding
295
  except Exception as e:
296
+ logger.error(f"Embedding failed: {str(e)}")
297
  return None
298
 
299
+ # --------------------------------------------------------------------
300
+ # BLINK DETECTION
301
+ # --------------------------------------------------------------------
302
+ def detect_blink(face_roi: np.ndarray, threshold: float = 0.25) -> Tuple[bool, float, float, np.ndarray, np.ndarray]:
303
+ """
304
+ Returns:
305
+ (blink_bool, left_ear, right_ear, left_eye_points, right_eye_points).
306
+ """
307
+ try:
308
+ face_mesh_proc = mp_face_mesh.FaceMesh(
309
+ static_image_mode=True,
310
+ max_num_faces=1,
311
+ refine_landmarks=True,
312
+ min_detection_confidence=0.5
313
+ )
314
+ result = face_mesh_proc.process(cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB))
315
+ face_mesh_proc.close()
316
+
317
+ if not result.multi_face_landmarks:
318
+ return False, 0.0, 0.0, None, None
319
+
320
+ landmarks = result.multi_face_landmarks[0].landmark
321
+ h, w = face_roi.shape[:2]
322
+
323
+ def eye_aspect_ratio(indices):
324
+ pts = [(landmarks[i].x * w, landmarks[i].y * h) for i in indices]
325
+ vertical = np.linalg.norm(np.array(pts[1]) - np.array(pts[5])) + \
326
+ np.linalg.norm(np.array(pts[2]) - np.array(pts[4]))
327
+ horizontal = np.linalg.norm(np.array(pts[0]) - np.array(pts[3]))
328
+ return vertical / (2.0 * horizontal + 1e-6)
329
+
330
+ left_ear = eye_aspect_ratio(LEFT_EYE_IDX)
331
+ right_ear = eye_aspect_ratio(RIGHT_EYE_IDX)
332
+
333
+ blink = (left_ear < threshold) and (right_ear < threshold)
334
+
335
+ left_eye_pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in LEFT_EYE_IDX])
336
+ right_eye_pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in RIGHT_EYE_IDX])
337
+
338
+ return blink, left_ear, right_ear, left_eye_pts, right_eye_pts
339
+
340
+ except Exception as e:
341
+ logger.error(f"Blink detection error: {str(e)}")
342
+ return False, 0.0, 0.0, None, None
343
+
344
+ # --------------------------------------------------------------------
345
+ # FACE MESH + IRIS DETECTION / DRAWING
346
+ # --------------------------------------------------------------------
347
+ def process_face_mesh(face_roi: np.ndarray):
348
+ try:
349
+ fm_proc = mp_face_mesh.FaceMesh(
350
+ static_image_mode=True,
351
+ max_num_faces=1,
352
+ refine_landmarks=True,
353
+ min_detection_confidence=0.5
354
+ )
355
+ result = fm_proc.process(cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB))
356
+ fm_proc.close()
357
+ if result.multi_face_landmarks:
358
+ return result.multi_face_landmarks[0]
359
+ return None
360
+ except Exception as e:
361
+ logger.error(f"Face mesh error: {str(e)}")
362
+ return None
363
+
364
+ def draw_face_mesh(image: np.ndarray, face_landmarks, config: Dict, pipeline_config: PipelineConfig):
365
+ mesh_color_bgr = pipeline_config.mesh_color[::-1]
366
+ contour_color_bgr = pipeline_config.contour_color[::-1]
367
+ iris_color_bgr = pipeline_config.iris_color[::-1]
368
+
369
+ if config.get('tesselation'):
370
+ mp_drawing.draw_landmarks(
371
+ image,
372
+ face_landmarks,
373
+ mp_face_mesh.FACEMESH_TESSELATION,
374
+ landmark_drawing_spec=mp_drawing.DrawingSpec(color=mesh_color_bgr, thickness=1, circle_radius=1),
375
+ connection_drawing_spec=mp_drawing.DrawingSpec(color=mesh_color_bgr, thickness=1),
376
+ )
377
+ if config.get('contours'):
378
+ mp_drawing.draw_landmarks(
379
+ image,
380
+ face_landmarks,
381
+ mp_face_mesh.FACEMESH_CONTOURS,
382
+ landmark_drawing_spec=None,
383
+ connection_drawing_spec=mp_drawing.DrawingSpec(color=contour_color_bgr, thickness=2)
384
+ )
385
+ if config.get('irises'):
386
+ mp_drawing.draw_landmarks(
387
+ image,
388
+ face_landmarks,
389
+ mp_face_mesh.FACEMESH_IRISES,
390
+ landmark_drawing_spec=None,
391
+ connection_drawing_spec=mp_drawing.DrawingSpec(color=iris_color_bgr, thickness=2)
392
+ )
393
+
394
+ # --------------------------------------------------------------------
395
+ # EYE COLOR DETECTION
396
+ # --------------------------------------------------------------------
397
+ EYE_COLOR_RANGES = {
398
+ "amber": (255, 191, 0),
399
+ "blue": (0, 0, 255),
400
+ "brown": (139, 69, 19),
401
+ "green": (0, 128, 0),
402
+ "gray": (128, 128, 128),
403
+ "hazel": (102, 51, 0),
404
+ }
405
+
406
+ def classify_eye_color(rgb_color: Tuple[int,int,int]) -> str:
407
+ if rgb_color is None:
408
+ return "Unknown"
409
+ min_dist = float('inf')
410
+ best = "Unknown"
411
+ for color_name, ref_rgb in EYE_COLOR_RANGES.items():
412
+ dist = math.sqrt(sum([(a-b)**2 for a,b in zip(rgb_color, ref_rgb)]))
413
+ if dist < min_dist:
414
+ min_dist = dist
415
+ best = color_name
416
+ return best
417
+
418
+ def get_dominant_color(image_roi, k=3):
419
+ if image_roi.size == 0:
420
+ return None
421
+ pixels = np.float32(image_roi.reshape(-1, 3))
422
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.1)
423
+ _, labels, palette = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
424
+ _, counts = np.unique(labels, return_counts=True)
425
+ dom_color = tuple(palette[np.argmax(counts)].astype(int).tolist())
426
+ return dom_color
427
+
428
+ def detect_eye_color(face_roi: np.ndarray, face_landmarks) -> Optional[str]:
429
+ if face_landmarks is None:
430
+ return None
431
+ h, w = face_roi.shape[:2]
432
+ iris_inds = set()
433
+ for conn in mp_face_mesh.FACEMESH_IRISES:
434
+ iris_inds.update(conn)
435
+
436
+ iris_points = []
437
+ for idx in iris_inds:
438
+ lm = face_landmarks.landmark[idx]
439
+ iris_points.append((int(lm.x * w), int(lm.y * h)))
440
+ if not iris_points:
441
+ return None
442
+
443
+ min_x = min(pt[0] for pt in iris_points)
444
+ max_x = max(pt[0] for pt in iris_points)
445
+ min_y = min(pt[1] for pt in iris_points)
446
+ max_y = max(pt[1] for pt in iris_points)
447
+
448
+ pad = 5
449
+ x1 = max(0, min_x - pad)
450
+ y1 = max(0, min_y - pad)
451
+ x2 = min(w, max_x + pad)
452
+ y2 = min(h, max_y + pad)
453
+
454
+ eye_roi = face_roi[y1:y2, x1:x2]
455
+ # Resize for more stable KMeans
456
+ eye_roi_resize = cv2.resize(eye_roi, (40, 40), interpolation=cv2.INTER_AREA)
457
+
458
+ if eye_roi_resize.size == 0:
459
+ return None
460
+
461
+ dom_rgb = get_dominant_color(eye_roi_resize)
462
+ if dom_rgb is not None:
463
+ return classify_eye_color(dom_rgb)
464
+ return None
465
+
466
+ # --------------------------------------------------------------------
467
+ # HAND TRACKER
468
+ # --------------------------------------------------------------------
469
+ class HandTracker:
470
+ def __init__(self, min_detection_confidence=0.5, min_tracking_confidence=0.5):
471
+ self.hands = mp_hands.Hands(
472
+ static_image_mode=True,
473
+ max_num_hands=2,
474
+ min_detection_confidence=min_detection_confidence,
475
+ min_tracking_confidence=min_tracking_confidence,
476
+ )
477
+ logger.info("Initialized Mediapipe HandTracking")
478
+
479
+ def detect_hands(self, image: np.ndarray):
480
+ try:
481
+ img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
482
+ results = self.hands.process(img_rgb)
483
+ return results.multi_hand_landmarks, results.multi_handedness
484
+ except Exception as e:
485
+ logger.error(f"Hand detection error: {str(e)}")
486
+ return None, None
487
+
488
+ def draw_hands(self, image: np.ndarray, hand_landmarks, handedness, config):
489
+ if not hand_landmarks:
490
+ return image
491
+
492
+ mpdraw = mp_drawing
493
+ for i, hlms in enumerate(hand_landmarks):
494
+ # Convert user config colors from (R,G,B) to (B,G,R)
495
+ hl_color = config.hand_landmark_color[::-1]
496
+ hc_color = config.hand_connection_color[::-1]
497
+ mpdraw.draw_landmarks(
498
+ image,
499
+ hlms,
500
+ mp_hands.HAND_CONNECTIONS,
501
+ mpdraw.DrawingSpec(color=hl_color, thickness=2, circle_radius=4),
502
+ mpdraw.DrawingSpec(color=hc_color, thickness=2, circle_radius=2),
503
+ )
504
+ if handedness and i < len(handedness):
505
+ label = handedness[i].classification[0].label
506
+ score = handedness[i].classification[0].score
507
+ text = f"{label}: {score:.2f}"
508
+ # We'll place text near the wrist
509
+ wrist_lm = hlms.landmark[mp_hands.HandLandmark.WRIST]
510
+ h, w, _ = image.shape
511
+ cx, cy = int(wrist_lm.x * w), int(wrist_lm.y * h)
512
+ ht_color = config.hand_text_color[::-1]
513
+ cv2.putText(image, text, (cx, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, ht_color, 2)
514
+ return image
515
+
516
  # --------------------------------------------------------------------
517
  # FACE PIPELINE
518
  # --------------------------------------------------------------------
 
523
  self.tracker = None
524
  self.facenet = None
525
  self.db = None
526
+ self.hand_tracker = None
527
  self._initialized = False
528
 
529
  def initialize(self):
530
  try:
531
+ # YOLO for face detection
532
  self.detector = YOLOFaceDetector(
533
  model_path=self.config.detector['model_path'],
534
  device=self.config.detector['device']
535
  )
536
+ # DeepSort tracking
537
  self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
538
+ # FaceNet embedder
539
  self.facenet = FaceNetEmbedder(device=self.config.detector['device'])
540
+ # Database
541
  self.db = FaceDatabase()
542
+
543
+ # Hand tracker if enabled
544
+ if self.config.hand['enable']:
545
+ self.hand_tracker = HandTracker(
546
+ min_detection_confidence=self.config.hand['min_detection_confidence'],
547
+ min_tracking_confidence=self.config.hand['min_tracking_confidence']
548
+ )
549
+
550
  self._initialized = True
551
  logger.info("FacePipeline initialized successfully.")
552
  except Exception as e:
553
+ logger.error(f"Initialization failed: {str(e)}")
554
  self._initialized = False
555
  raise
556
 
557
  def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
558
+ """
559
+ Main pipeline processing: detection, tracking, hand detection, face mesh, blink detection, etc.
560
+ Returns annotated_frame, detection_results.
561
+ """
562
  if not self._initialized:
563
+ logger.error("Pipeline not initialized.")
564
  return frame, []
565
 
566
  try:
567
+ # YOLO detection + DeepSort tracking
568
  detections = self.detector.detect(frame, self.config.detection_conf_thres)
569
+ tracked_objs = self.tracker.update(detections, frame)
570
+ annotated = frame.copy()
571
  results = []
572
 
573
+ # Hand detection if enabled
574
+ hand_landmarks_list = None
575
+ handedness_list = None
576
+ if self.config.hand['enable'] and self.hand_tracker:
577
+ hand_landmarks_list, handedness_list = self.hand_tracker.detect_hands(annotated)
578
+ annotated = self.hand_tracker.draw_hands(
579
+ annotated, hand_landmarks_list, handedness_list, self.config
580
+ )
581
+
582
+ for obj in tracked_objs:
583
  if not obj.is_confirmed():
584
  continue
585
+
586
  track_id = obj.track_id
587
+ bbox = obj.to_tlbr().astype(int)
588
+ x1, y1, x2, y2 = bbox
589
  conf = getattr(obj, 'score', 1.0)
590
  cls = getattr(obj, 'class_id', 0)
591
 
592
  face_roi = frame[y1:y2, x1:x2]
593
  if face_roi.size == 0:
594
+ logger.warning(f"Empty face ROI for track={track_id}")
595
  continue
596
 
597
+ # Anti-spoofing
598
  is_spoofed = False
599
+ if self.config.anti_spoof.get('enable', True):
600
  is_spoofed = not self.is_real_face(face_roi)
601
  if is_spoofed:
602
+ cls = 1 # Mark as spoofed
603
 
604
  if is_spoofed:
605
  box_color_bgr = self.config.spoofed_bbox_color[::-1]
606
  name = "Spoofed"
607
  similarity = 0.0
608
  else:
609
+ # Face embedding + recognition
610
+ emb = self.facenet.get_embedding(face_roi)
611
+ if emb is not None and self.config.recognition.get('enable', True):
612
+ name, similarity = self.recognize_face(emb, self.config.recognition_conf_thres)
 
613
  else:
614
  name = "Unknown"
615
  similarity = 0.0
616
 
617
+ box_color_rgb = (self.config.bbox_color if name != "Unknown"
618
+ else self.config.unknown_bbox_color)
 
 
 
619
  box_color_bgr = box_color_rgb[::-1]
620
 
621
+ label_text = name
622
+ cv2.rectangle(annotated, (x1, y1), (x2, y2), box_color_bgr, 2)
623
+ cv2.putText(annotated, label_text, (x1, y1 - 10),
624
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2)
 
 
625
 
626
+ # Blink detection
627
+ blink = False
628
+ if self.config.blink.get('enable', False):
629
+ blink, left_ear, right_ear, left_eye_pts, right_eye_pts = detect_blink(
630
+ face_roi, threshold=self.config.blink.get('ear_thresh', 0.25)
631
+ )
632
+ if left_eye_pts is not None and right_eye_pts is not None:
633
+ # Shift points to global coords
634
+ le_g = left_eye_pts + np.array([x1, y1])
635
+ re_g = right_eye_pts + np.array([x1, y1])
636
+ # Outline eyes
637
+ eye_outline_bgr = self.config.eye_outline_color[::-1]
638
+ cv2.polylines(annotated, [le_g], True, eye_outline_bgr, 1)
639
+ cv2.polylines(annotated, [re_g], True, eye_outline_bgr, 1)
640
+ if blink:
641
+ blink_msg_color = self.config.blink_text_color[::-1]
642
+ cv2.putText(annotated, "Blink Detected",
643
+ (x1, y2 + 20),
644
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
645
+ blink_msg_color, 2)
646
+
647
+ # Face mesh + eye color
648
+ face_mesh_landmarks = None
649
+ eye_color_name = None
650
+ if (self.config.face_mesh_options.get('enable') or
651
+ self.config.eye_color.get('enable')):
652
+ face_mesh_landmarks = process_face_mesh(face_roi)
653
+ if face_mesh_landmarks:
654
+ # If user wants to draw face mesh
655
+ if self.config.face_mesh_options.get('enable', False):
656
+ draw_face_mesh(
657
+ annotated[y1:y2, x1:x2],
658
+ face_mesh_landmarks,
659
+ self.config.face_mesh_options,
660
+ self.config
661
+ )
662
+ # Eye color
663
+ if self.config.eye_color.get('enable', False):
664
+ color_found = detect_eye_color(face_roi, face_mesh_landmarks)
665
+ if color_found:
666
+ eye_color_name = color_found
667
+ text_col_bgr = self.config.eye_color_text_color[::-1]
668
+ cv2.putText(
669
+ annotated, f"Eye Color: {eye_color_name}",
670
+ (x1, y2 + 40),
671
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
672
+ text_col_bgr, 2
673
+ )
674
+
675
+ # Record result
676
  detection_info = {
677
+ "track_id": track_id,
678
+ "bbox": (x1, y1, x2, y2),
679
+ "confidence": float(conf),
680
+ "class_id": cls,
681
+ "name": name,
682
+ "similarity": similarity,
683
+ "blink": blink if self.config.blink.get('enable') else None,
684
+ "face_mesh": bool(face_mesh_landmarks) if self.config.face_mesh_options.get('enable') else False,
685
+ "hands_detected": bool(hand_landmarks_list),
686
+ "hand_count": len(hand_landmarks_list) if hand_landmarks_list else 0,
687
+ "eye_color": eye_color_name if self.config.eye_color.get('enable') else None
688
  }
689
  results.append(detection_info)
690
 
691
+ return annotated, results
692
 
693
  except Exception as e:
694
+ logger.error(f"Frame process error: {str(e)}")
695
  return frame, []
696
 
697
  def is_real_face(self, face_roi: np.ndarray) -> bool:
698
  try:
699
  gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY)
700
+ lapv = cv2.Laplacian(gray, cv2.CV_64F).var()
701
+ return lapv > self.config.anti_spoof.get('lap_thresh', 80.0)
702
  except Exception as e:
703
+ logger.error(f"Anti-spoof error: {str(e)}")
704
  return False
705
 
706
+ def recognize_face(self, embedding: np.ndarray, threshold: float) -> Tuple[str, float]:
707
  try:
708
+ best_name = "Unknown"
709
+ best_sim = 0.0
710
+ for lbl, embs in self.db.embeddings.items():
711
+ for db_emb in embs:
712
+ sim = FacePipeline.cosine_similarity(embedding, db_emb)
713
+ if sim > best_sim:
714
+ best_sim = sim
715
+ best_name = lbl
716
+ if best_sim < threshold:
717
+ best_name = "Unknown"
718
+ return best_name, best_sim
719
  except Exception as e:
720
+ logger.error(f"Recognition error: {str(e)}")
721
+ return ("Unknown", 0.0)
722
 
723
  @staticmethod
724
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
725
+ return float(np.dot(a, b) / ((np.linalg.norm(a)*np.linalg.norm(b)) + 1e-6))
726
 
727
  # --------------------------------------------------------------------
728
+ # GLOBAL LOADER
729
  # --------------------------------------------------------------------
730
  pipeline = None
731
  def load_pipeline() -> FacePipeline:
732
  global pipeline
733
  if pipeline is None:
 
734
  cfg = PipelineConfig.load(CONFIG_PATH)
735
  pipeline = FacePipeline(cfg)
736
  pipeline.initialize()
737
  return pipeline
738
 
739
  # --------------------------------------------------------------------
740
+ # UTILITY: HEX <-> BGR
741
  # --------------------------------------------------------------------
742
+ def hex_to_bgr(hexstr: str) -> Tuple[int,int,int]:
743
+ if not hexstr.startswith('#'):
744
+ hexstr = '#' + hexstr
745
+ h = hexstr.lstrip('#')
746
+ if len(h) != 6:
747
  return (255, 0, 0)
748
+ r = int(h[0:2], 16)
749
+ g = int(h[2:4], 16)
750
+ b = int(h[4:6], 16)
751
  return (b,g,r)
752
 
753
  def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
 
755
  return f"#{r:02x}{g:02x}{b:02x}"
756
 
757
  # --------------------------------------------------------------------
758
+ # GRADIO CALLBACKS
759
  # --------------------------------------------------------------------
760
  def update_config(
761
+ # toggles
762
  enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
763
  show_tesselation, show_contours, show_irises,
764
+ # thresholds
765
+ detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, hand_det_conf, hand_track_conf,
766
+ # colors
767
+ bbox_hex, spoofed_hex, unknown_hex, eye_hex, blink_hex,
768
+ hand_landmark_hex, hand_connect_hex, hand_text_hex,
769
+ mesh_hex, contour_hex, iris_hex, eye_color_text_hex
 
770
  ):
771
  pl = load_pipeline()
772
  cfg = pl.config
 
798
  cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1]
799
  cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1]
800
  cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1]
801
+ cfg.hand_connection_color = hex_to_bgr(hand_connect_hex)[::-1]
802
  cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1]
803
  cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1]
804
  cfg.contour_color = hex_to_bgr(contour_hex)[::-1]
 
808
  cfg.save(CONFIG_PATH)
809
  return "Configuration saved successfully!"
810
 
811
+ def enroll_user(label_name: str, filepaths: List[str]) -> str:
812
+ """Enrolls a user by name using multiple image file paths."""
 
 
 
 
 
 
 
 
 
813
  pl = load_pipeline()
814
+ if not label_name:
815
  return "Please provide a user name."
816
+ if not filepaths or len(filepaths) == 0:
817
  return "No images provided."
818
 
819
+ enrolled_count = 0
820
+ for path in filepaths:
821
+ if not os.path.isfile(path):
822
+ continue
823
+ img_bgr = cv2.imread(path)
824
  if img_bgr is None:
825
  continue
826
+ # Detect face(s)
827
+ dets = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
828
+ for x1, y1, x2, y2, conf, cls in dets:
829
+ roi = img_bgr[y1:y2, x1:x2]
830
+ if roi.size == 0:
831
  continue
832
+ emb = pl.facenet.get_embedding(roi)
833
  if emb is not None:
834
+ pl.db.add_embedding(label_name, emb)
835
+ enrolled_count += 1
836
 
837
+ if enrolled_count > 0:
838
  pl.db.save()
839
+ return f"Enrolled '{label_name}' with {enrolled_count} face(s)!"
840
  else:
841
+ return "No faces detected in provided images."
842
 
 
 
 
843
  def search_by_name(name: str) -> str:
844
  pl = load_pipeline()
845
  if not name:
846
+ return "No name entered."
847
+ embs = pl.db.get_embeddings_by_label(name)
848
+ if embs:
849
+ return f"'{name}' found with {len(embs)} embedding(s)."
850
  else:
851
+ return f"No embeddings found for '{name}'."
852
 
853
+ def search_by_image(img: np.ndarray) -> str:
854
  pl = load_pipeline()
855
+ if img is None:
856
  return "No image uploaded."
857
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
858
+ dets = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
859
+ if not dets:
 
 
860
  return "No faces detected in the uploaded image."
861
+ x1, y1, x2, y2, conf, cls = dets[0]
862
+ roi = img_bgr[y1:y2, x1:x2]
863
+ if roi.size == 0:
 
864
  return "Empty face ROI in the uploaded image."
865
 
866
+ emb = pl.facenet.get_embedding(roi)
867
  if emb is None:
868
+ return "Could not generate embedding from face."
869
+ results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres)
870
+ if not results:
871
+ return "No matches in the database under current threshold."
872
+ lines = [f"- {lbl} (sim={sim:.3f})" for lbl, sim in results]
 
 
 
 
873
  return "Search results:\n" + "\n".join(lines)
874
 
875
  def remove_user(label: str) -> str:
 
887
  return "Enrolled users:\n" + ", ".join(labels)
888
  return "No users enrolled."
889
 
 
 
 
890
  def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
891
+ """Single-image test: run pipeline and return annotated image + JSON results."""
892
  if img is None:
893
  return None, "No image uploaded."
894
+
895
  pl = load_pipeline()
896
+ bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
897
+ processed, detections = pl.process_frame(bgr)
898
+ result_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
899
+ return result_rgb, str(detections)
900
 
901
  # --------------------------------------------------------------------
902
+ # BUILD GRADIO APP
903
  # --------------------------------------------------------------------
904
  def build_app():
905
  with gr.Blocks() as demo:
906
+ gr.Markdown("# Complete Face Recognition System (Single-Image) with Mediapipe")
907
 
908
+ # Tab: Image Test
909
  with gr.Tab("Image Test"):
910
+ gr.Markdown("Upload a single image to detect faces, run blink detection, face mesh, hand tracking, etc.")
911
+ test_in = gr.Image(type="numpy", label="Upload Image")
912
+ test_out = gr.Image()
913
+ test_info = gr.Textbox(label="Detections")
914
  process_btn = gr.Button("Process Image")
915
 
916
  process_btn.click(
917
  fn=process_test_image,
918
+ inputs=test_in,
919
+ outputs=[test_out, test_info],
920
  )
921
 
922
+ # Tab: Configuration
923
  with gr.Tab("Configuration"):
924
+ gr.Markdown("Adjust toggles, thresholds, and colors. Click Save to persist changes.")
925
 
926
  with gr.Row():
927
+ enable_recognition = gr.Checkbox(label="Enable Recognition", value=True)
928
  enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True)
929
  enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True)
930
  enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True)
931
  enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False)
932
  enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False)
933
 
934
+ # Face Mesh sub-options
935
  gr.Markdown("**Face Mesh Options**")
936
  with gr.Row():
937
+ show_tesselation = gr.Checkbox(label="Tesselation", value=False)
938
+ show_contours = gr.Checkbox(label="Contours", value=False)
939
+ show_irises = gr.Checkbox(label="Irises", value=False)
940
 
941
  gr.Markdown("**Thresholds**")
942
+ detection_conf = gr.Slider(0, 1, 0.4, step=0.01, label="Detection Confidence")
943
+ recognition_thresh = gr.Slider(0.5, 1.0, 0.85, step=0.01, label="Recognition Threshold")
944
+ antispoof_thresh = gr.Slider(0, 200, 80, step=1, label="Anti-Spoof Threshold")
945
+ blink_thresh = gr.Slider(0, 0.5, 0.25, step=0.01, label="Blink EAR Threshold")
946
+ hand_det_conf = gr.Slider(0, 1, 0.5, step=0.01, label="Hand Detection Confidence")
947
+ hand_track_conf = gr.Slider(0, 1, 0.5, step=0.01, label="Hand Tracking Confidence")
948
 
949
  gr.Markdown("**Color Options (Hex)**")
950
  bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00")
951
  spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000")
952
  unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000")
 
953
  eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00")
954
  blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff")
955
 
956
  hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d")
957
+ hand_connect_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600")
958
  hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff")
959
 
960
  mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64")
 
968
  save_btn.click(
969
  fn=update_config,
970
  inputs=[
971
+ enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
 
972
  show_tesselation, show_contours, show_irises,
973
+ detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, hand_det_conf, hand_track_conf,
974
+ bbox_hex, spoofed_hex, unknown_hex, eye_hex, blink_hex,
975
+ hand_landmark_hex, hand_connect_hex, hand_text_hex,
 
 
976
  mesh_hex, contour_hex, iris_hex, eye_color_text_hex
977
  ],
978
+ outputs=[save_msg]
979
  )
980
 
981
+ # Tab: Database Management
982
  with gr.Tab("Database Management"):
983
+ gr.Markdown("Enroll multiple images per user, search by name or image, remove users, list all users.")
984
 
985
  with gr.Accordion("User Enrollment", open=False):
986
+ enroll_name = gr.Textbox(label="User Name")
987
+ enroll_paths = gr.File(file_count="multiple", type="filepath", label="Upload Multiple Images")
 
 
 
 
 
988
  enroll_btn = gr.Button("Enroll User")
989
+ enroll_result = gr.Textbox()
990
+
991
+ enroll_btn.click(
992
+ fn=enroll_user,
993
+ inputs=[enroll_name, enroll_paths],
994
+ outputs=[enroll_result]
995
+ )
996
 
997
  with gr.Accordion("User Search", open=False):
998
+ search_mode = gr.Radio(["Name", "Image"], label="Search By", value="Name")
999
+ search_name_box = gr.Dropdown(label="Select User", choices=[], value=None, visible=True)
1000
+ search_image_box = gr.Image(label="Upload Search Image", type="numpy", visible=False)
1001
  search_btn = gr.Button("Search")
1002
+ search_out = gr.Textbox()
1003
 
1004
+ def toggle_search(mode):
1005
  if mode == "Name":
1006
  return gr.update(visible=True), gr.update(visible=False)
1007
  else:
1008
  return gr.update(visible=False), gr.update(visible=True)
1009
 
1010
  search_mode.change(
1011
+ fn=toggle_search,
1012
  inputs=[search_mode],
1013
+ outputs=[search_name_box, search_image_box]
1014
  )
1015
 
1016
+ def do_search(mode, uname, img):
1017
  if mode == "Name":
1018
+ return search_by_name(uname)
1019
  else:
1020
  return search_by_image(img)
1021
 
1022
  search_btn.click(
1023
+ fn=do_search,
1024
+ inputs=[search_mode, search_name_box, search_image_box],
1025
+ outputs=[search_out]
1026
  )
1027
 
1028
  with gr.Accordion("User Management Tools", open=False):
1029
  list_btn = gr.Button("List Enrolled Users")
1030
+ list_out = gr.Textbox()
1031
+ list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_out])
1032
 
1033
+ def refresh_choices():
1034
  pl = load_pipeline()
1035
  return gr.update(choices=pl.db.list_labels())
1036
 
1037
+ refresh_btn = gr.Button("Refresh User List")
1038
+ refresh_btn.click(fn=refresh_choices, inputs=[], outputs=[search_name_box])
1039
 
1040
+ remove_box = gr.Dropdown(label="Select User to Remove", choices=[])
1041
+ remove_btn = gr.Button("Remove")
1042
+ remove_out = gr.Textbox()
1043
 
1044
+ remove_btn.click(fn=remove_user, inputs=[remove_box], outputs=[remove_out])
1045
+ refresh_btn.click(fn=refresh_choices, inputs=[], outputs=[remove_box])
1046
 
1047
  return demo
1048
 
 
1051
  # --------------------------------------------------------------------
1052
  if __name__ == "__main__":
1053
  app = build_app()
1054
+ # queue() is optional if concurrency is expected
1055
  app.queue().launch(server_name="0.0.0.0", server_port=7860)