wuhp commited on
Commit
a6ec168
·
verified ·
1 Parent(s): 5a539d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -27
app.py CHANGED
@@ -12,11 +12,17 @@ from typing import Optional, Dict, List, Tuple
12
  from dataclasses import dataclass, field
13
  from collections import Counter
14
 
15
- # 3rd-party
16
  import gradio as gr
 
 
17
  from ultralytics import YOLO
 
 
18
  from facenet_pytorch import InceptionResnetV1
19
  from torchvision import transforms
 
 
20
  from deep_sort_realtime.deepsort_tracker import DeepSort
21
 
22
  # --------------------------------------------------------------------
@@ -41,7 +47,7 @@ DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl")
41
  MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
42
  CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
43
 
44
- # Example eye indices if you still want blink detection somewhere
45
  LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
46
  RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
47
 
@@ -238,7 +244,7 @@ class YOLOFaceDetector:
238
  return []
239
 
240
  # --------------------------------------------------------------------
241
- # FACE TRACKER (Used if you want tracking across frames - optional)
242
  # --------------------------------------------------------------------
243
  class FaceTracker:
244
  def __init__(self, max_age: int = 30):
@@ -298,7 +304,7 @@ class FacePipeline:
298
  def initialize(self):
299
  try:
300
  self.detector = YOLOFaceDetector(
301
- model_path=self.config.detector['model_path'],
302
  device=self.config.detector['device']
303
  )
304
  self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
@@ -358,8 +364,8 @@ class FacePipeline:
358
  similarity = 0.0
359
 
360
  box_color_rgb = (
361
- self.config.bbox_color
362
- if name != "Unknown"
363
  else self.config.unknown_bbox_color
364
  )
365
  box_color_bgr = box_color_rgb[::-1]
@@ -418,7 +424,7 @@ class FacePipeline:
418
  return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
419
 
420
  # --------------------------------------------------------------------
421
- # GLOBAL LOAD PIPELINE
422
  # --------------------------------------------------------------------
423
  pipeline = None
424
  def load_pipeline() -> FacePipeline:
@@ -431,7 +437,7 @@ def load_pipeline() -> FacePipeline:
431
  return pipeline
432
 
433
  # --------------------------------------------------------------------
434
- # GRADIO UTILS
435
  # --------------------------------------------------------------------
436
  def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
437
  if not hex_str.startswith('#'):
@@ -449,7 +455,7 @@ def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
449
  return f"#{r:02x}{g:02x}{b:02x}"
450
 
451
  # --------------------------------------------------------------------
452
- # TAB: CONFIGURATION
453
  # --------------------------------------------------------------------
454
  def update_config(
455
  enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
@@ -503,22 +509,30 @@ def update_config(
503
  return "Configuration saved successfully!"
504
 
505
  # --------------------------------------------------------------------
506
- # TAB: DATABASE MANAGEMENT
507
  # --------------------------------------------------------------------
508
- def enroll_user(name: str, images: List[np.ndarray]) -> str:
 
 
 
 
 
 
 
509
  pl = load_pipeline()
510
  if not name:
511
  return "Please provide a user name."
512
-
513
- if not images or len(images) == 0:
514
  return "No images provided."
515
 
516
  count_enrolled = 0
517
- for img in images:
518
- if img is None:
 
 
 
519
  continue
520
- # Gradio typically supplies images in RGB
521
- img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
522
  detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
523
  for x1, y1, x2, y2, conf, cls in detections:
524
  face_roi = img_bgr[y1:y2, x1:x2]
@@ -535,6 +549,9 @@ def enroll_user(name: str, images: List[np.ndarray]) -> str:
535
  else:
536
  return "No faces were detected or embedded. Enrollment failed."
537
 
 
 
 
538
  def search_by_name(name: str) -> str:
539
  pl = load_pipeline()
540
  if not name:
@@ -550,6 +567,7 @@ def search_by_image(image: np.ndarray) -> str:
550
  if image is None:
551
  return "No image uploaded."
552
 
 
553
  img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
554
  detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
555
  if not detections:
@@ -589,7 +607,7 @@ def list_users() -> str:
589
  return "No users enrolled."
590
 
591
  # --------------------------------------------------------------------
592
- # TAB: IMAGE-BASED RECOGNITION
593
  # --------------------------------------------------------------------
594
  def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
595
  if img is None:
@@ -607,6 +625,7 @@ def build_app():
607
  with gr.Blocks() as demo:
608
  gr.Markdown("# Face Recognition System (Image-Only)")
609
 
 
610
  with gr.Tab("Image Test"):
611
  gr.Markdown("Upload a single image for face detection & recognition:")
612
  image_input = gr.Image(type="numpy", label="Upload Image")
@@ -620,6 +639,7 @@ def build_app():
620
  outputs=[image_out, image_info],
621
  )
622
 
 
623
  with gr.Tab("Configuration"):
624
  gr.Markdown("Modify pipeline settings & thresholds.")
625
 
@@ -668,7 +688,7 @@ def build_app():
668
  save_btn.click(
669
  fn=update_config,
670
  inputs=[
671
- enable_recognition, enable_antispoof, enable_blink, enable_hand,
672
  enable_eyecolor, enable_facemesh,
673
  show_tesselation, show_contours, show_irises,
674
  detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
@@ -681,15 +701,21 @@ def build_app():
681
  outputs=[save_msg],
682
  )
683
 
 
684
  with gr.Tab("Database Management"):
685
  gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
686
 
687
  with gr.Accordion("User Enrollment", open=False):
688
  enroll_name = gr.Textbox(label="Enter name for enrollment")
689
- enroll_images = gr.Image(type="numpy", label="Upload Enrollment Images", multiple=True)
 
 
 
 
 
690
  enroll_btn = gr.Button("Enroll User")
691
  enroll_result = gr.Textbox(label="", interactive=False)
692
- enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_images], outputs=[enroll_result])
693
 
694
  with gr.Accordion("User Search", open=False):
695
  search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
@@ -699,15 +725,14 @@ def build_app():
699
  search_result = gr.Textbox(label="", interactive=False)
700
 
701
  def update_search_visibility(mode):
702
- # Show name dropdown if "Name", else show image upload
703
  if mode == "Name":
704
  return gr.update(visible=True), gr.update(visible=False)
705
  else:
706
  return gr.update(visible=False), gr.update(visible=True)
707
 
708
  search_mode.change(
709
- fn=update_search_visibility,
710
- inputs=[search_mode],
711
  outputs=[search_name_input, search_image_input]
712
  )
713
 
@@ -718,8 +743,8 @@ def build_app():
718
  return search_by_image(img)
719
 
720
  search_btn.click(
721
- fn=search_user,
722
- inputs=[search_mode, search_name_input, search_image_input],
723
  outputs=[search_result]
724
  )
725
 
@@ -749,5 +774,5 @@ def build_app():
749
  # --------------------------------------------------------------------
750
  if __name__ == "__main__":
751
  app = build_app()
752
- # queue() is optional if you expect concurrency
753
  app.queue().launch(server_name="0.0.0.0", server_port=7860)
 
12
  from dataclasses import dataclass, field
13
  from collections import Counter
14
 
15
+ # Gradio
16
  import gradio as gr
17
+
18
+ # YOLO (Ultralytics)
19
  from ultralytics import YOLO
20
+
21
+ # FaceNet
22
  from facenet_pytorch import InceptionResnetV1
23
  from torchvision import transforms
24
+
25
+ # DeepSORT tracking
26
  from deep_sort_realtime.deepsort_tracker import DeepSort
27
 
28
  # --------------------------------------------------------------------
 
47
  MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
48
  CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
49
 
50
+ # If you still want blink detection or face mesh references, keep them
51
  LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
52
  RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
53
 
 
244
  return []
245
 
246
  # --------------------------------------------------------------------
247
+ # FACE TRACKER
248
  # --------------------------------------------------------------------
249
  class FaceTracker:
250
  def __init__(self, max_age: int = 30):
 
304
  def initialize(self):
305
  try:
306
  self.detector = YOLOFaceDetector(
307
+ model_path=self.config.detector['model_path'],
308
  device=self.config.detector['device']
309
  )
310
  self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
 
364
  similarity = 0.0
365
 
366
  box_color_rgb = (
367
+ self.config.bbox_color
368
+ if name != "Unknown"
369
  else self.config.unknown_bbox_color
370
  )
371
  box_color_bgr = box_color_rgb[::-1]
 
424
  return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
425
 
426
  # --------------------------------------------------------------------
427
+ # GLOBAL PIPELINE
428
  # --------------------------------------------------------------------
429
  pipeline = None
430
  def load_pipeline() -> FacePipeline:
 
437
  return pipeline
438
 
439
  # --------------------------------------------------------------------
440
+ # HELPER FUNCTIONS FOR GRADIO
441
  # --------------------------------------------------------------------
442
  def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
443
  if not hex_str.startswith('#'):
 
455
  return f"#{r:02x}{g:02x}{b:02x}"
456
 
457
  # --------------------------------------------------------------------
458
+ # UPDATE CONFIG
459
  # --------------------------------------------------------------------
460
  def update_config(
461
  enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
 
509
  return "Configuration saved successfully!"
510
 
511
  # --------------------------------------------------------------------
512
+ # ENROLL MULTIPLE IMAGES (using gr.File)
513
  # --------------------------------------------------------------------
514
+ def enroll_user(name: str, files: List[dict]) -> str:
515
+ """
516
+ Each item in `files` is a dict with keys:
517
+ - 'name': filename
518
+ - 'size': file size in bytes
519
+ - 'data': the binary file contents
520
+ We decode them into OpenCV images and process.
521
+ """
522
  pl = load_pipeline()
523
  if not name:
524
  return "Please provide a user name."
525
+ if not files or len(files) == 0:
 
526
  return "No images provided."
527
 
528
  count_enrolled = 0
529
+ for f in files:
530
+ file_bytes = f["data"]
531
+ np_array = np.frombuffer(file_bytes, np.uint8)
532
+ img_bgr = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
533
+ if img_bgr is None:
534
  continue
535
+
 
536
  detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
537
  for x1, y1, x2, y2, conf, cls in detections:
538
  face_roi = img_bgr[y1:y2, x1:x2]
 
549
  else:
550
  return "No faces were detected or embedded. Enrollment failed."
551
 
552
+ # --------------------------------------------------------------------
553
+ # SEARCH / USER MGMT
554
+ # --------------------------------------------------------------------
555
  def search_by_name(name: str) -> str:
556
  pl = load_pipeline()
557
  if not name:
 
567
  if image is None:
568
  return "No image uploaded."
569
 
570
+ # Convert to BGR for pipeline
571
  img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
572
  detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
573
  if not detections:
 
607
  return "No users enrolled."
608
 
609
  # --------------------------------------------------------------------
610
+ # PROCESS SINGLE IMAGE (IMAGE TEST)
611
  # --------------------------------------------------------------------
612
  def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
613
  if img is None:
 
625
  with gr.Blocks() as demo:
626
  gr.Markdown("# Face Recognition System (Image-Only)")
627
 
628
+ # -- Tab: Image Test --
629
  with gr.Tab("Image Test"):
630
  gr.Markdown("Upload a single image for face detection & recognition:")
631
  image_input = gr.Image(type="numpy", label="Upload Image")
 
639
  outputs=[image_out, image_info],
640
  )
641
 
642
+ # -- Tab: Configuration --
643
  with gr.Tab("Configuration"):
644
  gr.Markdown("Modify pipeline settings & thresholds.")
645
 
 
688
  save_btn.click(
689
  fn=update_config,
690
  inputs=[
691
+ enable_recognition, enable_antispoof, enable_blink, enable_hand,
692
  enable_eyecolor, enable_facemesh,
693
  show_tesselation, show_contours, show_irises,
694
  detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
 
701
  outputs=[save_msg],
702
  )
703
 
704
+ # -- Tab: Database Management --
705
  with gr.Tab("Database Management"):
706
  gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
707
 
708
  with gr.Accordion("User Enrollment", open=False):
709
  enroll_name = gr.Textbox(label="Enter name for enrollment")
710
+ # Use gr.File with multiple file support:
711
+ enroll_files = gr.File(
712
+ file_count="multiple",
713
+ type="file",
714
+ label="Upload Enrollment Images (multiple allowed)"
715
+ )
716
  enroll_btn = gr.Button("Enroll User")
717
  enroll_result = gr.Textbox(label="", interactive=False)
718
+ enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_files], outputs=[enroll_result])
719
 
720
  with gr.Accordion("User Search", open=False):
721
  search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
 
725
  search_result = gr.Textbox(label="", interactive=False)
726
 
727
  def update_search_visibility(mode):
 
728
  if mode == "Name":
729
  return gr.update(visible=True), gr.update(visible=False)
730
  else:
731
  return gr.update(visible=False), gr.update(visible=True)
732
 
733
  search_mode.change(
734
+ fn=update_search_visibility,
735
+ inputs=[search_mode],
736
  outputs=[search_name_input, search_image_input]
737
  )
738
 
 
743
  return search_by_image(img)
744
 
745
  search_btn.click(
746
+ fn=search_user,
747
+ inputs=[search_mode, search_name_input, search_image_input],
748
  outputs=[search_result]
749
  )
750
 
 
774
  # --------------------------------------------------------------------
775
  if __name__ == "__main__":
776
  app = build_app()
777
+ # If concurrency or GPU usage is high, you can use .queue():
778
  app.queue().launch(server_name="0.0.0.0", server_port=7860)