Update app.py
Browse files
app.py
CHANGED
@@ -12,11 +12,17 @@ from typing import Optional, Dict, List, Tuple
|
|
12 |
from dataclasses import dataclass, field
|
13 |
from collections import Counter
|
14 |
|
15 |
-
#
|
16 |
import gradio as gr
|
|
|
|
|
17 |
from ultralytics import YOLO
|
|
|
|
|
18 |
from facenet_pytorch import InceptionResnetV1
|
19 |
from torchvision import transforms
|
|
|
|
|
20 |
from deep_sort_realtime.deepsort_tracker import DeepSort
|
21 |
|
22 |
# --------------------------------------------------------------------
|
@@ -41,7 +47,7 @@ DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl")
|
|
41 |
MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
|
42 |
CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
|
43 |
|
44 |
-
#
|
45 |
LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
|
46 |
RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
|
47 |
|
@@ -238,7 +244,7 @@ class YOLOFaceDetector:
|
|
238 |
return []
|
239 |
|
240 |
# --------------------------------------------------------------------
|
241 |
-
# FACE TRACKER
|
242 |
# --------------------------------------------------------------------
|
243 |
class FaceTracker:
|
244 |
def __init__(self, max_age: int = 30):
|
@@ -298,7 +304,7 @@ class FacePipeline:
|
|
298 |
def initialize(self):
|
299 |
try:
|
300 |
self.detector = YOLOFaceDetector(
|
301 |
-
model_path=self.config.detector['model_path'],
|
302 |
device=self.config.detector['device']
|
303 |
)
|
304 |
self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
|
@@ -358,8 +364,8 @@ class FacePipeline:
|
|
358 |
similarity = 0.0
|
359 |
|
360 |
box_color_rgb = (
|
361 |
-
self.config.bbox_color
|
362 |
-
if name != "Unknown"
|
363 |
else self.config.unknown_bbox_color
|
364 |
)
|
365 |
box_color_bgr = box_color_rgb[::-1]
|
@@ -418,7 +424,7 @@ class FacePipeline:
|
|
418 |
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
|
419 |
|
420 |
# --------------------------------------------------------------------
|
421 |
-
# GLOBAL
|
422 |
# --------------------------------------------------------------------
|
423 |
pipeline = None
|
424 |
def load_pipeline() -> FacePipeline:
|
@@ -431,7 +437,7 @@ def load_pipeline() -> FacePipeline:
|
|
431 |
return pipeline
|
432 |
|
433 |
# --------------------------------------------------------------------
|
434 |
-
# GRADIO
|
435 |
# --------------------------------------------------------------------
|
436 |
def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
|
437 |
if not hex_str.startswith('#'):
|
@@ -449,7 +455,7 @@ def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
|
|
449 |
return f"#{r:02x}{g:02x}{b:02x}"
|
450 |
|
451 |
# --------------------------------------------------------------------
|
452 |
-
#
|
453 |
# --------------------------------------------------------------------
|
454 |
def update_config(
|
455 |
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
|
@@ -503,22 +509,30 @@ def update_config(
|
|
503 |
return "Configuration saved successfully!"
|
504 |
|
505 |
# --------------------------------------------------------------------
|
506 |
-
#
|
507 |
# --------------------------------------------------------------------
|
508 |
-
def enroll_user(name: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
pl = load_pipeline()
|
510 |
if not name:
|
511 |
return "Please provide a user name."
|
512 |
-
|
513 |
-
if not images or len(images) == 0:
|
514 |
return "No images provided."
|
515 |
|
516 |
count_enrolled = 0
|
517 |
-
for
|
518 |
-
|
|
|
|
|
|
|
519 |
continue
|
520 |
-
|
521 |
-
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
522 |
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
|
523 |
for x1, y1, x2, y2, conf, cls in detections:
|
524 |
face_roi = img_bgr[y1:y2, x1:x2]
|
@@ -535,6 +549,9 @@ def enroll_user(name: str, images: List[np.ndarray]) -> str:
|
|
535 |
else:
|
536 |
return "No faces were detected or embedded. Enrollment failed."
|
537 |
|
|
|
|
|
|
|
538 |
def search_by_name(name: str) -> str:
|
539 |
pl = load_pipeline()
|
540 |
if not name:
|
@@ -550,6 +567,7 @@ def search_by_image(image: np.ndarray) -> str:
|
|
550 |
if image is None:
|
551 |
return "No image uploaded."
|
552 |
|
|
|
553 |
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
554 |
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
|
555 |
if not detections:
|
@@ -589,7 +607,7 @@ def list_users() -> str:
|
|
589 |
return "No users enrolled."
|
590 |
|
591 |
# --------------------------------------------------------------------
|
592 |
-
#
|
593 |
# --------------------------------------------------------------------
|
594 |
def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
|
595 |
if img is None:
|
@@ -607,6 +625,7 @@ def build_app():
|
|
607 |
with gr.Blocks() as demo:
|
608 |
gr.Markdown("# Face Recognition System (Image-Only)")
|
609 |
|
|
|
610 |
with gr.Tab("Image Test"):
|
611 |
gr.Markdown("Upload a single image for face detection & recognition:")
|
612 |
image_input = gr.Image(type="numpy", label="Upload Image")
|
@@ -620,6 +639,7 @@ def build_app():
|
|
620 |
outputs=[image_out, image_info],
|
621 |
)
|
622 |
|
|
|
623 |
with gr.Tab("Configuration"):
|
624 |
gr.Markdown("Modify pipeline settings & thresholds.")
|
625 |
|
@@ -668,7 +688,7 @@ def build_app():
|
|
668 |
save_btn.click(
|
669 |
fn=update_config,
|
670 |
inputs=[
|
671 |
-
enable_recognition, enable_antispoof, enable_blink, enable_hand,
|
672 |
enable_eyecolor, enable_facemesh,
|
673 |
show_tesselation, show_contours, show_irises,
|
674 |
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
|
@@ -681,15 +701,21 @@ def build_app():
|
|
681 |
outputs=[save_msg],
|
682 |
)
|
683 |
|
|
|
684 |
with gr.Tab("Database Management"):
|
685 |
gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
|
686 |
|
687 |
with gr.Accordion("User Enrollment", open=False):
|
688 |
enroll_name = gr.Textbox(label="Enter name for enrollment")
|
689 |
-
|
|
|
|
|
|
|
|
|
|
|
690 |
enroll_btn = gr.Button("Enroll User")
|
691 |
enroll_result = gr.Textbox(label="", interactive=False)
|
692 |
-
enroll_btn.click(fn=enroll_user, inputs=[enroll_name,
|
693 |
|
694 |
with gr.Accordion("User Search", open=False):
|
695 |
search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
|
@@ -699,15 +725,14 @@ def build_app():
|
|
699 |
search_result = gr.Textbox(label="", interactive=False)
|
700 |
|
701 |
def update_search_visibility(mode):
|
702 |
-
# Show name dropdown if "Name", else show image upload
|
703 |
if mode == "Name":
|
704 |
return gr.update(visible=True), gr.update(visible=False)
|
705 |
else:
|
706 |
return gr.update(visible=False), gr.update(visible=True)
|
707 |
|
708 |
search_mode.change(
|
709 |
-
fn=update_search_visibility,
|
710 |
-
inputs=[search_mode],
|
711 |
outputs=[search_name_input, search_image_input]
|
712 |
)
|
713 |
|
@@ -718,8 +743,8 @@ def build_app():
|
|
718 |
return search_by_image(img)
|
719 |
|
720 |
search_btn.click(
|
721 |
-
fn=search_user,
|
722 |
-
inputs=[search_mode, search_name_input, search_image_input],
|
723 |
outputs=[search_result]
|
724 |
)
|
725 |
|
@@ -749,5 +774,5 @@ def build_app():
|
|
749 |
# --------------------------------------------------------------------
|
750 |
if __name__ == "__main__":
|
751 |
app = build_app()
|
752 |
-
#
|
753 |
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|
|
|
12 |
from dataclasses import dataclass, field
|
13 |
from collections import Counter
|
14 |
|
15 |
+
# Gradio
|
16 |
import gradio as gr
|
17 |
+
|
18 |
+
# YOLO (Ultralytics)
|
19 |
from ultralytics import YOLO
|
20 |
+
|
21 |
+
# FaceNet
|
22 |
from facenet_pytorch import InceptionResnetV1
|
23 |
from torchvision import transforms
|
24 |
+
|
25 |
+
# DeepSORT tracking
|
26 |
from deep_sort_realtime.deepsort_tracker import DeepSort
|
27 |
|
28 |
# --------------------------------------------------------------------
|
|
|
47 |
MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
|
48 |
CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
|
49 |
|
50 |
+
# If you still want blink detection or face mesh references, keep them
|
51 |
LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
|
52 |
RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
|
53 |
|
|
|
244 |
return []
|
245 |
|
246 |
# --------------------------------------------------------------------
|
247 |
+
# FACE TRACKER
|
248 |
# --------------------------------------------------------------------
|
249 |
class FaceTracker:
|
250 |
def __init__(self, max_age: int = 30):
|
|
|
304 |
def initialize(self):
|
305 |
try:
|
306 |
self.detector = YOLOFaceDetector(
|
307 |
+
model_path=self.config.detector['model_path'],
|
308 |
device=self.config.detector['device']
|
309 |
)
|
310 |
self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
|
|
|
364 |
similarity = 0.0
|
365 |
|
366 |
box_color_rgb = (
|
367 |
+
self.config.bbox_color
|
368 |
+
if name != "Unknown"
|
369 |
else self.config.unknown_bbox_color
|
370 |
)
|
371 |
box_color_bgr = box_color_rgb[::-1]
|
|
|
424 |
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
|
425 |
|
426 |
# --------------------------------------------------------------------
|
427 |
+
# GLOBAL PIPELINE
|
428 |
# --------------------------------------------------------------------
|
429 |
pipeline = None
|
430 |
def load_pipeline() -> FacePipeline:
|
|
|
437 |
return pipeline
|
438 |
|
439 |
# --------------------------------------------------------------------
|
440 |
+
# HELPER FUNCTIONS FOR GRADIO
|
441 |
# --------------------------------------------------------------------
|
442 |
def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
|
443 |
if not hex_str.startswith('#'):
|
|
|
455 |
return f"#{r:02x}{g:02x}{b:02x}"
|
456 |
|
457 |
# --------------------------------------------------------------------
|
458 |
+
# UPDATE CONFIG
|
459 |
# --------------------------------------------------------------------
|
460 |
def update_config(
|
461 |
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
|
|
|
509 |
return "Configuration saved successfully!"
|
510 |
|
511 |
# --------------------------------------------------------------------
|
512 |
+
# ENROLL MULTIPLE IMAGES (using gr.File)
|
513 |
# --------------------------------------------------------------------
|
514 |
+
def enroll_user(name: str, files: List[dict]) -> str:
|
515 |
+
"""
|
516 |
+
Each item in `files` is a dict with keys:
|
517 |
+
- 'name': filename
|
518 |
+
- 'size': file size in bytes
|
519 |
+
- 'data': the binary file contents
|
520 |
+
We decode them into OpenCV images and process.
|
521 |
+
"""
|
522 |
pl = load_pipeline()
|
523 |
if not name:
|
524 |
return "Please provide a user name."
|
525 |
+
if not files or len(files) == 0:
|
|
|
526 |
return "No images provided."
|
527 |
|
528 |
count_enrolled = 0
|
529 |
+
for f in files:
|
530 |
+
file_bytes = f["data"]
|
531 |
+
np_array = np.frombuffer(file_bytes, np.uint8)
|
532 |
+
img_bgr = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
533 |
+
if img_bgr is None:
|
534 |
continue
|
535 |
+
|
|
|
536 |
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
|
537 |
for x1, y1, x2, y2, conf, cls in detections:
|
538 |
face_roi = img_bgr[y1:y2, x1:x2]
|
|
|
549 |
else:
|
550 |
return "No faces were detected or embedded. Enrollment failed."
|
551 |
|
552 |
+
# --------------------------------------------------------------------
|
553 |
+
# SEARCH / USER MGMT
|
554 |
+
# --------------------------------------------------------------------
|
555 |
def search_by_name(name: str) -> str:
|
556 |
pl = load_pipeline()
|
557 |
if not name:
|
|
|
567 |
if image is None:
|
568 |
return "No image uploaded."
|
569 |
|
570 |
+
# Convert to BGR for pipeline
|
571 |
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
572 |
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
|
573 |
if not detections:
|
|
|
607 |
return "No users enrolled."
|
608 |
|
609 |
# --------------------------------------------------------------------
|
610 |
+
# PROCESS SINGLE IMAGE (IMAGE TEST)
|
611 |
# --------------------------------------------------------------------
|
612 |
def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
|
613 |
if img is None:
|
|
|
625 |
with gr.Blocks() as demo:
|
626 |
gr.Markdown("# Face Recognition System (Image-Only)")
|
627 |
|
628 |
+
# -- Tab: Image Test --
|
629 |
with gr.Tab("Image Test"):
|
630 |
gr.Markdown("Upload a single image for face detection & recognition:")
|
631 |
image_input = gr.Image(type="numpy", label="Upload Image")
|
|
|
639 |
outputs=[image_out, image_info],
|
640 |
)
|
641 |
|
642 |
+
# -- Tab: Configuration --
|
643 |
with gr.Tab("Configuration"):
|
644 |
gr.Markdown("Modify pipeline settings & thresholds.")
|
645 |
|
|
|
688 |
save_btn.click(
|
689 |
fn=update_config,
|
690 |
inputs=[
|
691 |
+
enable_recognition, enable_antispoof, enable_blink, enable_hand,
|
692 |
enable_eyecolor, enable_facemesh,
|
693 |
show_tesselation, show_contours, show_irises,
|
694 |
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
|
|
|
701 |
outputs=[save_msg],
|
702 |
)
|
703 |
|
704 |
+
# -- Tab: Database Management --
|
705 |
with gr.Tab("Database Management"):
|
706 |
gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
|
707 |
|
708 |
with gr.Accordion("User Enrollment", open=False):
|
709 |
enroll_name = gr.Textbox(label="Enter name for enrollment")
|
710 |
+
# Use gr.File with multiple file support:
|
711 |
+
enroll_files = gr.File(
|
712 |
+
file_count="multiple",
|
713 |
+
type="file",
|
714 |
+
label="Upload Enrollment Images (multiple allowed)"
|
715 |
+
)
|
716 |
enroll_btn = gr.Button("Enroll User")
|
717 |
enroll_result = gr.Textbox(label="", interactive=False)
|
718 |
+
enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_files], outputs=[enroll_result])
|
719 |
|
720 |
with gr.Accordion("User Search", open=False):
|
721 |
search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
|
|
|
725 |
search_result = gr.Textbox(label="", interactive=False)
|
726 |
|
727 |
def update_search_visibility(mode):
|
|
|
728 |
if mode == "Name":
|
729 |
return gr.update(visible=True), gr.update(visible=False)
|
730 |
else:
|
731 |
return gr.update(visible=False), gr.update(visible=True)
|
732 |
|
733 |
search_mode.change(
|
734 |
+
fn=update_search_visibility,
|
735 |
+
inputs=[search_mode],
|
736 |
outputs=[search_name_input, search_image_input]
|
737 |
)
|
738 |
|
|
|
743 |
return search_by_image(img)
|
744 |
|
745 |
search_btn.click(
|
746 |
+
fn=search_user,
|
747 |
+
inputs=[search_mode, search_name_input, search_image_input],
|
748 |
outputs=[search_result]
|
749 |
)
|
750 |
|
|
|
774 |
# --------------------------------------------------------------------
|
775 |
if __name__ == "__main__":
|
776 |
app = build_app()
|
777 |
+
# If concurrency or GPU usage is high, you can use .queue():
|
778 |
app.queue().launch(server_name="0.0.0.0", server_port=7860)
|