|
import os |
|
import sys |
|
import math |
|
import requests |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import pickle |
|
import logging |
|
from PIL import Image |
|
from typing import Optional, Dict, List, Tuple |
|
from dataclasses import dataclass, field |
|
from collections import Counter |
|
import io |
|
import tempfile |
|
|
|
import gradio as gr |
|
|
|
from ultralytics import YOLO |
|
from facenet_pytorch import InceptionResnetV1 |
|
from torchvision import transforms |
|
from deep_sort_realtime.deepsort_tracker import DeepSort |
|
|
|
import mediapipe as mp |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
handlers=[logging.FileHandler('face_pipeline.log'), logging.StreamHandler()], |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
logging.getLogger('torch').setLevel(logging.ERROR) |
|
logging.getLogger('mediapipe').setLevel(logging.ERROR) |
|
logging.getLogger('deep_sort_realtime').setLevel(logging.ERROR) |
|
|
|
|
|
DEFAULT_MODEL_URL = "https://github.com/wuhplaptop/face-11-n/blob/main/face2.pt?raw=true" |
|
DEFAULT_DB_PATH = os.path.expanduser("~/.face_pipeline/known_faces.pkl") |
|
MODEL_DIR = os.path.expanduser("~/.face_pipeline/models") |
|
CONFIG_PATH = os.path.expanduser("~/.face_pipeline/config.pkl") |
|
|
|
|
|
LEFT_EYE_IDX = [33, 160, 158, 133, 153, 144] |
|
RIGHT_EYE_IDX = [263, 387, 385, 362, 380, 373] |
|
|
|
|
|
mp_drawing = mp.solutions.drawing_utils |
|
mp_face_mesh = mp.solutions.face_mesh |
|
mp_hands = mp.solutions.hands |
|
|
|
@dataclass |
|
class PipelineConfig: |
|
detector: Dict = field(default_factory=dict) |
|
tracker: Dict = field(default_factory=dict) |
|
recognition: Dict = field(default_factory=dict) |
|
anti_spoof: Dict = field(default_factory=dict) |
|
blink: Dict = field(default_factory=dict) |
|
face_mesh_options: Dict = field(default_factory=dict) |
|
hand: Dict = field(default_factory=dict) |
|
eye_color: Dict = field(default_factory=dict) |
|
enabled_components: Dict = field(default_factory=dict) |
|
|
|
detection_conf_thres: float = 0.4 |
|
recognition_conf_thres: float = 0.85 |
|
|
|
bbox_color: Tuple[int, int, int] = (0, 255, 0) |
|
spoofed_bbox_color: Tuple[int, int, int] = (0, 0, 255) |
|
unknown_bbox_color: Tuple[int, int, int] = (0, 0, 255) |
|
eye_outline_color: Tuple[int, int, int] = (255, 255, 0) |
|
blink_text_color: Tuple[int, int, int] = (0, 0, 255) |
|
hand_landmark_color: Tuple[int, int, int] = (255, 210, 77) |
|
hand_connection_color: Tuple[int, int, int] = (204, 102, 0) |
|
hand_text_color: Tuple[int, int, int] = (255, 255, 255) |
|
mesh_color: Tuple[int, int, int] = (100, 255, 100) |
|
contour_color: Tuple[int, int, int] = (200, 200, 0) |
|
iris_color: Tuple[int, int, int] = (255, 0, 255) |
|
eye_color_text_color: Tuple[int, int, int] = (255, 255, 255) |
|
|
|
def __post_init__(self): |
|
self.detector = self.detector or { |
|
'model_path': os.path.join(MODEL_DIR, "face2.pt"), |
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu', |
|
} |
|
self.tracker = self.tracker or {'max_age': 30} |
|
self.recognition = self.recognition or {'enable': True} |
|
self.anti_spoof = self.anti_spoof or {'enable': True, 'lap_thresh': 80.0} |
|
self.blink = self.blink or {'enable': True, 'ear_thresh': 0.25} |
|
self.face_mesh_options = self.face_mesh_options or { |
|
'enable': False, |
|
'tesselation': False, |
|
'contours': False, |
|
'irises': False, |
|
} |
|
self.hand = self.hand or { |
|
'enable': True, |
|
'min_detection_confidence': 0.5, |
|
'min_tracking_confidence': 0.5, |
|
} |
|
self.eye_color = self.eye_color or {'enable': False} |
|
self.enabled_components = self.enabled_components or { |
|
'detection': True, |
|
'tracking': True, |
|
'anti_spoof': True, |
|
'recognition': True, |
|
'blink': True, |
|
'face_mesh': False, |
|
'hand': True, |
|
'eye_color': False, |
|
} |
|
|
|
def save(self, path: str): |
|
"""Save this config to a pickle file.""" |
|
try: |
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
with open(path, 'wb') as f: |
|
pickle.dump(self.__dict__, f) |
|
logger.info(f"Saved config to {path}") |
|
logger.debug(f"Config data saved: {self.__dict__}") |
|
except Exception as e: |
|
logger.error(f"Config save failed: {str(e)}") |
|
raise RuntimeError(f"Config save failed: {str(e)}") from e |
|
|
|
@classmethod |
|
def load(cls, path: str) -> 'PipelineConfig': |
|
"""Load a config from a pickle file.""" |
|
try: |
|
if os.path.exists(path): |
|
with open(path, 'rb') as f: |
|
data = pickle.load(f) |
|
logger.info(f"Loaded config from {path}") |
|
logger.debug(f"Config data loaded: {data}") |
|
return cls(**data) |
|
logger.info("No config file found, using default config.") |
|
return cls() |
|
except Exception as e: |
|
logger.error(f"Config load failed: {str(e)}") |
|
return cls() |
|
|
|
def export_config(self) -> bytes: |
|
"""Export your config to bytes.""" |
|
try: |
|
config_data = self.__dict__ |
|
buf = io.BytesIO() |
|
pickle.dump(config_data, buf) |
|
buf.seek(0) |
|
return buf.read() |
|
except Exception as e: |
|
logger.error(f"Export config failed: {str(e)}") |
|
raise RuntimeError(f"Export config failed: {str(e)}") from e |
|
|
|
@classmethod |
|
def import_config(cls, config_bytes: bytes) -> 'PipelineConfig': |
|
"""Import config from bytes.""" |
|
try: |
|
buf = io.BytesIO(config_bytes) |
|
data = pickle.load(buf) |
|
return cls(**data) |
|
except Exception as e: |
|
logger.error(f"Import config failed: {str(e)}") |
|
raise RuntimeError(f"Import config failed: {str(e)}") from e |
|
|
|
class FaceDatabase: |
|
def __init__(self, db_path: str = DEFAULT_DB_PATH): |
|
self.db_path = db_path |
|
self.embeddings: Dict[str, List[np.ndarray]] = {} |
|
self._load() |
|
|
|
def _load(self): |
|
try: |
|
if os.path.exists(self.db_path): |
|
with open(self.db_path, 'rb') as f: |
|
self.embeddings = pickle.load(f) |
|
logger.info(f"Loaded database from {self.db_path}") |
|
except Exception as e: |
|
logger.error(f"Database load failed: {str(e)}") |
|
self.embeddings = {} |
|
|
|
def save(self): |
|
try: |
|
os.makedirs(os.path.dirname(self.db_path), exist_ok=True) |
|
with open(self.db_path, 'wb') as f: |
|
pickle.dump(self.embeddings, f) |
|
logger.info(f"Saved database to {self.db_path}") |
|
except Exception as e: |
|
logger.error(f"Database save failed: {str(e)}") |
|
raise RuntimeError(f"Database save failed: {str(e)}") from e |
|
|
|
def export_database(self) -> bytes: |
|
"""Export the entire face embeddings DB to bytes.""" |
|
try: |
|
db_data = self.embeddings |
|
buf = io.BytesIO() |
|
pickle.dump(db_data, buf) |
|
buf.seek(0) |
|
return buf.read() |
|
except Exception as e: |
|
logger.error(f"Export database failed: {str(e)}") |
|
raise RuntimeError(f"Export database failed: {str(e)}") from e |
|
|
|
def import_database(self, db_bytes: bytes, merge: bool = True): |
|
""" |
|
Import embeddings from bytes. |
|
If merge=True, merges with current DB. If False, overwrites. |
|
""" |
|
try: |
|
buf = io.BytesIO(db_bytes) |
|
imported_data = pickle.load(buf) |
|
if not isinstance(imported_data, dict): |
|
raise ValueError("Imported data is not a dictionary!") |
|
|
|
if merge: |
|
for label, emb_list in imported_data.items(): |
|
if label not in self.embeddings: |
|
self.embeddings[label] = [] |
|
self.embeddings[label].extend(emb_list) |
|
else: |
|
self.embeddings = imported_data |
|
|
|
self.save() |
|
logger.info(f"Imported face database, merge={merge}") |
|
except Exception as e: |
|
logger.error(f"Import database failed: {str(e)}") |
|
raise RuntimeError(f"Import database failed: {str(e)}") from e |
|
|
|
def add_embedding(self, label: str, embedding: np.ndarray): |
|
try: |
|
if not isinstance(embedding, np.ndarray) or embedding.ndim != 1: |
|
raise ValueError("Invalid embedding format") |
|
if label not in self.embeddings: |
|
self.embeddings[label] = [] |
|
self.embeddings[label].append(embedding) |
|
logger.debug(f"Added embedding for {label}") |
|
except Exception as e: |
|
logger.error(f"Add embedding failed: {str(e)}") |
|
raise |
|
|
|
def remove_label(self, label: str): |
|
try: |
|
if label in self.embeddings: |
|
del self.embeddings[label] |
|
logger.info(f"Removed {label}") |
|
else: |
|
logger.warning(f"Label {label} not found") |
|
except Exception as e: |
|
logger.error(f"Remove label failed: {str(e)}") |
|
raise |
|
|
|
def list_labels(self) -> List[str]: |
|
return list(self.embeddings.keys()) |
|
|
|
def get_embeddings_by_label(self, label: str) -> Optional[List[np.ndarray]]: |
|
return self.embeddings.get(label) |
|
|
|
def search_by_image(self, query_embedding: np.ndarray, threshold: float = 0.7) -> List[Tuple[str, float]]: |
|
results = [] |
|
for lbl, embs in self.embeddings.items(): |
|
for db_emb in embs: |
|
sim = FacePipeline.cosine_similarity(query_embedding, db_emb) |
|
if sim >= threshold: |
|
results.append((lbl, sim)) |
|
return sorted(results, key=lambda x: x[1], reverse=True) |
|
|
|
class YOLOFaceDetector: |
|
def __init__(self, model_path: str, device: str = 'cpu'): |
|
self.model = None |
|
self.device = device |
|
try: |
|
if not os.path.exists(model_path): |
|
logger.info(f"Model not found at {model_path}. Downloading from GitHub...") |
|
resp = requests.get(DEFAULT_MODEL_URL) |
|
resp.raise_for_status() |
|
os.makedirs(os.path.dirname(model_path), exist_ok=True) |
|
with open(model_path, 'wb') as f: |
|
f.write(resp.content) |
|
logger.info(f"Downloaded YOLO model to {model_path}") |
|
|
|
self.model = YOLO(model_path) |
|
self.model.to(device) |
|
logger.info(f"Loaded YOLO model from {model_path}") |
|
except Exception as e: |
|
logger.error(f"YOLO init failed: {str(e)}") |
|
raise |
|
|
|
def detect(self, image: np.ndarray, conf_thres: float) -> List[Tuple[int, int, int, int, float, int]]: |
|
try: |
|
results = self.model.predict( |
|
source=image, conf=conf_thres, verbose=False, device=self.device |
|
) |
|
detections = [] |
|
for result in results: |
|
for box in result.boxes: |
|
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() |
|
conf = float(box.conf[0].cpu().numpy()) |
|
cls = int(box.cls[0].cpu().numpy()) if box.cls is not None else 0 |
|
detections.append((int(x1), int(y1), int(x2), int(y2), conf, cls)) |
|
logger.debug(f"Detected {len(detections)} faces.") |
|
return detections |
|
except Exception as e: |
|
logger.error(f"Detection error: {str(e)}") |
|
return [] |
|
|
|
class FaceTracker: |
|
def __init__(self, max_age: int = 30): |
|
self.tracker = DeepSort(max_age=max_age, embedder='mobilenet') |
|
|
|
def update(self, detections: List[Tuple], frame: np.ndarray): |
|
try: |
|
ds_detections = [ |
|
([x1, y1, x2 - x1, y2 - y1], conf, cls) |
|
for (x1, y1, x2, y2, conf, cls) in detections |
|
] |
|
tracks = self.tracker.update_tracks(ds_detections, frame=frame) |
|
logger.debug(f"Updated tracker with {len(tracks)} tracks.") |
|
return tracks |
|
except Exception as e: |
|
logger.error(f"Tracking error: {str(e)}") |
|
return [] |
|
|
|
class FaceNetEmbedder: |
|
def __init__(self, device: str = 'cpu'): |
|
self.device = device |
|
self.model = InceptionResnetV1(pretrained='vggface2').eval().to(device) |
|
self.transform = transforms.Compose([ |
|
transforms.Resize((160, 160)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5]*3, [0.5]*3), |
|
]) |
|
|
|
def get_embedding(self, face_bgr: np.ndarray) -> Optional[np.ndarray]: |
|
try: |
|
face_rgb = cv2.cvtColor(face_bgr, cv2.COLOR_BGR2RGB) |
|
pil_img = Image.fromarray(face_rgb).convert('RGB') |
|
tens = self.transform(pil_img).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
embedding = self.model(tens)[0].cpu().numpy() |
|
logger.debug(f"Generated embedding sample: {embedding[:5]}...") |
|
return embedding |
|
except Exception as e: |
|
logger.error(f"Embedding failed: {str(e)}") |
|
return None |
|
|
|
def detect_blink(face_roi: np.ndarray, threshold: float = 0.25) -> Tuple[bool, float, float, Optional[np.ndarray], Optional[np.ndarray]]: |
|
""" |
|
Returns: |
|
(blink_bool, left_ear, right_ear, left_eye_points, right_eye_points). |
|
""" |
|
try: |
|
face_mesh_proc = mp_face_mesh.FaceMesh( |
|
static_image_mode=True, |
|
max_num_faces=1, |
|
refine_landmarks=True, |
|
min_detection_confidence=0.5 |
|
) |
|
result = face_mesh_proc.process(cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB)) |
|
face_mesh_proc.close() |
|
|
|
if not result.multi_face_landmarks: |
|
return False, 0.0, 0.0, None, None |
|
|
|
landmarks = result.multi_face_landmarks[0].landmark |
|
h, w = face_roi.shape[:2] |
|
|
|
def eye_aspect_ratio(indices): |
|
pts = [(landmarks[i].x * w, landmarks[i].y * h) for i in indices] |
|
vertical = np.linalg.norm(np.array(pts[1]) - np.array(pts[5])) + \ |
|
np.linalg.norm(np.array(pts[2]) - np.array(pts[4])) |
|
horizontal = np.linalg.norm(np.array(pts[0]) - np.array(pts[3])) |
|
return vertical / (2.0 * horizontal + 1e-6) |
|
|
|
left_ear = eye_aspect_ratio(LEFT_EYE_IDX) |
|
right_ear = eye_aspect_ratio(RIGHT_EYE_IDX) |
|
|
|
blink = (left_ear < threshold) and (right_ear < threshold) |
|
|
|
left_eye_pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in LEFT_EYE_IDX]) |
|
right_eye_pts = np.array([(int(landmarks[i].x * w), int(landmarks[i].y * h)) for i in RIGHT_EYE_IDX]) |
|
|
|
return blink, left_ear, right_ear, left_eye_pts, right_eye_pts |
|
|
|
except Exception as e: |
|
logger.error(f"Blink detection error: {str(e)}") |
|
return False, 0.0, 0.0, None, None |
|
|
|
def process_face_mesh(face_roi: np.ndarray): |
|
try: |
|
fm_proc = mp_face_mesh.FaceMesh( |
|
static_image_mode=True, |
|
max_num_faces=1, |
|
refine_landmarks=True, |
|
min_detection_confidence=0.5 |
|
) |
|
result = fm_proc.process(cv2.cvtColor(face_roi, cv2.COLOR_BGR2RGB)) |
|
fm_proc.close() |
|
if result.multi_face_landmarks: |
|
return result.multi_face_landmarks[0] |
|
return None |
|
except Exception as e: |
|
logger.error(f"Face mesh error: {str(e)}") |
|
return None |
|
|
|
def draw_face_mesh(image: np.ndarray, face_landmarks, config: Dict, pipeline_config: PipelineConfig): |
|
mesh_color_bgr = pipeline_config.mesh_color[::-1] |
|
contour_color_bgr = pipeline_config.contour_color[::-1] |
|
iris_color_bgr = pipeline_config.iris_color[::-1] |
|
|
|
if config.get('tesselation'): |
|
mp_drawing.draw_landmarks( |
|
image, |
|
face_landmarks, |
|
mp_face_mesh.FACEMESH_TESSELATION, |
|
landmark_drawing_spec=mp_drawing.DrawingSpec(color=mesh_color_bgr, thickness=1, circle_radius=1), |
|
connection_drawing_spec=mp_drawing.DrawingSpec(color=mesh_color_bgr, thickness=1), |
|
) |
|
if config.get('contours'): |
|
mp_drawing.draw_landmarks( |
|
image, |
|
face_landmarks, |
|
mp_face_mesh.FACEMESH_CONTOURS, |
|
landmark_drawing_spec=None, |
|
connection_drawing_spec=mp_drawing.DrawingSpec(color=contour_color_bgr, thickness=2) |
|
) |
|
if config.get('irises'): |
|
mp_drawing.draw_landmarks( |
|
image, |
|
face_landmarks, |
|
mp_face_mesh.FACEMESH_IRISES, |
|
landmark_drawing_spec=None, |
|
connection_drawing_spec=mp_drawing.DrawingSpec(color=iris_color_bgr, thickness=2) |
|
) |
|
|
|
EYE_COLOR_RANGES = { |
|
"amber": (255, 191, 0), |
|
"blue": (0, 0, 255), |
|
"brown": (139, 69, 19), |
|
"green": (0, 128, 0), |
|
"gray": (128, 128, 128), |
|
"hazel": (102, 51, 0), |
|
} |
|
|
|
def classify_eye_color(rgb_color: Tuple[int,int,int]) -> str: |
|
if rgb_color is None: |
|
return "Unknown" |
|
min_dist = float('inf') |
|
best = "Unknown" |
|
for color_name, ref_rgb in EYE_COLOR_RANGES.items(): |
|
dist = math.sqrt(sum([(a-b)**2 for a,b in zip(rgb_color, ref_rgb)])) |
|
if dist < min_dist: |
|
min_dist = dist |
|
best = color_name |
|
return best |
|
|
|
def get_dominant_color(image_roi, k=3): |
|
if image_roi.size == 0: |
|
return None |
|
pixels = np.float32(image_roi.reshape(-1, 3)) |
|
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.1) |
|
_, labels, palette = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS) |
|
_, counts = np.unique(labels, return_counts=True) |
|
dom_color = tuple(palette[np.argmax(counts)].astype(int).tolist()) |
|
return dom_color |
|
|
|
def detect_eye_color(face_roi: np.ndarray, face_landmarks) -> Optional[str]: |
|
if face_landmarks is None: |
|
return None |
|
h, w = face_roi.shape[:2] |
|
iris_inds = set() |
|
for conn in mp_face_mesh.FACEMESH_IRISES: |
|
iris_inds.update(conn) |
|
|
|
iris_points = [] |
|
for idx in iris_inds: |
|
lm = face_landmarks.landmark[idx] |
|
iris_points.append((int(lm.x * w), int(lm.y * h))) |
|
if not iris_points: |
|
return None |
|
|
|
min_x = min(pt[0] for pt in iris_points) |
|
max_x = max(pt[0] for pt in iris_points) |
|
min_y = min(pt[1] for pt in iris_points) |
|
max_y = max(pt[1] for pt in iris_points) |
|
|
|
pad = 5 |
|
x1 = max(0, min_x - pad) |
|
y1 = max(0, min_y - pad) |
|
x2 = min(w, max_x + pad) |
|
y2 = min(h, max_y + pad) |
|
|
|
eye_roi = face_roi[y1:y2, x1:x2] |
|
eye_roi_resize = cv2.resize(eye_roi, (40, 40), interpolation=cv2.INTER_AREA) |
|
|
|
if eye_roi_resize.size == 0: |
|
return None |
|
|
|
dom_rgb = get_dominant_color(eye_roi_resize) |
|
if dom_rgb is not None: |
|
return classify_eye_color(dom_rgb) |
|
return None |
|
|
|
class HandTracker: |
|
def __init__(self, min_detection_confidence=0.5, min_tracking_confidence=0.5): |
|
self.hands = mp_hands.Hands( |
|
static_image_mode=True, |
|
max_num_hands=2, |
|
min_detection_confidence=min_detection_confidence, |
|
min_tracking_confidence=min_tracking_confidence, |
|
) |
|
logger.info("Initialized Mediapipe HandTracking") |
|
|
|
def detect_hands(self, image: np.ndarray): |
|
try: |
|
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
results = self.hands.process(img_rgb) |
|
return results.multi_hand_landmarks, results.multi_handedness |
|
except Exception as e: |
|
logger.error(f"Hand detection error: {str(e)}") |
|
return None, None |
|
|
|
def draw_hands(self, image: np.ndarray, hand_landmarks, handedness, config: Dict): |
|
if not hand_landmarks: |
|
return image |
|
|
|
for i, hlms in enumerate(hand_landmarks): |
|
hl_color = config.hand_landmark_color[::-1] |
|
hc_color = config.hand_connection_color[::-1] |
|
mp_drawing.draw_landmarks( |
|
image, |
|
hlms, |
|
mp_hands.HAND_CONNECTIONS, |
|
mp_drawing.DrawingSpec(color=hl_color, thickness=2, circle_radius=4), |
|
mp_drawing.DrawingSpec(color=hc_color, thickness=2, circle_radius=2), |
|
) |
|
if handedness and i < len(handedness): |
|
label = handedness[i].classification[0].label |
|
score = handedness[i].classification[0].score |
|
text = f"{label}: {score:.2f}" |
|
|
|
wrist_lm = hlms.landmark[mp_hands.HandLandmark.WRIST] |
|
h, w_img, _ = image.shape |
|
cx, cy = int(wrist_lm.x * w_img), int(wrist_lm.y * h) |
|
ht_color = config.hand_text_color[::-1] |
|
cv2.putText(image, text, (cx, cy - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, ht_color, 2) |
|
return image |
|
|
|
class FacePipeline: |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
self.detector = None |
|
self.tracker = None |
|
self.facenet = None |
|
self.db = None |
|
self.hand_tracker = None |
|
self._initialized = False |
|
|
|
def initialize(self): |
|
try: |
|
self.detector = YOLOFaceDetector( |
|
model_path=self.config.detector['model_path'], |
|
device=self.config.detector['device'] |
|
) |
|
self.tracker = FaceTracker(max_age=self.config.tracker['max_age']) |
|
self.facenet = FaceNetEmbedder(device=self.config.detector['device']) |
|
self.db = FaceDatabase() |
|
|
|
if self.config.hand['enable']: |
|
self.hand_tracker = HandTracker( |
|
min_detection_confidence=self.config.hand['min_detection_confidence'], |
|
min_tracking_confidence=self.config.hand['min_tracking_confidence'] |
|
) |
|
|
|
self._initialized = True |
|
logger.info("FacePipeline initialized successfully.") |
|
except Exception as e: |
|
logger.error(f"Initialization failed: {str(e)}") |
|
self._initialized = False |
|
raise |
|
|
|
def process_frame(self, frame: np.ndarray) -> Tuple[np.ndarray, List[Dict]]: |
|
""" |
|
Main pipeline processing: detection, tracking, hand detection, face mesh, blink detection, etc. |
|
Returns annotated_frame, detection_results. |
|
""" |
|
if not self._initialized: |
|
logger.error("Pipeline not initialized.") |
|
return frame, [] |
|
|
|
try: |
|
detections = self.detector.detect(frame, self.config.detection_conf_thres) |
|
tracked_objs = self.tracker.update(detections, frame) |
|
annotated = frame.copy() |
|
results = [] |
|
|
|
|
|
hand_landmarks_list = None |
|
handedness_list = None |
|
if self.config.hand['enable'] and self.hand_tracker: |
|
hand_landmarks_list, handedness_list = self.hand_tracker.detect_hands(annotated) |
|
annotated = self.hand_tracker.draw_hands( |
|
annotated, hand_landmarks_list, handedness_list, self.config |
|
) |
|
|
|
for obj in tracked_objs: |
|
if not obj.is_confirmed(): |
|
continue |
|
|
|
track_id = obj.track_id |
|
bbox = obj.to_tlbr().astype(int) |
|
x1, y1, x2, y2 = bbox |
|
conf = getattr(obj, 'score', 1.0) |
|
cls = getattr(obj, 'class_id', 0) |
|
|
|
face_roi = frame[y1:y2, x1:x2] |
|
if face_roi.size == 0: |
|
logger.warning(f"Empty face ROI for track={track_id}") |
|
continue |
|
|
|
|
|
is_spoofed = False |
|
if self.config.anti_spoof.get('enable', True): |
|
is_spoofed = not self.is_real_face(face_roi) |
|
if is_spoofed: |
|
cls = 1 |
|
|
|
if is_spoofed: |
|
box_color_bgr = self.config.spoofed_bbox_color[::-1] |
|
name = "Spoofed" |
|
similarity = 0.0 |
|
else: |
|
|
|
emb = self.facenet.get_embedding(face_roi) |
|
if emb is not None and self.config.recognition.get('enable', True): |
|
name, similarity = self.recognize_face(emb, self.config.recognition_conf_thres) |
|
else: |
|
name = "Unknown" |
|
similarity = 0.0 |
|
|
|
box_color_rgb = (self.config.bbox_color if name != "Unknown" |
|
else self.config.unknown_bbox_color) |
|
box_color_bgr = box_color_rgb[::-1] |
|
|
|
label_text = name |
|
cv2.rectangle(annotated, (x1, y1), (x2, y2), box_color_bgr, 2) |
|
cv2.putText(annotated, label_text, (x1, y1 - 10), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, box_color_bgr, 2) |
|
|
|
|
|
blink = False |
|
if self.config.blink.get('enable', False): |
|
blink, left_ear, right_ear, left_eye_pts, right_eye_pts = detect_blink( |
|
face_roi, threshold=self.config.blink.get('ear_thresh', 0.25) |
|
) |
|
if left_eye_pts is not None and right_eye_pts is not None: |
|
le_g = left_eye_pts + np.array([x1, y1]) |
|
re_g = right_eye_pts + np.array([x1, y1]) |
|
|
|
eye_outline_bgr = self.config.eye_outline_color[::-1] |
|
cv2.polylines(annotated, [le_g], True, eye_outline_bgr, 1) |
|
cv2.polylines(annotated, [re_g], True, eye_outline_bgr, 1) |
|
if blink: |
|
blink_msg_color = self.config.blink_text_color[::-1] |
|
cv2.putText(annotated, "Blink Detected", |
|
(x1, y2 + 20), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, |
|
blink_msg_color, 2) |
|
|
|
|
|
face_mesh_landmarks = None |
|
eye_color_name = None |
|
if (self.config.face_mesh_options.get('enable') or |
|
self.config.eye_color.get('enable')): |
|
face_mesh_landmarks = process_face_mesh(face_roi) |
|
if face_mesh_landmarks: |
|
|
|
if self.config.face_mesh_options.get('enable', False): |
|
draw_face_mesh( |
|
annotated[y1:y2, x1:x2], |
|
face_mesh_landmarks, |
|
self.config.face_mesh_options, |
|
self.config |
|
) |
|
|
|
|
|
if self.config.eye_color.get('enable', False): |
|
color_found = detect_eye_color(face_roi, face_mesh_landmarks) |
|
if color_found: |
|
eye_color_name = color_found |
|
text_col_bgr = self.config.eye_color_text_color[::-1] |
|
cv2.putText( |
|
annotated, f"Eye Color: {eye_color_name}", |
|
(x1, y2 + 40), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, |
|
text_col_bgr, 2 |
|
) |
|
|
|
detection_info = { |
|
"track_id": track_id, |
|
"bbox": (x1, y1, x2, y2), |
|
"confidence": float(conf), |
|
"class_id": cls, |
|
"name": name, |
|
"similarity": similarity, |
|
"blink": blink if self.config.blink.get('enable') else None, |
|
"face_mesh": bool(face_mesh_landmarks) if self.config.face_mesh_options.get('enable') else False, |
|
"hands_detected": bool(hand_landmarks_list), |
|
"hand_count": len(hand_landmarks_list) if hand_landmarks_list else 0, |
|
"eye_color": eye_color_name if self.config.eye_color.get('enable') else None |
|
} |
|
results.append(detection_info) |
|
|
|
return annotated, results |
|
|
|
except Exception as e: |
|
logger.error(f"Frame process error: {str(e)}") |
|
return frame, [] |
|
|
|
def is_real_face(self, face_roi: np.ndarray) -> bool: |
|
try: |
|
gray = cv2.cvtColor(face_roi, cv2.COLOR_BGR2GRAY) |
|
lapv = cv2.Laplacian(gray, cv2.CV_64F).var() |
|
return lapv > self.config.anti_spoof.get('lap_thresh', 80.0) |
|
except Exception as e: |
|
logger.error(f"Anti-spoof error: {str(e)}") |
|
return False |
|
|
|
def recognize_face(self, embedding: np.ndarray, threshold: float) -> Tuple[str, float]: |
|
try: |
|
best_name = "Unknown" |
|
best_sim = 0.0 |
|
for lbl, embs in self.db.embeddings.items(): |
|
for db_emb in embs: |
|
sim = FacePipeline.cosine_similarity(embedding, db_emb) |
|
if sim > best_sim: |
|
best_sim = sim |
|
best_name = lbl |
|
if best_sim < threshold: |
|
best_name = "Unknown" |
|
return best_name, best_sim |
|
except Exception as e: |
|
logger.error(f"Recognition error: {str(e)}") |
|
return ("Unknown", 0.0) |
|
|
|
@staticmethod |
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: |
|
return float(np.dot(a, b) / ((np.linalg.norm(a)*np.linalg.norm(b)) + 1e-6)) |
|
|
|
pipeline = None |
|
def load_pipeline() -> FacePipeline: |
|
"""Global pipeline loader. Creates if not exists, or returns existing one.""" |
|
global pipeline |
|
if pipeline is None: |
|
cfg = PipelineConfig.load(CONFIG_PATH) |
|
pipeline = FacePipeline(cfg) |
|
pipeline.initialize() |
|
return pipeline |
|
|
|
def hex_to_bgr(hexstr: str) -> Tuple[int,int,int]: |
|
if not hexstr.startswith('#'): |
|
hexstr = '#' + hexstr |
|
h = hexstr.lstrip('#') |
|
if len(h) != 6: |
|
return (255, 0, 0) |
|
r = int(h[0:2], 16) |
|
g = int(h[2:4], 16) |
|
b = int(h[4:6], 16) |
|
return (b,g,r) |
|
|
|
def bgr_to_hex(bgr: Tuple[int,int,int]) -> str: |
|
b,g,r = bgr |
|
return f"#{r:02x}{g:02x}{b:02x}" |
|
|
|
def update_config( |
|
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh, |
|
show_tesselation, show_contours, show_irises, |
|
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, hand_det_conf, hand_track_conf, |
|
bbox_hex, spoofed_hex, unknown_hex, eye_hex, blink_hex, |
|
hand_landmark_hex, hand_connect_hex, hand_text_hex, |
|
mesh_hex, contour_hex, iris_hex, eye_color_text_hex |
|
): |
|
pl = load_pipeline() |
|
cfg = pl.config |
|
|
|
cfg.recognition['enable'] = enable_recognition |
|
cfg.anti_spoof['enable'] = enable_antispoof |
|
cfg.blink['enable'] = enable_blink |
|
cfg.hand['enable'] = enable_hand |
|
cfg.eye_color['enable'] = enable_eyecolor |
|
cfg.face_mesh_options['enable'] = enable_facemesh |
|
|
|
cfg.face_mesh_options['tesselation'] = show_tesselation |
|
cfg.face_mesh_options['contours'] = show_contours |
|
cfg.face_mesh_options['irises'] = show_irises |
|
|
|
cfg.detection_conf_thres = detection_conf |
|
cfg.recognition_conf_thres = recognition_thresh |
|
cfg.anti_spoof['lap_thresh'] = antispoof_thresh |
|
cfg.blink['ear_thresh'] = blink_thresh |
|
cfg.hand['min_detection_confidence'] = hand_det_conf |
|
cfg.hand['min_tracking_confidence'] = hand_track_conf |
|
|
|
cfg.bbox_color = hex_to_bgr(bbox_hex)[::-1] |
|
cfg.spoofed_bbox_color = hex_to_bgr(spoofed_hex)[::-1] |
|
cfg.unknown_bbox_color = hex_to_bgr(unknown_hex)[::-1] |
|
cfg.eye_outline_color = hex_to_bgr(eye_hex)[::-1] |
|
cfg.blink_text_color = hex_to_bgr(blink_hex)[::-1] |
|
cfg.hand_landmark_color = hex_to_bgr(hand_landmark_hex)[::-1] |
|
cfg.hand_connection_color = hex_to_bgr(hand_connect_hex)[::-1] |
|
cfg.hand_text_color = hex_to_bgr(hand_text_hex)[::-1] |
|
cfg.mesh_color = hex_to_bgr(mesh_hex)[::-1] |
|
cfg.contour_color = hex_to_bgr(contour_hex)[::-1] |
|
cfg.iris_color = hex_to_bgr(iris_hex)[::-1] |
|
cfg.eye_color_text_color = hex_to_bgr(eye_color_text_hex)[::-1] |
|
|
|
cfg.save(CONFIG_PATH) |
|
logger.info("Configuration updated with:") |
|
logger.info(f"Recognition Enabled: {enable_recognition}") |
|
logger.info(f"Anti-spoof Enabled: {enable_antispoof}") |
|
logger.info(f"Blink Enabled: {enable_blink}") |
|
logger.info(f"Face Mesh Enabled: {enable_facemesh}, Tesselation: {show_tesselation}, Contours: {show_contours}, Irises: {show_irises}") |
|
logger.info(f"Thresholds - Detection Conf: {detection_conf}, Recognition: {recognition_thresh}, Anti-spoof: {antispoof_thresh}, Blink: {blink_thresh}, Hand Det Conf: {hand_det_conf}, Hand Track Conf: {hand_track_conf}") |
|
logger.info(f"Colors - BBox: {bbox_hex}, Spoofed: {spoofed_hex}, Unknown: {unknown_hex}, Eye Outline: {eye_hex}, Blink Text: {blink_hex}, Hand Landmark: {hand_landmark_hex}, Hand Connect: {hand_connect_hex}, Hand Text: {hand_text_hex}, Mesh: {mesh_hex}, Contour: {contour_hex}, Iris: {iris_hex}, Eye Color Text: {eye_color_text_hex}") |
|
|
|
|
|
return "Configuration saved successfully!" |
|
|
|
def enroll_user(label_name: str, files: List[bytes]) -> str: |
|
"""Enrolls a user by name using multiple uploaded image files.""" |
|
pl = load_pipeline() |
|
if not label_name: |
|
return "Please provide a user name." |
|
if not files or len(files) == 0: |
|
return "No images provided." |
|
|
|
enrolled_count = 0 |
|
for file_bytes in files: |
|
if not file_bytes: |
|
continue |
|
try: |
|
img_array = np.frombuffer(file_bytes, np.uint8) |
|
img_bgr = cv2.imdecode(img_array, cv2.IMREAD_COLOR) |
|
if img_bgr is None: |
|
continue |
|
|
|
dets = pl.detector.detect(img_bgr, pl.config.detection_conf_thres) |
|
for x1, y1, x2, y2, conf, cls in dets: |
|
roi = img_bgr[y1:y2, x1:x2] |
|
if roi.size == 0: |
|
continue |
|
emb = pl.facenet.get_embedding(roi) |
|
if emb is not None: |
|
pl.db.add_embedding(label_name, emb) |
|
enrolled_count += 1 |
|
except Exception as e: |
|
logger.error(f"Error enrolling user from file: {str(e)}") |
|
continue |
|
|
|
if enrolled_count > 0: |
|
pl.db.save() |
|
return f"Enrolled '{label_name}' with {enrolled_count} face(s)!" |
|
else: |
|
return "No faces detected in provided images." |
|
|
|
def search_by_name(name: str) -> str: |
|
pl = load_pipeline() |
|
if not name: |
|
return "No name entered." |
|
embs = pl.db.get_embeddings_by_label(name) |
|
if embs: |
|
return f"'{name}' found with {len(embs)} embedding(s)." |
|
else: |
|
return f"No embeddings found for '{name}'." |
|
|
|
def search_by_image(img: np.ndarray) -> str: |
|
pl = load_pipeline() |
|
if img is None: |
|
return "No image uploaded." |
|
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
dets = pl.detector.detect(img_bgr, pl.config.detection_conf_thres) |
|
if not dets: |
|
return "No faces detected in the uploaded image." |
|
x1, y1, x2, y2, conf, cls = dets[0] |
|
roi = img_bgr[y1:y2, x1:x2] |
|
if roi.size == 0: |
|
return "Empty face ROI in the uploaded image." |
|
|
|
emb = pl.facenet.get_embedding(roi) |
|
if emb is None: |
|
return "Could not generate embedding from face." |
|
results = pl.db.search_by_image(emb, pl.config.recognition_conf_thres) |
|
if not results: |
|
return "No matches in the database under current threshold." |
|
lines = [f"- {lbl} (sim={sim:.3f})" for lbl, sim in results] |
|
return "Search results:\n" + "\n".join(lines) |
|
|
|
def remove_user(label: str) -> str: |
|
pl = load_pipeline() |
|
if not label: |
|
return "No user label selected." |
|
pl.db.remove_label(label) |
|
pl.db.save() |
|
return f"User '{label}' removed." |
|
|
|
def list_users() -> str: |
|
pl = load_pipeline() |
|
labels = pl.db.list_labels() |
|
if labels: |
|
return "Enrolled users:\n" + ", ".join(labels) |
|
return "No users enrolled." |
|
|
|
def process_test_image(img: np.ndarray) -> Tuple[np.ndarray, str]: |
|
"""Single-image test: run pipeline and return annotated image + JSON results.""" |
|
if img is None: |
|
return None, "No image uploaded." |
|
pl = load_pipeline() |
|
bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|
processed, detections = pl.process_frame(bgr) |
|
result_rgb = cv2.cvtColor(processed, cv2.COLOR_BGR2RGB) |
|
return result_rgb, str(detections) |
|
|
|
|
|
|
|
|
|
def export_all_file() -> str: |
|
""" |
|
Exports both the pipeline config and database embeddings into a single |
|
pickle file. Returns the file path for Gradio to handle the download. |
|
""" |
|
pl = load_pipeline() |
|
combined_data = { |
|
"config": pl.config.__dict__, |
|
"database": pl.db.embeddings |
|
} |
|
|
|
|
|
buf = io.BytesIO() |
|
pickle.dump(combined_data, buf) |
|
buf_bytes = buf.getvalue() |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file: |
|
tmp_file.write(buf_bytes) |
|
temp_path = tmp_file.name |
|
return temp_path |
|
|
|
def import_all_file(file_bytes: bytes, merge_db: bool = True) -> str: |
|
""" |
|
Imports a single pickle file containing both the config and database. |
|
If merge_db=False, overwrites the existing DB; otherwise merges. |
|
""" |
|
if not file_bytes: |
|
return "No file provided." |
|
|
|
try: |
|
|
|
buf = io.BytesIO(file_bytes) |
|
combined_data = pickle.load(buf) |
|
|
|
if not isinstance(combined_data, dict): |
|
return "Invalid combined data format." |
|
|
|
|
|
new_cfg_data = combined_data.get("config", {}) |
|
new_cfg = PipelineConfig(**new_cfg_data) |
|
|
|
|
|
new_db_data = combined_data.get("database", {}) |
|
|
|
|
|
global pipeline |
|
pipeline = FacePipeline(new_cfg) |
|
pipeline.initialize() |
|
|
|
|
|
if merge_db: |
|
|
|
for label, emb_list in new_db_data.items(): |
|
if label not in pipeline.db.embeddings: |
|
pipeline.db.embeddings[label] = [] |
|
pipeline.db.embeddings[label].extend(emb_list) |
|
else: |
|
|
|
pipeline.db.embeddings = new_db_data |
|
|
|
pipeline.db.save() |
|
|
|
return "Config and database imported successfully!" |
|
|
|
except Exception as e: |
|
logger.error(f"Import all failed: {str(e)}") |
|
return f"Import failed: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_config_file() -> str: |
|
"""Export the current pipeline config as a downloadable file.""" |
|
pl = load_pipeline() |
|
config_bytes = pl.config.export_config() |
|
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file: |
|
tmp_file.write(config_bytes) |
|
temp_path = tmp_file.name |
|
return temp_path |
|
|
|
def import_config_file(file_bytes: bytes) -> str: |
|
"""Import a pipeline config from uploaded bytes and re-initialize pipeline.""" |
|
if not file_bytes: |
|
return "No file provided." |
|
try: |
|
new_cfg = PipelineConfig.import_config(file_bytes) |
|
pl = FacePipeline(new_cfg) |
|
pl.initialize() |
|
global pipeline |
|
pipeline = pl |
|
return f"Imported config successfully!" |
|
except Exception as e: |
|
logger.error(f"Import config failed: {str(e)}") |
|
return f"Import failed: {str(e)}" |
|
|
|
def export_db_file() -> str: |
|
"""Export the current face database as a downloadable file.""" |
|
pl = load_pipeline() |
|
db_bytes = pl.db.export_database() |
|
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as tmp_file: |
|
tmp_file.write(db_bytes) |
|
temp_path = tmp_file.name |
|
return temp_path |
|
|
|
def import_db_file(db_bytes: bytes, merge: bool=True) -> str: |
|
"""Import face database from uploaded bytes. Merge or overwrite existing.""" |
|
if not db_bytes: |
|
return "No file provided." |
|
try: |
|
pl = load_pipeline() |
|
pl.db.import_database(db_bytes, merge=merge) |
|
return f"Database imported successfully, merge={merge}" |
|
except Exception as e: |
|
logger.error(f"Import DB failed: {str(e)}") |
|
return f"Import DB failed: {str(e)}" |
|
|
|
|
|
def build_app(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# FaceRec: Comprehensive Face Recognition Pipeline") |
|
gr.Markdown("**Note:** After downloading, please rename the file to its appropriate extension (e.g., `config_export.pkl`, `database_export.pkl`).") |
|
|
|
with gr.Tab("Image Test"): |
|
gr.Markdown("Upload a single image to detect faces, run blink detection, face mesh, hand tracking, etc.") |
|
test_in = gr.Image(type="numpy", label="Upload Image") |
|
test_out = gr.Image() |
|
test_info = gr.Textbox(label="Detections") |
|
process_btn = gr.Button("Process Image") |
|
|
|
process_btn.click( |
|
fn=process_test_image, |
|
inputs=test_in, |
|
outputs=[test_out, test_info], |
|
) |
|
|
|
with gr.Tab("Configuration"): |
|
gr.Markdown("Adjust toggles, thresholds, and colors. Click Save to persist changes.") |
|
|
|
with gr.Row(): |
|
enable_recognition = gr.Checkbox(label="Enable Recognition", value=True) |
|
enable_antispoof = gr.Checkbox(label="Enable Anti-Spoof", value=True) |
|
enable_blink = gr.Checkbox(label="Enable Blink Detection", value=True) |
|
enable_hand = gr.Checkbox(label="Enable Hand Tracking", value=True) |
|
enable_eyecolor = gr.Checkbox(label="Enable Eye Color Detection", value=False) |
|
enable_facemesh = gr.Checkbox(label="Enable Face Mesh", value=False) |
|
|
|
gr.Markdown("**Face Mesh Options**") |
|
with gr.Row(): |
|
show_tesselation = gr.Checkbox(label="Tesselation", value=False) |
|
show_contours = gr.Checkbox(label="Contours", value=False) |
|
show_irises = gr.Checkbox(label="Irises", value=False) |
|
|
|
gr.Markdown("**Thresholds**") |
|
detection_conf = gr.Slider(0, 1, 0.4, step=0.01, label="Detection Confidence") |
|
recognition_thresh = gr.Slider(0.5, 1.0, 0.85, step=0.01, label="Recognition Threshold") |
|
antispoof_thresh = gr.Slider(0, 200, 80, step=1, label="Anti-Spoof Threshold") |
|
blink_thresh = gr.Slider(0, 0.5, 0.25, step=0.01, label="Blink EAR Threshold") |
|
hand_det_conf = gr.Slider(0, 1, 0.5, step=0.01, label="Hand Detection Confidence") |
|
hand_track_conf = gr.Slider(0, 1, 0.5, step=0.01, label="Hand Tracking Confidence") |
|
|
|
gr.Markdown("**Color Options (Hex)**") |
|
bbox_hex = gr.Textbox(label="Box Color (Recognized)", value="#00ff00") |
|
spoofed_hex = gr.Textbox(label="Box Color (Spoofed)", value="#ff0000") |
|
unknown_hex = gr.Textbox(label="Box Color (Unknown)", value="#ff0000") |
|
eye_hex = gr.Textbox(label="Eye Outline Color", value="#ffff00") |
|
blink_hex = gr.Textbox(label="Blink Text Color", value="#0000ff") |
|
|
|
hand_landmark_hex = gr.Textbox(label="Hand Landmark Color", value="#ffd24d") |
|
hand_connect_hex = gr.Textbox(label="Hand Connection Color", value="#cc6600") |
|
hand_text_hex = gr.Textbox(label="Hand Text Color", value="#ffffff") |
|
|
|
mesh_hex = gr.Textbox(label="Mesh Color", value="#64ff64") |
|
contour_hex = gr.Textbox(label="Contour Color", value="#c8c800") |
|
iris_hex = gr.Textbox(label="Iris Color", value="#ff00ff") |
|
eye_color_text_hex = gr.Textbox(label="Eye Color Text Color", value="#ffffff") |
|
|
|
save_btn = gr.Button("Save Configuration") |
|
save_msg = gr.Textbox(label="", interactive=False) |
|
|
|
save_btn.click( |
|
fn=update_config, |
|
inputs=[ |
|
enable_recognition, enable_antispoof, enable_blink, enable_hand, enable_eyecolor, enable_facemesh, |
|
show_tesselation, show_contours, show_irises, |
|
detection_conf, recognition_thresh, antispoof_thresh, blink_thresh, hand_det_conf, hand_track_conf, |
|
bbox_hex, spoofed_hex, unknown_hex, eye_hex, blink_hex, |
|
hand_landmark_hex, hand_connect_hex, hand_text_hex, |
|
mesh_hex, contour_hex, iris_hex, eye_color_text_hex |
|
], |
|
outputs=[save_msg] |
|
) |
|
|
|
with gr.Tab("Database Management"): |
|
gr.Markdown("Enroll multiple images per user, search by name or image, remove users, list all users.") |
|
|
|
with gr.Accordion("User Enrollment", open=False): |
|
enroll_name = gr.Textbox(label="User Name") |
|
enroll_paths = gr.File(file_count="multiple", type="binary", label="Upload Multiple Images") |
|
enroll_btn = gr.Button("Enroll User") |
|
enroll_result = gr.Textbox() |
|
|
|
enroll_btn.click( |
|
fn=enroll_user, |
|
inputs=[enroll_name, enroll_paths], |
|
outputs=[enroll_result] |
|
) |
|
|
|
with gr.Accordion("User Search", open=False): |
|
search_mode = gr.Radio(["Name", "Image"], label="Search By", value="Name") |
|
search_name_box = gr.Dropdown(label="Select User", choices=[], value=None, visible=True) |
|
search_image_box = gr.Image(label="Upload Search Image", type="numpy", visible=False) |
|
search_btn = gr.Button("Search") |
|
search_out = gr.Textbox() |
|
|
|
def toggle_search(mode): |
|
if mode == "Name": |
|
return gr.update(visible=True), gr.update(visible=False) |
|
else: |
|
return gr.update(visible=False), gr.update(visible=True) |
|
|
|
search_mode.change( |
|
fn=toggle_search, |
|
inputs=[search_mode], |
|
outputs=[search_name_box, search_image_box] |
|
) |
|
|
|
def do_search(mode, uname, img): |
|
if mode == "Name": |
|
return search_by_name(uname) |
|
else: |
|
return search_by_image(img) |
|
|
|
search_btn.click( |
|
fn=do_search, |
|
inputs=[search_mode, search_name_box, search_image_box], |
|
outputs=[search_out] |
|
) |
|
|
|
with gr.Accordion("User Management Tools", open=False): |
|
list_btn = gr.Button("List Enrolled Users") |
|
list_out = gr.Textbox() |
|
list_btn.click(fn=lambda: list_users(), inputs=[], outputs=[list_out]) |
|
|
|
def refresh_choices(): |
|
pl = load_pipeline() |
|
return gr.update(choices=pl.db.list_labels()) |
|
|
|
refresh_btn = gr.Button("Refresh User List") |
|
refresh_btn.click(fn=refresh_choices, inputs=[], outputs=[search_name_box]) |
|
|
|
remove_box = gr.Dropdown(label="Select User to Remove", choices=[]) |
|
remove_btn = gr.Button("Remove") |
|
remove_out = gr.Textbox() |
|
|
|
remove_btn.click(fn=remove_user, inputs=[remove_box], outputs=[remove_out]) |
|
refresh_btn.click(fn=refresh_choices, inputs=[], outputs=[remove_box]) |
|
|
|
with gr.Tab("Export / Import"): |
|
gr.Markdown("Export or import pipeline config (thresholds/colors) or face database (embeddings).") |
|
gr.Markdown("**Note:** After downloading, please rename the file to its appropriate extension (e.g., `config_export.pkl`, `database_export.pkl`).") |
|
|
|
gr.Markdown("**Export Individually (Download)**") |
|
export_config_btn = gr.Button("Export Config") |
|
export_config_download = gr.File(label="Download Config Export", type="binary") |
|
|
|
export_db_btn = gr.Button("Export Database") |
|
export_db_download = gr.File(label="Download Database Export", type="binary") |
|
|
|
export_config_btn.click(fn=export_config_file, inputs=[], outputs=[export_config_download]) |
|
export_db_btn.click(fn=export_db_file, inputs=[], outputs=[export_db_download]) |
|
|
|
gr.Markdown("**Import Individually (Upload)**") |
|
import_config_filebox = gr.File(label="Import Config File", file_count="single", type="binary") |
|
import_config_btn = gr.Button("Import Config") |
|
import_config_out = gr.Textbox() |
|
|
|
import_db_filebox = gr.File(label="Import Database File", file_count="single", type="binary") |
|
merge_db_checkbox = gr.Checkbox(label="Merge instead of overwrite?", value=True) |
|
import_db_btn = gr.Button("Import Database") |
|
import_db_out = gr.Textbox() |
|
|
|
import_config_btn.click(fn=import_config_file, inputs=[import_config_filebox], outputs=[import_config_out]) |
|
import_db_btn.click(fn=import_db_file, inputs=[import_db_filebox, merge_db_checkbox], outputs=[import_db_out]) |
|
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
gr.Markdown("**Export & Import Everything (Config + Database) Together**") |
|
gr.Markdown("**Note:** After downloading, please rename the file to `pipeline_export.pkl`.") |
|
|
|
|
|
export_all_btn = gr.Button("Export All (Config + DB)") |
|
export_all_download = gr.File(label="Download Combined Export", type="binary") |
|
|
|
export_all_btn.click( |
|
fn=export_all_file, |
|
outputs=[export_all_download], |
|
inputs=[] |
|
) |
|
|
|
|
|
import_all_in = gr.File(label="Import Combined File (Pickle)", file_count="single", type="binary") |
|
import_all_merge_cb = gr.Checkbox(label="Merge DB instead of overwrite?", value=True) |
|
import_all_btn = gr.Button("Import All") |
|
import_all_out = gr.Textbox() |
|
|
|
import_all_btn.click( |
|
fn=import_all_file, |
|
inputs=[import_all_in, import_all_merge_cb], |
|
outputs=[import_all_out] |
|
) |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Entry point to launch the Gradio app.""" |
|
app = build_app() |
|
|
|
app.queue().launch(server_name="0.0.0.0", server_port=7860) |
|
|
|
if __name__ == "__main__": |
|
main() |