facerec / app.py
wuhp's picture
Create app.py
40291e5 verified
raw
history blame
33.7 kB
# app.py
# --------------------------------------------------------------------
# A Gradio-based face recognition system that mimics most features of
# your Streamlit app: real-time webcam, image tests, configuration,
# database enrollment, searching, user removal, etc.
# --------------------------------------------------------------------
import os
import sys
import math
import requests
import numpy as np
import cv2
import torch
import pickle
import logging
from PIL import Image
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass, field
from collections import Counter
# 3rd-party modules
import gradio as gr
from ultralytics import YOLO
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from deep_sort_realtime.deepsort_tracker import DeepSort
# --------------------------------------------------------------------
# GLOBALS & CONSTANTS
# --------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler('face_pipeline.log'), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
logging.getLogger('torch').setLevel(logging.ERROR)
logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR)
DEFAULT_MODEL_URL = "https://github.com/wuhplaptop/face-11-n/blob/main/face2.pt?raw=true"
DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl")
MODEL_DIR = os.path.expanduser("~/.face_pipeline/models")
CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl")
# If you need blink detection or face mesh, keep or define your landmarks:
LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144]
RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373]
# --------------------------------------------------------------------
# PIPELINE CONFIG DATACLASS
# --------------------------------------------------------------------
@dataclass
class PipelineConfig:
detector: Dict = field(default_factory=dict)
tracker: Dict = field(default_factory=dict)
recognition: Dict = field(default_factory=dict)
anti_spoof: Dict = field(default_factory=dict)
blink: Dict = field(default_factory=dict)
face_mesh_options: Dict = field(default_factory=dict)
hand: Dict = field(default_factory=dict)
eye_color: Dict = field(default_factory=dict)
enabled_components: Dict = field(default_factory=dict)
detection_conf_thres: float = 0.4
recognition_conf_thres: float = 0.85
bbox_color: Tuple[int, int, int] = (0, 255, 0)
spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255)
unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255)
eye_outline_color: Tuple[int, int, int] = (255, 255, 0)
blink_text_color: Tuple[int, int, int] = (0, 0, 255)
hand_landmark_color: Tuple[int, int, int] = (255, 210, 77)
hand_connection_color: Tuple[int, int, int] = (204, 102, 0)
hand_text_color: Tuple[int, int, int] = (255, 255, 255)
mesh_color: Tuple[int, int, int] = (100, 255, 100)
contour_color: Tuple[int, int, int] = (200, 200, 0)
iris_color: Tuple[int, int, int] = (255, 0, 255)
eye_color_text_color: Tuple[int, int, int] = (255, 255, 255)
def __post_init__(self):
self.detector = self.detector or {
'model_path': os.path.join(MODEL_DIR, "face2.pt"),
'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}
self.tracker = self.tracker or {'max_age': 30}
self.recognition = self.recognition or {'enable': True}
self.anti_spoof = self.anti_spoof or {'enable': True, 'lap_thresh': 80.0}
self.blink = self.blink or {'enable': True, 'ear_thresh': 0.25}
self.face_mesh_options = self.face_mesh_options or {
'enable': False,
'tesselation': False,
'contours': False,
'irises': False,
}
self.hand = self.hand or {
'enable': True,
'min_detection_confidence': 0.5,
'min_tracking_confidence': 0.5,
}
self.eye_color = self.eye_color or {'enable': False}
self.enabled_components = self.enabled_components or {
'detection': True,
'tracking': True,
'anti_spoof': True,
'recognition': True,
'blink': True,
'face_mesh': False,
'hand': True,
'eye_color': False,
}
def save(self, path: str):
try:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'wb') as f:
pickle.dump(self.__dict__, f)
logger.info(f"Saved config to {path}")
except Exception as e:
logger.error(f"Config save failed: {str(e)}")
raise RuntimeError(f"Config save failed: {str(e)}") from e
@classmethod
def load(cls, path: str) -> 'PipelineConfig':
try:
if os.path.exists(path):
with open(path, 'rb') as f:
data = pickle.load(f)
return cls(**data)
return cls()
except Exception as e:
logger.error(f"Config load failed: {str(e)}")
return cls()
# --------------------------------------------------------------------
# FACE DATABASE
# --------------------------------------------------------------------
class FaceDatabase:
def __init__(self, db_path: str = DEFAULT_DB_PATH):
self.db_path = db_path
self.embeddings: Dict[str, List[np.ndarray]] = {}
self._load()
def _load(self):
try:
if os.path.exists(self.db_path):
with open(self.db_path, 'rb') as f:
self.embeddings = pickle.load(f)
logger.info(f"Loaded database from {self.db_path}")
except Exception as e:
logger.error(f"Database load failed: {str(e)}")
self.embeddings = {}
def save(self):
try:
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
with open(self.db_path, 'wb') as f:
pickle.dump(self.embeddings, f)
logger.info(f"Saved database to {self.db_path}")
except Exception as e:
logger.error(f"Database save failed: {str(e)}")
raise RuntimeError(f"Database save failed: {str(e)}") from e
def add_embedding(self, label: str, embedding: np.ndarray):
try:
if not isinstance(embedding, np.ndarray) or embedding.ndim != 1:
raise ValueError("Invalid embedding format")
if label not in self.embeddings:
self.embeddings[label] = []
self.embeddings[label].append(embedding)
logger.debug(f"Added embedding for {label}")
except Exception as e:
logger.error(f"Add embedding failed: {str(e)}")
raise
def remove_label(self, label: str):
try:
if label in self.embeddings:
del self.embeddings[label]
logger.info(f"Removed {label}")
else:
logger.warning(f"Label {label} not found")
except Exception as e:
logger.error(f"Remove label failed: {str(e)}")
raise
def list_labels(self) -> List[str]:
return list(self.embeddings.keys())
def get_embeddings_by_label(self, label: str) -> Optional[List[np.ndarray]]:
return self.embeddings.get(label)
def search_by_image(self, query_embedding: np.ndarray, threshold: float = 0.7) -> List[Tuple[str, float]]:
results = []
for label, embeddings in self.embeddings.items():
for db_emb in embeddings:
similarity = FacePipeline.cosine_similarity(query_embedding, db_emb)
if similarity >= threshold:
results.append((label, similarity))
return sorted(results, key=lambda x: x[1], reverse=True)
# --------------------------------------------------------------------
# YOLO FACE DETECTOR
# --------------------------------------------------------------------
class YOLOFaceDetector:
def __init__(self, model_path: str, device: str = 'cpu'):
self.model = None
self.device = device
try:
if not os.path.exists(model_path):
logger.info(f"Model file not found at {model_path}. Attempting to download...")
response = requests.get(DEFAULT_MODEL_URL)
response.raise_for_status()
os.makedirs(os.path.dirname(model_path), exist_ok=True)
with open(model_path, 'wb') as f:
f.write(response.content)
logger.info(f"Downloaded YOLO model to {model_path}")
self.model = YOLO(model_path)
self.model.to(device)
logger.info(f"Loaded YOLO model from {model_path}")
except Exception as e:
logger.error(f"YOLO initialization failed: {str(e)}")
raise
def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]:
try:
results = self.model.predict(
source=image, conf=conf_thres, verbose=False, device=self.device
)
detections = []
for result in results:
for box in result.boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf[0].cpu().numpy())
cls = int(box.cls[0].cpu().numpy()) if box.cls is not None else 0
detections.append((int(x1), int(y1), int(x2), int(y2), conf, cls))
logger.debug(f"Detected {len(detections)} faces.")
return detections
except Exception as e:
logger.error(f"Detection failed: {str(e)}")
return []
# --------------------------------------------------------------------
# FACE TRACKER
# --------------------------------------------------------------------
class FaceTracker:
def __init__(self, max_age: int = 30):
self.tracker = DeepSort(max_age=max_age, embedder='mobilenet')
def update(self, detections: List[Tuple], frame: np.ndarray):
try:
ds_detections = [
([x1, y1, x2 - x1, y2 - y1], conf, cls)
for (x1, y1, x2, y2, conf, cls) in detections
]
tracks = self.tracker.update_tracks(ds_detections, frame=frame)
logger.debug(f"Updated tracker with {len(tracks)} tracks.")
return tracks
except Exception as e:
logger.error(f"Tracking update failed: {str(e)}")
return []
# --------------------------------------------------------------------
# FACENET EMBEDDER
# --------------------------------------------------------------------
class FaceNetEmbedder:
def __init__(self, device: str = 'cpu'):
self.device = device
self.model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
self.transform = transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Normalize([0.5]*3, [0.5]*3),
])
def get_embedding(self, face_bgr: np.ndarray) -> Optional[np.ndarray]:
try:
face_rgb = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(face_rgb).convert('RGB')
tens = self.transform(pil_img).unsqueeze(0).to(self.device)
with torch.no_grad():
embedding = self.model(tens)[0].cpu().numpy()
logger.debug(f"Generated embedding: {embedding[:5]}...")
return embedding
except Exception as e:
logger.error(f"Embedding generation failed: {str(e)}")
return None
# --------------------------------------------------------------------
# MAIN PIPELINE
# --------------------------------------------------------------------
class FacePipeline:
def __init__(self, config: PipelineConfig):
self.config = config
self.detector = None
self.tracker = None
self.facenet = None
self.db = None
self._initialized = False
def initialize(self):
try:
self.detector = YOLOFaceDetector(
model_path=self.config.detector['model_path'],
device=self.config.detector['device']
)
self.tracker = FaceTracker(max_age=self.config.tracker['max_age'])
self.facenet = FaceNetEmbedder(device=self.config.detector['device'])
self.db = FaceDatabase()
self._initialized = True
logger.info("FacePipeline initialized successfully.")
except Exception as e:
logger.error(f"Pipeline initialization failed: {str(e)}")
self._initialized = False
raise
def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]:
if not self._initialized:
logger.error("Pipeline not initialized!")
return frame, []
try:
detections = self.detector.detect(frame, self.config.detection_conf_thres)
tracked_objects = self.tracker.update(detections, frame)
annotated_frame = frame.copy()
results = []
for obj in tracked_objects:
if not obj.is_confirmed():
continue
track_id = obj.track_id
bbox = obj.to_tlbr()
x1, y1, x2, y2 = bbox.astype(int)
conf = getattr(obj, 'score', 1.0)
cls = getattr(obj, 'class_id', 0)
face_roi = frame[y1:y2, x1:x2]
if face_roi.size == 0:
logger.warning(f"Empty face ROI for track {track_id}")
continue
# Anti-spoof
is_spoofed = False
if self.config.anti_spoof['enable']:
is_spoofed = not self.is_real_face(face_roi)
if is_spoofed:
cls = 1
if is_spoofed:
box_color_bgr = self.config.spoofed_bbox_color[::-1]
name = "Spoofed"
similarity = 0.0
else:
# Face recognition
embedding = self.facenet.get_embedding(face_roi)
if embedding is not None and self.config.recognition['enable']:
name, similarity = self.recognize_face(
embedding, self.config.recognition_conf_thres
)
else:
name = "Unknown"
similarity = 0.0
box_color_rgb = (
self.config.bbox_color
if name != "Unknown"
else self.config.unknown_bbox_color
)
box_color_bgr = box_color_rgb[::-1]
label_text = f"{name}"
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), box_color_bgr, 2)
cv2.putText(
annotated_frame, label_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2
)
detection_info = {
'track_id': track_id,
'bbox': (x1, y1, x2, y2),
'confidence': float(conf),
'class_id': cls,
'name': name,
'similarity': float(similarity),
}
results.append(detection_info)
return annotated_frame, results
except Exception as e:
logger.error(f"Frame processing failed: {str(e)}", exc_info=True)
return frame, []
def is_real_face(self, face_roi: np.ndarray) -> bool:
try:
gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY)
lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
return lap_var > self.config.anti_spoof['lap_thresh']
except Exception as e:
logger.error(f"Anti-spoof check failed: {str(e)}")
return False
def recognize_face(self, embedding: np.ndarray, recognition_threshold: float) -> Tuple[str, float]:
try:
best_match = "Unknown"
best_similarity = 0.0
for label, embeddings in self.db.embeddings.items():
for db_emb in embeddings:
similarity = FacePipeline.cosine_similarity(embedding, db_emb)
if similarity > best_similarity:
best_similarity = similarity
best_match = label
if best_similarity < recognition_threshold:
best_match = "Unknown"
return best_match, best_similarity
except Exception as e:
logger.error(f"Face recognition failed: {str(e)}")
return "Unknown", 0.0
@staticmethod
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6))
# --------------------------------------------------------------------
# GLOBAL pipeline instance (we can store it in a lazy loader)
# --------------------------------------------------------------------
pipeline = None
def load_pipeline() -> FacePipeline:
global pipeline
if pipeline is None:
logger.info("Loading pipeline for the first time...")
cfg = PipelineConfig.load(CONFIG_PATH)
pipeline = FacePipeline(cfg)
pipeline.initialize()
return pipeline
# --------------------------------------------------------------------
# GRADIO HELPER FUNCTIONS
# --------------------------------------------------------------------
def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]:
"""
Convert a hex string (#RRGGBB) into a BGR tuple as used in OpenCV.
"""
if not hex_str.startswith('#'):
hex_str = f"#{hex_str}"
hex_str = hex_str.lstrip('#')
if len(hex_str) != 6:
return (255, 0, 0) # fallback to something
r = int(hex_str[0:2], 16)
g = int(hex_str[2:4], 16)
b = int(hex_str[4:6], 16)
return (b,g,r)
def bgr_to_hex(bgr: Tuple[int,int,int]) -> str:
"""
Convert a BGR tuple (as stored in pipeline config) to a #RRGGBB hex string.
"""
b,g,r = bgr
return f"#{r:02x}{g:02x}{b:02x}"
# --------------------------------------------------------------------
# TAB: Configuration
# --------------------------------------------------------------------
def update_config(
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh,
show_tesselation, show_contours, show_irises,
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
hand_det_conf, hand_track_conf,
bbox_hex, spoofed_hex, unknown_hex,
eye_hex, blink_hex,
hand_landmark_hex, hand_connection_hex, hand_text_hex,
mesh_hex, contour_hex, iris_hex,
eye_color_text_hex
):
# Load pipeline
pl = load_pipeline()
cfg = pl.config
# Update toggles
cfg.recognition['enable'] = enable_recognition
cfg.anti_spoof['enable'] = enable_antispoof
cfg.blink['enable'] = enable_blink
cfg.hand['enable'] = enable_hand
cfg.eye_color['enable'] = enable_eyecolor
cfg.face_mesh_options['enable'] = enable_facemesh
cfg.face_mesh_options['tesselation'] = show_tesselation
cfg.face_mesh_options['contours'] = show_contours
cfg.face_mesh_options['irises'] = show_irises
# Update thresholds
cfg.detection_conf_thres = detection_conf
cfg.recognition_conf_thres = recognition_thresh
cfg.anti_spoof['lap_thresh'] = antispoof_thresh
cfg.blink['ear_thresh'] = blink_thresh
cfg.hand['min_detection_confidence'] = hand_det_conf
cfg.hand['min_tracking_confidence'] = hand_track_conf
# Update color fields
cfg.bbox_color = hex_to_bgr(bbox_hex)[::-1] # store in (R,G,B)
cfg.spoofed_bbox_color = hex_to_bgr(spoofed_hex)[::-1]
cfg.unknown_bbox_color = hex_to_bgr(unknown_hex)[::-1]
cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1]
cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1]
cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1]
cfg.hand_connection_color = hex_to_bgr(hand_connection_hex)[::-1]
cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1]
cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1]
cfg.contour_color = hex_to_bgr(contour_hex)[::-1]
cfg.iris_color = hex_to_bgr(iris_hex)[::-1]
cfg.eye_color_text_color = hex_to_bgr(eye_color_text_hex)[::-1]
# Save config
cfg.save(CONFIG_PATH)
return "Configuration saved successfully!"
# --------------------------------------------------------------------
# TAB: Database Management
# --------------------------------------------------------------------
def enroll_user(name: str, images: List[np.ndarray]) -> str:
"""
Enroll user by name using one or more images. images is a list of
NxMx3 numpy arrays in BGR or RGB depending on Gradio type.
"""
pl = load_pipeline()
if not name:
return "Please provide a user name."
if not images or len(images) == 0:
return "No images provided."
count_enrolled = 0
for img in images:
if img is None:
continue
# Gradio provides images in RGB by default, let's ensure BGR for pipeline
if img.shape[-1] == 3: # RGB
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
else:
img_bgr = img
# Run YOLO detection on each image
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
for x1, y1, x2, y2, conf, cls in detections:
face_roi = img_bgr[y1:y2, x1:x2]
if face_roi.size == 0:
continue
emb = pl.facenet.get_embedding(face_roi)
if emb is not None:
pl.db.add_embedding(name, emb)
count_enrolled += 1
if count_enrolled > 0:
pl.db.save()
return f"Enrolled {name} with {count_enrolled} face(s)!"
else:
return "No faces were detected or embedded. Enrollment failed."
def search_by_name(name: str) -> str:
pl = load_pipeline()
if not name:
return "No name provided"
embeddings = pl.db.get_embeddings_by_label(name)
if embeddings:
return f"User '{name}' found with {len(embeddings)} embedding(s)."
else:
return f"No embeddings found for user '{name}'."
def search_by_image(image: np.ndarray) -> str:
"""
Search database by face in the uploaded image.
"""
pl = load_pipeline()
if image is None:
return "No image uploaded."
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres)
if not detections:
return "No faces detected in the uploaded image."
x1, y1, x2, y2, conf, cls = detections[0]
face_roi = img_bgr[y1:y2, x1:x2]
if face_roi.size == 0:
return "Empty face ROI in the uploaded image."
emb = pl.facenet.get_embedding(face_roi)
if emb is None:
return "Failed to generate embedding from the face."
search_results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres)
if not search_results:
return "No matching users found in the database (under current threshold)."
lines = []
for label, sim in search_results:
lines.append(f" - {label}, similarity={sim:.3f}")
return "Search results:\n" + "\n".join(lines)
def remove_user(label: str) -> str:
pl = load_pipeline()
if not label:
return "No user label selected."
pl.db.remove_label(label)
pl.db.save()
return f"User '{label}' removed."
def list_users() -> str:
pl = load_pipeline()
labels = pl.db.list_labels()
if labels:
return f"Enrolled users:\n{', '.join(labels)}"
return "No users enrolled."
# --------------------------------------------------------------------
# TAB: Real-Time Recognition
# --------------------------------------------------------------------
def process_webcam_frame(frame: np.ndarray) -> Tuple[np.ndarray, str]:
"""
Called for every incoming webcam frame. Return annotated frame + textual info.
Gradio delivers frames in RGB.
"""
if frame is None:
return None, "No frame."
pl = load_pipeline()
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
annotated_bgr, detections = pl.process_frame(frame_bgr)
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB)
return annotated_rgb, str(detections)
# --------------------------------------------------------------------
# TAB: Image Test
# --------------------------------------------------------------------
def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]:
if img is None:
return None, "No image uploaded."
pl = load_pipeline()
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
processed, detections = pl.process_frame(img_bgr)
out_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB)
return out_rgb, str(detections)
# --------------------------------------------------------------------
# BUILD THE GRADIO APP
# --------------------------------------------------------------------
def build_app():
with gr.Blocks() as demo:
gr.Markdown("# Face Recognition System (Gradio)")
with gr.Tab("Real-Time Recognition"):
gr.Markdown("Live face recognition from your webcam (roughly 'real-time').")
webcam_input = gr.Video(source="webcam", mirror=True, streaming=True)
webcam_output = gr.Image()
webcam_info = gr.Textbox(label="Detections", interactive=False)
webcam_input.change(
fn=process_webcam_frame,
inputs=webcam_input,
outputs=[webcam_output, webcam_info],
)
with gr.Tab("Image Test"):
gr.Markdown("Upload a single image for face detection and recognition.")
image_input = gr.Image(type="numpy", label="Upload Image")
image_out = gr.Image()
image_info = gr.Textbox(label="Detections", interactive=False)
process_btn = gr.Button("Process Image")
process_btn.click(
fn=process_test_image,
inputs=image_input,
outputs=[image_out, image_info],
)
with gr.Tab("Configuration"):
gr.Markdown("Modify the pipeline settings and thresholds here.")
with gr.Row():
enable_recognition = gr.Checkbox(label="Enable Face Recognition", value=True)
enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True)
enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True)
enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True)
enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False)
enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False)
gr.Markdown("**Face Mesh Options** (only if Face Mesh is enabled):")
with gr.Row():
show_tesselation = gr.Checkbox(label="Show Tesselation", value=False)
show_contours = gr.Checkbox(label="Show Contours", value=False)
show_irises = gr.Checkbox(label="Show Irises", value=False)
gr.Markdown("**Thresholds**")
detection_conf = gr.Slider(0.0, 1.0, value=0.4, step=0.01, label="Detection Confidence Threshold")
recognition_thresh = gr.Slider(0.5, 1.0, value=0.85, step=0.01, label="Recognition Similarity Threshold")
antispoof_thresh = gr.Slider(0, 200, value=80, step=1, label="Anti-Spoof Laplacian Threshold")
blink_thresh = gr.Slider(0, 0.5, value=0.25, step=0.01, label="Blink EAR Threshold")
hand_det_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Detection Confidence")
hand_track_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Tracking Confidence")
gr.Markdown("**Color Options (Hex)**")
bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00")
spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000")
unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000")
eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00")
blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff")
hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d")
hand_connection_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600")
hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff")
mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64")
contour_hex = gr.Textbox(label="Contour Color", value="#c8c800")
iris_hex = gr.Textbox(label="Iris Color", value="#ff00ff")
eye_color_text_hex = gr.Textbox(label="Eye Color Text Color", value="#ffffff")
save_btn = gr.Button("Save Configuration")
save_msg = gr.Textbox(label="", interactive=False)
save_btn.click(
fn=update_config,
inputs=[
enable_recognition, enable_antispoof, enable_blink,
enable_hand, enable_eyecolor, enable_facemesh,
show_tesselation, show_contours, show_irises,
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh,
hand_det_conf, hand_track_conf,
bbox_hex, spoofed_hex, unknown_hex,
eye_hex, blink_hex,
hand_landmark_hex, hand_connection_hex, hand_text_hex,
mesh_hex, contour_hex, iris_hex, eye_color_text_hex
],
outputs=[save_msg],
)
with gr.Tab("Database Management"):
gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.")
with gr.Accordion("User Enrollment", open=False):
enroll_name = gr.Textbox(label="Enter name for enrollment")
enroll_images = gr.Image(type="numpy", label="Upload Enrollment Images", multiple=True)
enroll_btn = gr.Button("Enroll User")
enroll_result = gr.Textbox(label="", interactive=False)
enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_images], outputs=[enroll_result])
with gr.Accordion("User Search", open=False):
search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By")
search_name_input = gr.Dropdown(label="Select User", choices=[], value=None, interactive=True)
search_image_input = gr.Image(type="numpy", label="Upload Image", visible=False)
search_btn = gr.Button("Search")
search_result = gr.Textbox(label="", interactive=False)
def update_search_visibility(mode):
if mode == "Name":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
search_mode.change(fn=update_search_visibility,
inputs=[search_mode],
outputs=[search_name_input, search_image_input])
def search_user(mode, name, img):
if mode == "Name":
return search_by_name(name)
else:
return search_by_image(img)
search_btn.click(fn=search_user,
inputs=[search_mode, search_name_input, search_image_input],
outputs=[search_result])
with gr.Accordion("User Management Tools", open=False):
list_btn = gr.Button("List Enrolled Users")
list_output = gr.Textbox(label="", interactive=False)
list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_output])
# Reload user list dropdown
def get_user_list():
pl = load_pipeline()
return gr.update(choices=pl.db.list_labels())
# A dedicated button to refresh the dropdown
refresh_users_btn = gr.Button("Refresh User List")
refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[search_name_input])
remove_user_select = gr.Dropdown(label="Select User to Remove", choices=[])
remove_btn = gr.Button("Remove Selected User")
remove_output = gr.Textbox(label="", interactive=False)
remove_btn.click(fn=remove_user, inputs=[remove_user_select], outputs=[remove_output])
refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[remove_user_select])
return demo
# --------------------------------------------------------------------
# MAIN
# --------------------------------------------------------------------
if __name__ == "__main__":
app = build_app()
app.queue().launch(server_name="0.0.0.0", server_port=7860)