|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import math |
|
import requests |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import pickle |
|
import logging |
|
from PIL import Image |
|
from typing import Optional, Dict, List, Tuple |
|
from dataclasses import dataclass, field |
|
from collections import Counter |
|
|
|
|
|
import gradio as gr |
|
from ultralytics import YOLO |
|
from facenet_pytorch import InceptionResnetV1 |
|
from torchvision import transforms |
|
from deep_sort_realtime.deepsort_tracker import DeepSort |
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.FileHandler('face_pipeline.log'), logging.StreamHandler()], |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
logging.getLogger('torch').setLevel(logging.ERROR) |
|
logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR) |
|
|
|
DEFAULT_MODEL_URL = "https://github.com/wuhplaptop/face-11-n/blob/main/face2.pt?raw=true" |
|
DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl") |
|
MODEL_DIR = os.path.expanduser("~/.face_pipeline/models") |
|
CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl") |
|
|
|
|
|
LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144] |
|
RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373] |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class PipelineConfig: |
|
detector: Dict = field(default_factory=dict) |
|
tracker: Dict = field(default_factory=dict) |
|
recognition: Dict = field(default_factory=dict) |
|
anti_spoof: Dict = field(default_factory=dict) |
|
blink: Dict = field(default_factory=dict) |
|
face_mesh_options: Dict = field(default_factory=dict) |
|
hand: Dict = field(default_factory=dict) |
|
eye_color: Dict = field(default_factory=dict) |
|
enabled_components: Dict = field(default_factory=dict) |
|
detection_conf_thres: float = 0.4 |
|
recognition_conf_thres: float = 0.85 |
|
bbox_color: Tuple[int, int, int] = (0, 255, 0) |
|
spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255) |
|
unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255) |
|
eye_outline_color: Tuple[int, int, int] = (255, 255, 0) |
|
blink_text_color: Tuple[int, int, int] = (0, 0, 255) |
|
hand_landmark_color: Tuple[int, int, int] = (255, 210, 77) |
|
hand_connection_color: Tuple[int, int, int] = (204, 102, 0) |
|
hand_text_color: Tuple[int, int, int] = (255, 255, 255) |
|
mesh_color: Tuple[int, int, int] = (100, 255, 100) |
|
contour_color: Tuple[int, int, int] = (200, 200, 0) |
|
iris_color: Tuple[int, int, int] = (255, 0, 255) |
|
eye_color_text_color: Tuple[int, int, int] = (255, 255, 255) |
|
|
|
def __post_init__(self): |
|
self.detector = self.detector or { |
|
'model_path': os.path.join(MODEL_DIR, "face2.pt"), |
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu', |
|
} |
|
self.tracker = self.tracker or {'max_age': 30} |
|
self.recognition = self.recognition or {'enable': True} |
|
self.anti_spoof = self.anti_spoof or {'enable': True, 'lap_thresh': 80.0} |
|
self.blink = self.blink or {'enable': True, 'ear_thresh': 0.25} |
|
self.face_mesh_options = self.face_mesh_options or { |
|
'enable': False, |
|
'tesselation': False, |
|
'contours': False, |
|
'irises': False, |
|
} |
|
self.hand = self.hand or { |
|
'enable': True, |
|
'min_detection_confidence': 0.5, |
|
'min_tracking_confidence': 0.5, |
|
} |
|
self.eye_color = self.eye_color or {'enable': False} |
|
self.enabled_components = self.enabled_components or { |
|
'detection': True, |
|
'tracking': True, |
|
'anti_spoof': True, |
|
'recognition': True, |
|
'blink': True, |
|
'face_mesh': False, |
|
'hand': True, |
|
'eye_color': False, |
|
} |
|
|
|
def save(self, path: str): |
|
try: |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
with open(path, 'wb') as f: |
|
pickle.dump(self.__dict__, f) |
|
logger.info(f"Saved config to {path}") |
|
except Exception as e: |
|
logger.error(f"Config save failed: {str(e)}") |
|
raise RuntimeError(f"Config save failed: {str(e)}") from e |
|
|
|
@classmethod |
|
def load(cls, path: str) -> 'PipelineConfig': |
|
try: |
|
if os.path.exists(path): |
|
with open(path, 'rb') as f: |
|
data = pickle.load(f) |
|
return cls(**data) |
|
return cls() |
|
except Exception as e: |
|
logger.error(f"Config load failed: {str(e)}") |
|
return cls() |
|
|
|
|
|
|
|
|
|
class FaceDatabase: |
|
def __init__(self, db_path: str = DEFAULT_DB_PATH): |
|
self.db_path = db_path |
|
self.embeddings: Dict[str, List[np.ndarray]] = {} |
|
self._load() |
|
|
|
def _load(self): |
|
try: |
|
if os.path.exists(self.db_path): |
|
with open(self.db_path, 'rb') as f: |
|
self.embeddings = pickle.load(f) |
|
logger.info(f"Loaded database from {self.db_path}") |
|
except Exception as e: |
|
logger.error(f"Database load failed: {str(e)}") |
|
self.embeddings = {} |
|
|
|
def save(self): |
|
try: |
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) |
|
with open(self.db_path, 'wb') as f: |
|
pickle.dump(self.embeddings, f) |
|
logger.info(f"Saved database to {self.db_path}") |
|
except Exception as e: |
|
logger.error(f"Database save failed: {str(e)}") |
|
raise RuntimeError(f"Database save failed: {str(e)}") from e |
|
|
|
def add_embedding(self, label: str, embedding: np.ndarray): |
|
try: |
|
if not isinstance(embedding, np.ndarray) or embedding.ndim != 1: |
|
raise ValueError("Invalid embedding format") |
|
if label not in self.embeddings: |
|
self.embeddings[label] = [] |
|
self.embeddings[label].append(embedding) |
|
logger.debug(f"Added embedding for {label}") |
|
except Exception as e: |
|
logger.error(f"Add embedding failed: {str(e)}") |
|
raise |
|
|
|
def remove_label(self, label: str): |
|
try: |
|
if label in self.embeddings: |
|
del self.embeddings[label] |
|
logger.info(f"Removed {label}") |
|
else: |
|
logger.warning(f"Label {label} not found") |
|
except Exception as e: |
|
logger.error(f"Remove label failed: {str(e)}") |
|
raise |
|
|
|
def list_labels(self) -> List[str]: |
|
return list(self.embeddings.keys()) |
|
|
|
def get_embeddings_by_label(self, label: str) -> Optional[List[np.ndarray]]: |
|
return self.embeddings.get(label) |
|
|
|
def search_by_image(self, query_embedding: np.ndarray, threshold: float = 0.7) -> List[Tuple[str, float]]: |
|
results = [] |
|
for label, embeddings in self.embeddings.items(): |
|
for db_emb in embeddings: |
|
similarity = FacePipeline.cosine_similarity(query_embedding, db_emb) |
|
if similarity >= threshold: |
|
results.append((label, similarity)) |
|
return sorted(results, key=lambda x: x[1], reverse=True) |
|
|
|
|
|
|
|
|
|
class YOLOFaceDetector: |
|
def __init__(self, model_path: str, device: str = 'cpu'): |
|
self.model = None |
|
self.device = device |
|
try: |
|
if not os.path.exists(model_path): |
|
logger.info(f"Model file not found at {model_path}. Attempting to download...") |
|
response = requests.get(DEFAULT_MODEL_URL) |
|
response.raise_for_status() |
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
with open(model_path, 'wb') as f: |
|
f.write(response.content) |
|
logger.info(f"Downloaded YOLO model to {model_path}") |
|
|
|
self.model = YOLO(model_path) |
|
self.model.to(device) |
|
logger.info(f"Loaded YOLO model from {model_path}") |
|
except Exception as e: |
|
logger.error(f"YOLO initialization failed: {str(e)}") |
|
raise |
|
|
|
def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]: |
|
try: |
|
results = self.model.predict( |
|
source=image, conf=conf_thres, verbose=False, device=self.device |
|
) |
|
detections = [] |
|
for result in results: |
|
for box in result.boxes: |
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() |
|
conf = float(box.conf[0].cpu().numpy()) |
|
cls = int(box.cls[0].cpu().numpy()) if box.cls is not None else 0 |
|
detections.append((int(x1), int(y1), int(x2), int(y2), conf, cls)) |
|
logger.debug(f"Detected {len(detections)} faces.") |
|
return detections |
|
except Exception as e: |
|
logger.error(f"Detection failed: {str(e)}") |
|
return [] |
|
|
|
|
|
|
|
|
|
class FaceTracker: |
|
def __init__(self, max_age: int = 30): |
|
self.tracker = DeepSort(max_age=max_age, embedder='mobilenet') |
|
|
|
def update(self, detections: List[Tuple], frame: np.ndarray): |
|
try: |
|
ds_detections = [ |
|
([x1, y1, x2 - x1, y2 - y1], conf, cls) |
|
for (x1, y1, x2, y2, conf, cls) in detections |
|
] |
|
tracks = self.tracker.update_tracks(ds_detections, frame=frame) |
|
logger.debug(f"Updated tracker with {len(tracks)} tracks.") |
|
return tracks |
|
except Exception as e: |
|
logger.error(f"Tracking update failed: {str(e)}") |
|
return [] |
|
|
|
|
|
|
|
|
|
class FaceNetEmbedder: |
|
def __init__(self, device: str = 'cpu'): |
|
self.device = device |
|
self.model = InceptionResnetV1(pretrained='vggface2').eval().to(device) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((160, 160)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5]*3, [0.5]*3), |
|
]) |
|
|
|
def get_embedding(self, face_bgr: np.ndarray) -> Optional[np.ndarray]: |
|
try: |
|
face_rgb = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB) |
|
pil_img = Image.fromarray(face_rgb).convert('RGB') |
|
tens = self.transform(pil_img).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
embedding = self.model(tens)[0].cpu().numpy() |
|
logger.debug(f"Generated embedding: {embedding[:5]}...") |
|
return embedding |
|
except Exception as e: |
|
logger.error(f"Embedding generation failed: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
|
|
class FacePipeline: |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
self.detector = None |
|
self.tracker = None |
|
self.facenet = None |
|
self.db = None |
|
self._initialized = False |
|
|
|
def initialize(self): |
|
try: |
|
self.detector = YOLOFaceDetector( |
|
model_path=self.config.detector['model_path'], |
|
device=self.config.detector['device'] |
|
) |
|
self.tracker = FaceTracker(max_age=self.config.tracker['max_age']) |
|
self.facenet = FaceNetEmbedder(device=self.config.detector['device']) |
|
self.db = FaceDatabase() |
|
self._initialized = True |
|
logger.info("FacePipeline initialized successfully.") |
|
except Exception as e: |
|
logger.error(f"Pipeline initialization failed: {str(e)}") |
|
self._initialized = False |
|
raise |
|
|
|
def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]: |
|
if not self._initialized: |
|
logger.error("Pipeline not initialized!") |
|
return frame, [] |
|
|
|
try: |
|
detections = self.detector.detect(frame, self.config.detection_conf_thres) |
|
tracked_objects = self.tracker.update(detections, frame) |
|
annotated_frame = frame.copy() |
|
results = [] |
|
|
|
for obj in tracked_objects: |
|
if not obj.is_confirmed(): |
|
continue |
|
track_id = obj.track_id |
|
bbox = obj.to_tlbr() |
|
x1, y1, x2, y2 = bbox.astype(int) |
|
conf = getattr(obj, 'score', 1.0) |
|
cls = getattr(obj, 'class_id', 0) |
|
|
|
face_roi = frame[y1:y2, x1:x2] |
|
if face_roi.size == 0: |
|
logger.warning(f"Empty face ROI for track {track_id}") |
|
continue |
|
|
|
|
|
is_spoofed = False |
|
if self.config.anti_spoof['enable']: |
|
is_spoofed = not self.is_real_face(face_roi) |
|
if is_spoofed: |
|
cls = 1 |
|
|
|
if is_spoofed: |
|
box_color_bgr = self.config.spoofed_bbox_color[::-1] |
|
name = "Spoofed" |
|
similarity = 0.0 |
|
else: |
|
|
|
embedding = self.facenet.get_embedding(face_roi) |
|
if embedding is not None and self.config.recognition['enable']: |
|
name, similarity = self.recognize_face( |
|
embedding, self.config.recognition_conf_thres |
|
) |
|
else: |
|
name = "Unknown" |
|
similarity = 0.0 |
|
|
|
box_color_rgb = ( |
|
self.config.bbox_color |
|
if name != "Unknown" |
|
else self.config.unknown_bbox_color |
|
) |
|
box_color_bgr = box_color_rgb[::-1] |
|
|
|
label_text = f"{name}" |
|
cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), box_color_bgr, 2) |
|
cv2.putText( |
|
annotated_frame, label_text, (x1, y1 - 10), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2 |
|
) |
|
|
|
detection_info = { |
|
'track_id': track_id, |
|
'bbox': (x1, y1, x2, y2), |
|
'confidence': float(conf), |
|
'class_id': cls, |
|
'name': name, |
|
'similarity': float(similarity), |
|
} |
|
results.append(detection_info) |
|
|
|
return annotated_frame, results |
|
|
|
except Exception as e: |
|
logger.error(f"Frame processing failed: {str(e)}", exc_info=True) |
|
return frame, [] |
|
|
|
def is_real_face(self, face_roi: np.ndarray) -> bool: |
|
try: |
|
gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY) |
|
lap_var = cv2.Laplacian(gray, cv2.CV_64F).var() |
|
return lap_var > self.config.anti_spoof['lap_thresh'] |
|
except Exception as e: |
|
logger.error(f"Anti-spoof check failed: {str(e)}") |
|
return False |
|
|
|
def recognize_face(self, embedding: np.ndarray, recognition_threshold: float) -> Tuple[str, float]: |
|
try: |
|
best_match = "Unknown" |
|
best_similarity = 0.0 |
|
for label, embeddings in self.db.embeddings.items(): |
|
for db_emb in embeddings: |
|
similarity = FacePipeline.cosine_similarity(embedding, db_emb) |
|
if similarity > best_similarity: |
|
best_similarity = similarity |
|
best_match = label |
|
if best_similarity < recognition_threshold: |
|
best_match = "Unknown" |
|
return best_match, best_similarity |
|
except Exception as e: |
|
logger.error(f"Face recognition failed: {str(e)}") |
|
return "Unknown", 0.0 |
|
|
|
@staticmethod |
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-6)) |
|
|
|
|
|
|
|
|
|
pipeline = None |
|
|
|
def load_pipeline() -> FacePipeline: |
|
global pipeline |
|
if pipeline is None: |
|
logger.info("Loading pipeline for the first time...") |
|
cfg = PipelineConfig.load(CONFIG_PATH) |
|
pipeline = FacePipeline(cfg) |
|
pipeline.initialize() |
|
return pipeline |
|
|
|
|
|
|
|
|
|
def hex_to_bgr(hex_str: str) -> Tuple[int,int,int]: |
|
""" |
|
Convert a hex string (#RRGGBB) into a BGR tuple as used in OpenCV. |
|
""" |
|
if not hex_str.startswith('#'): |
|
hex_str = f"#{hex_str}" |
|
hex_str = hex_str.lstrip('#') |
|
if len(hex_str) != 6: |
|
return (255, 0, 0) |
|
r = int(hex_str[0:2], 16) |
|
g = int(hex_str[2:4], 16) |
|
b = int(hex_str[4:6], 16) |
|
return (b,g,r) |
|
|
|
def bgr_to_hex(bgr: Tuple[int,int,int]) -> str: |
|
""" |
|
Convert a BGR tuple (as stored in pipeline config) to a #RRGGBB hex string. |
|
""" |
|
b,g,r = bgr |
|
return f"#{r:02x}{g:02x}{b:02x}" |
|
|
|
|
|
|
|
|
|
def update_config( |
|
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh, |
|
show_tesselation, show_contours, show_irises, |
|
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, |
|
hand_det_conf, hand_track_conf, |
|
bbox_hex, spoofed_hex, unknown_hex, |
|
eye_hex, blink_hex, |
|
hand_landmark_hex, hand_connection_hex, hand_text_hex, |
|
mesh_hex, contour_hex, iris_hex, |
|
eye_color_text_hex |
|
): |
|
|
|
pl = load_pipeline() |
|
cfg = pl.config |
|
|
|
|
|
cfg.recognition['enable'] = enable_recognition |
|
cfg.anti_spoof['enable'] = enable_antispoof |
|
cfg.blink['enable'] = enable_blink |
|
cfg.hand['enable'] = enable_hand |
|
cfg.eye_color['enable'] = enable_eyecolor |
|
cfg.face_mesh_options['enable'] = enable_facemesh |
|
|
|
cfg.face_mesh_options['tesselation'] = show_tesselation |
|
cfg.face_mesh_options['contours'] = show_contours |
|
cfg.face_mesh_options['irises'] = show_irises |
|
|
|
|
|
cfg.detection_conf_thres = detection_conf |
|
cfg.recognition_conf_thres = recognition_thresh |
|
cfg.anti_spoof['lap_thresh'] = antispoof_thresh |
|
cfg.blink['ear_thresh'] = blink_thresh |
|
cfg.hand['min_detection_confidence'] = hand_det_conf |
|
cfg.hand['min_tracking_confidence'] = hand_track_conf |
|
|
|
|
|
cfg.bbox_color = hex_to_bgr(bbox_hex)[::-1] |
|
cfg.spoofed_bbox_color = hex_to_bgr(spoofed_hex)[::-1] |
|
cfg.unknown_bbox_color = hex_to_bgr(unknown_hex)[::-1] |
|
cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1] |
|
cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1] |
|
cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1] |
|
cfg.hand_connection_color = hex_to_bgr(hand_connection_hex)[::-1] |
|
cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1] |
|
cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1] |
|
cfg.contour_color = hex_to_bgr(contour_hex)[::-1] |
|
cfg.iris_color = hex_to_bgr(iris_hex)[::-1] |
|
cfg.eye_color_text_color = hex_to_bgr(eye_color_text_hex)[::-1] |
|
|
|
|
|
cfg.save(CONFIG_PATH) |
|
|
|
return "Configuration saved successfully!" |
|
|
|
|
|
|
|
|
|
def enroll_user(name: str, images: List[np.ndarray]) -> str: |
|
""" |
|
Enroll user by name using one or more images. images is a list of |
|
NxMx3 numpy arrays in BGR or RGB depending on Gradio type. |
|
""" |
|
pl = load_pipeline() |
|
if not name: |
|
return "Please provide a user name." |
|
|
|
if not images or len(images) == 0: |
|
return "No images provided." |
|
|
|
count_enrolled = 0 |
|
for img in images: |
|
if img is None: |
|
continue |
|
|
|
if img.shape[-1] == 3: |
|
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
else: |
|
img_bgr = img |
|
|
|
|
|
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres) |
|
for x1, y1, x2, y2, conf, cls in detections: |
|
face_roi = img_bgr[y1:y2, x1:x2] |
|
if face_roi.size == 0: |
|
continue |
|
emb = pl.facenet.get_embedding(face_roi) |
|
if emb is not None: |
|
pl.db.add_embedding(name, emb) |
|
count_enrolled += 1 |
|
|
|
if count_enrolled > 0: |
|
pl.db.save() |
|
return f"Enrolled {name} with {count_enrolled} face(s)!" |
|
else: |
|
return "No faces were detected or embedded. Enrollment failed." |
|
|
|
def search_by_name(name: str) -> str: |
|
pl = load_pipeline() |
|
if not name: |
|
return "No name provided" |
|
embeddings = pl.db.get_embeddings_by_label(name) |
|
if embeddings: |
|
return f"User '{name}' found with {len(embeddings)} embedding(s)." |
|
else: |
|
return f"No embeddings found for user '{name}'." |
|
|
|
def search_by_image(image: np.ndarray) -> str: |
|
""" |
|
Search database by face in the uploaded image. |
|
""" |
|
pl = load_pipeline() |
|
if image is None: |
|
return "No image uploaded." |
|
|
|
img_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) |
|
detections = pl.detector.detect(img_bgr, pl.config.detection_conf_thres) |
|
if not detections: |
|
return "No faces detected in the uploaded image." |
|
|
|
x1, y1, x2, y2, conf, cls = detections[0] |
|
face_roi = img_bgr[y1:y2, x1:x2] |
|
if face_roi.size == 0: |
|
return "Empty face ROI in the uploaded image." |
|
|
|
emb = pl.facenet.get_embedding(face_roi) |
|
if emb is None: |
|
return "Failed to generate embedding from the face." |
|
|
|
search_results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres) |
|
if not search_results: |
|
return "No matching users found in the database (under current threshold)." |
|
|
|
lines = [] |
|
for label, sim in search_results: |
|
lines.append(f" - {label}, similarity={sim:.3f}") |
|
return "Search results:\n" + "\n".join(lines) |
|
|
|
def remove_user(label: str) -> str: |
|
pl = load_pipeline() |
|
if not label: |
|
return "No user label selected." |
|
pl.db.remove_label(label) |
|
pl.db.save() |
|
return f"User '{label}' removed." |
|
|
|
def list_users() -> str: |
|
pl = load_pipeline() |
|
labels = pl.db.list_labels() |
|
if labels: |
|
return f"Enrolled users:\n{', '.join(labels)}" |
|
return "No users enrolled." |
|
|
|
|
|
|
|
|
|
def process_webcam_frame(frame: np.ndarray) -> Tuple[np.ndarray, str]: |
|
""" |
|
Called for every incoming webcam frame. Return annotated frame + textual info. |
|
Gradio delivers frames in RGB. |
|
""" |
|
if frame is None: |
|
return None, "No frame." |
|
pl = load_pipeline() |
|
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
annotated_bgr, detections = pl.process_frame(frame_bgr) |
|
annotated_rgb = cv2.cvtColor(annotated_bgr, cv2.COLOR_BGR2RGB) |
|
return annotated_rgb, str(detections) |
|
|
|
|
|
|
|
|
|
def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]: |
|
if img is None: |
|
return None, "No image uploaded." |
|
pl = load_pipeline() |
|
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
processed, detections = pl.process_frame(img_bgr) |
|
out_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB) |
|
return out_rgb, str(detections) |
|
|
|
|
|
|
|
|
|
def build_app(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Face Recognition System (Gradio)") |
|
|
|
with gr.Tab("Real-Time Recognition"): |
|
gr.Markdown("Live face recognition from your webcam (roughly 'real-time').") |
|
webcam_input = gr.Video(source="webcam", mirror=True, streaming=True) |
|
webcam_output = gr.Image() |
|
webcam_info = gr.Textbox(label="Detections", interactive=False) |
|
webcam_input.change( |
|
fn=process_webcam_frame, |
|
inputs=webcam_input, |
|
outputs=[webcam_output, webcam_info], |
|
) |
|
|
|
with gr.Tab("Image Test"): |
|
gr.Markdown("Upload a single image for face detection and recognition.") |
|
image_input = gr.Image(type="numpy", label="Upload Image") |
|
image_out = gr.Image() |
|
image_info = gr.Textbox(label="Detections", interactive=False) |
|
process_btn = gr.Button("Process Image") |
|
|
|
process_btn.click( |
|
fn=process_test_image, |
|
inputs=image_input, |
|
outputs=[image_out, image_info], |
|
) |
|
|
|
with gr.Tab("Configuration"): |
|
gr.Markdown("Modify the pipeline settings and thresholds here.") |
|
|
|
with gr.Row(): |
|
enable_recognition = gr.Checkbox(label="Enable Face Recognition", value=True) |
|
enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True) |
|
enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True) |
|
enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True) |
|
enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False) |
|
enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False) |
|
|
|
gr.Markdown("**Face Mesh Options** (only if Face Mesh is enabled):") |
|
with gr.Row(): |
|
show_tesselation = gr.Checkbox(label="Show Tesselation", value=False) |
|
show_contours = gr.Checkbox(label="Show Contours", value=False) |
|
show_irises = gr.Checkbox(label="Show Irises", value=False) |
|
|
|
gr.Markdown("**Thresholds**") |
|
detection_conf = gr.Slider(0.0, 1.0, value=0.4, step=0.01, label="Detection Confidence Threshold") |
|
recognition_thresh = gr.Slider(0.5, 1.0, value=0.85, step=0.01, label="Recognition Similarity Threshold") |
|
antispoof_thresh = gr.Slider(0, 200, value=80, step=1, label="Anti-Spoof Laplacian Threshold") |
|
blink_thresh = gr.Slider(0, 0.5, value=0.25, step=0.01, label="Blink EAR Threshold") |
|
hand_det_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Detection Confidence") |
|
hand_track_conf = gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Hand Tracking Confidence") |
|
|
|
gr.Markdown("**Color Options (Hex)**") |
|
bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00") |
|
spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000") |
|
unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000") |
|
|
|
eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00") |
|
blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff") |
|
|
|
hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d") |
|
hand_connection_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600") |
|
hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff") |
|
|
|
mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64") |
|
contour_hex = gr.Textbox(label="Contour Color", value="#c8c800") |
|
iris_hex = gr.Textbox(label="Iris Color", value="#ff00ff") |
|
|
|
eye_color_text_hex = gr.Textbox(label="Eye Color Text Color", value="#ffffff") |
|
|
|
save_btn = gr.Button("Save Configuration") |
|
save_msg = gr.Textbox(label="", interactive=False) |
|
|
|
save_btn.click( |
|
fn=update_config, |
|
inputs=[ |
|
enable_recognition, enable_antispoof, enable_blink, |
|
enable_hand, enable_eyecolor, enable_facemesh, |
|
show_tesselation, show_contours, show_irises, |
|
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, |
|
hand_det_conf, hand_track_conf, |
|
bbox_hex, spoofed_hex, unknown_hex, |
|
eye_hex, blink_hex, |
|
hand_landmark_hex, hand_connection_hex, hand_text_hex, |
|
mesh_hex, contour_hex, iris_hex, eye_color_text_hex |
|
], |
|
outputs=[save_msg], |
|
) |
|
|
|
with gr.Tab("Database Management"): |
|
gr.Markdown("Enroll Users, Search by Name or Image, Remove or List.") |
|
|
|
with gr.Accordion("User Enrollment", open=False): |
|
enroll_name = gr.Textbox(label="Enter name for enrollment") |
|
enroll_images = gr.Image(type="numpy", label="Upload Enrollment Images", multiple=True) |
|
enroll_btn = gr.Button("Enroll User") |
|
enroll_result = gr.Textbox(label="", interactive=False) |
|
enroll_btn.click(fn=enroll_user, inputs=[enroll_name, enroll_images], outputs=[enroll_result]) |
|
|
|
with gr.Accordion("User Search", open=False): |
|
search_mode = gr.Radio(["Name", "Image"], value="Name", label="Search Database By") |
|
search_name_input = gr.Dropdown(label="Select User", choices=[], value=None, interactive=True) |
|
search_image_input = gr.Image(type="numpy", label="Upload Image", visible=False) |
|
search_btn = gr.Button("Search") |
|
search_result = gr.Textbox(label="", interactive=False) |
|
|
|
def update_search_visibility(mode): |
|
if mode == "Name": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
search_mode.change(fn=update_search_visibility, |
|
inputs=[search_mode], |
|
outputs=[search_name_input, search_image_input]) |
|
|
|
def search_user(mode, name, img): |
|
if mode == "Name": |
|
return search_by_name(name) |
|
else: |
|
return search_by_image(img) |
|
|
|
search_btn.click(fn=search_user, |
|
inputs=[search_mode, search_name_input, search_image_input], |
|
outputs=[search_result]) |
|
|
|
with gr.Accordion("User Management Tools", open=False): |
|
list_btn = gr.Button("List Enrolled Users") |
|
list_output = gr.Textbox(label="", interactive=False) |
|
list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_output]) |
|
|
|
|
|
def get_user_list(): |
|
pl = load_pipeline() |
|
return gr.update(choices=pl.db.list_labels()) |
|
|
|
|
|
refresh_users_btn = gr.Button("Refresh User List") |
|
refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[search_name_input]) |
|
|
|
remove_user_select = gr.Dropdown(label="Select User to Remove", choices=[]) |
|
remove_btn = gr.Button("Remove Selected User") |
|
remove_output = gr.Textbox(label="", interactive=False) |
|
|
|
remove_btn.click(fn=remove_user, inputs=[remove_user_select], outputs=[remove_output]) |
|
refresh_users_btn.click(fn=get_user_list, inputs=[], outputs=[remove_user_select]) |
|
|
|
return demo |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
app = build_app() |
|
app.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|