David Driscoll
		
	commited on
		
		
					Commit 
							
							·
						
						107dab2
	
1
								Parent(s):
							
							dfc63b4
								
Model overhaul
Browse files
    	
        app.py
    CHANGED
    
    | @@ -2,262 +2,234 @@ import gradio as gr | |
| 2 | 
             
            import cv2
         | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import torch
         | 
| 5 | 
            -
            from torchvision import models, transforms
         | 
| 6 | 
            -
            from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
         | 
| 7 | 
             
            from PIL import Image
         | 
| 8 | 
             
            import mediapipe as mp
         | 
| 9 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 10 |  | 
| 11 | 
             
            # -----------------------------
         | 
| 12 | 
            -
            # Configuration
         | 
| 13 | 
             
            # -----------------------------
         | 
| 14 | 
            -
            SKIP_RATE = 1  # For image processing, always run the analysis
         | 
| 15 | 
            -
             | 
| 16 | 
            -
            # Use GPU if available
         | 
| 17 | 
             
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 18 | 
            -
             | 
| 19 | 
            -
            # Desired input size for faster inference
         | 
| 20 | 
             
            DESIRED_SIZE = (640, 480)
         | 
| 21 |  | 
| 22 | 
             
            # -----------------------------
         | 
| 23 | 
            -
            #  | 
| 24 | 
            -
            # -----------------------------
         | 
| 25 | 
            -
            posture_cache = {"landmarks": None, "text": "Initializing...", "counter": 0}
         | 
| 26 | 
            -
            emotion_cache = {"text": "Initializing...", "counter": 0}
         | 
| 27 | 
            -
            objects_cache = {"boxes": None, "text": "Initializing...", "object_list_text": "", "counter": 0}
         | 
| 28 | 
            -
            faces_cache = {"boxes": None, "text": "Initializing...", "counter": 0}
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            # -----------------------------
         | 
| 31 | 
            -
            # Initialize Models and Helpers
         | 
| 32 | 
             
            # -----------------------------
         | 
| 33 | 
            -
            mp_pose = mp.solutions.pose
         | 
| 34 | 
            -
            pose = mp_pose.Pose()
         | 
| 35 | 
            -
            mp_drawing = mp.solutions.drawing_utils
         | 
| 36 | 
            -
             | 
| 37 | 
             
            mp_face_detection = mp.solutions.face_detection
         | 
| 38 | 
             
            face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
         | 
| 39 |  | 
| 40 | 
            -
            object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
         | 
| 41 | 
            -
                weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
         | 
| 42 | 
            -
            )
         | 
| 43 | 
            -
            object_detection_model.eval().to(device)  # Move model to GPU if available
         | 
| 44 | 
            -
             | 
| 45 | 
            -
            obj_transform = transforms.Compose([transforms.ToTensor()])
         | 
| 46 | 
            -
             | 
| 47 | 
            -
            # Initialize the FER emotion detector
         | 
| 48 | 
            -
            emotion_detector = FER(mtcnn=True)
         | 
| 49 | 
            -
             | 
| 50 | 
            -
            # Retrieve object categories from model weights metadata
         | 
| 51 | 
            -
            object_categories = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta["categories"]
         | 
| 52 | 
            -
             | 
| 53 | 
             
            # -----------------------------
         | 
| 54 | 
            -
            #  | 
| 55 | 
             
            # -----------------------------
         | 
| 56 | 
            -
            def draw_posture_overlay(raw_frame, landmarks):
         | 
| 57 | 
            -
                # Draw connector lines using MediaPipe's POSE_CONNECTIONS
         | 
| 58 | 
            -
                for connection in mp_pose.POSE_CONNECTIONS:
         | 
| 59 | 
            -
                    start_idx, end_idx = connection
         | 
| 60 | 
            -
                    if start_idx < len(landmarks) and end_idx < len(landmarks):
         | 
| 61 | 
            -
                        start_point = landmarks[start_idx]
         | 
| 62 | 
            -
                        end_point = landmarks[end_idx]
         | 
| 63 | 
            -
                        cv2.line(raw_frame, start_point, end_point, (50, 205, 50), 2)
         | 
| 64 | 
            -
                # Draw landmark points in lime green (BGR: (50,205,50))
         | 
| 65 | 
            -
                for (x, y) in landmarks:
         | 
| 66 | 
            -
                    cv2.circle(raw_frame, (x, y), 4, (50, 205, 50), -1)
         | 
| 67 | 
            -
                return raw_frame
         | 
| 68 |  | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
             | 
| 72 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 73 |  | 
| 74 | 
             
            # -----------------------------
         | 
| 75 | 
            -
            #  | 
| 76 | 
             
            # -----------------------------
         | 
| 77 | 
            -
            def compute_posture_overlay(image):
         | 
| 78 | 
            -
                frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
         | 
| 79 | 
            -
                h, w, _ = frame_bgr.shape
         | 
| 80 | 
            -
                frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
         | 
| 81 | 
            -
                small_h, small_w, _ = frame_bgr_small.shape
         | 
| 82 | 
            -
             | 
| 83 | 
            -
                frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
         | 
| 84 | 
            -
                pose_results = pose.process(frame_rgb_small)
         | 
| 85 |  | 
| 86 | 
            -
             | 
| 87 | 
            -
             | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
            -
             | 
| 92 | 
            -
             | 
| 93 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 94 | 
             
                else:
         | 
| 95 | 
            -
                     | 
| 96 | 
            -
             | 
| 97 | 
            -
             | 
| 98 | 
            -
                 | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
                 | 
| 102 | 
            -
                 | 
| 103 | 
            -
                 | 
| 104 | 
            -
             | 
| 105 | 
            -
                 | 
| 106 | 
            -
                 | 
| 107 | 
            -
             | 
| 108 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 109 | 
             
                else:
         | 
| 110 | 
            -
                     | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
                 | 
| 115 | 
            -
                 | 
| 116 | 
            -
                 | 
| 117 | 
            -
             | 
| 118 | 
            -
                 | 
| 119 | 
            -
                 | 
| 120 | 
            -
             | 
| 121 | 
            -
                 | 
| 122 | 
            -
                    detections = object_detection_model([img_tensor])[0]
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                threshold = 0.8
         | 
| 125 | 
            -
                boxes = []
         | 
| 126 | 
            -
                object_list = []
         | 
| 127 | 
            -
                for box, score, label in zip(detections["boxes"], detections["scores"], detections["labels"]):
         | 
| 128 | 
            -
                    if score > threshold:
         | 
| 129 | 
            -
                        boxes.append(tuple(box.int().cpu().numpy()))
         | 
| 130 | 
            -
                        label_idx = int(label)
         | 
| 131 | 
            -
                        label_name = object_categories[label_idx] if label_idx < len(object_categories) else "Unknown"
         | 
| 132 | 
            -
                        object_list.append(f"{label_name} ({score:.2f})")
         | 
| 133 | 
            -
                text = f"Detected {len(boxes)} object(s)" if boxes else "No objects detected"
         | 
| 134 | 
            -
                object_list_text = " | ".join(object_list) if object_list else "None"
         | 
| 135 | 
            -
                return boxes, text, object_list_text
         | 
| 136 | 
            -
             | 
| 137 | 
            -
            def compute_faces_overlay(image):
         | 
| 138 | 
            -
                frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
         | 
| 139 | 
            -
                h, w, _ = frame_bgr.shape
         | 
| 140 | 
            -
                frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
         | 
| 141 | 
            -
                small_h, small_w, _ = frame_bgr_small.shape
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
         | 
| 144 | 
            -
                face_results = face_detection.process(frame_rgb_small)
         | 
| 145 | 
            -
             | 
| 146 | 
            -
                boxes = []
         | 
| 147 | 
             
                if face_results.detections:
         | 
| 148 | 
            -
                     | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 156 | 
             
                else:
         | 
| 157 | 
            -
                     | 
| 158 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 159 |  | 
| 160 | 
             
            # -----------------------------
         | 
| 161 | 
            -
            #  | 
| 162 | 
             
            # -----------------------------
         | 
| 163 | 
            -
            def analyze_posture_current(image):
         | 
| 164 | 
            -
                global posture_cache
         | 
| 165 | 
            -
                posture_cache["counter"] += 1
         | 
| 166 | 
            -
                current_frame = np.array(image)
         | 
| 167 | 
            -
                if posture_cache["counter"] % SKIP_RATE == 0 or posture_cache["landmarks"] is None:
         | 
| 168 | 
            -
                    landmarks, text = compute_posture_overlay(image)
         | 
| 169 | 
            -
                    posture_cache["landmarks"] = landmarks
         | 
| 170 | 
            -
                    posture_cache["text"] = text
         | 
| 171 |  | 
| 172 | 
            -
             | 
| 173 | 
            -
                 | 
| 174 | 
            -
             | 
| 175 |  | 
| 176 | 
            -
             | 
|  | |
|  | |
| 177 |  | 
| 178 | 
            -
            def  | 
| 179 | 
            -
                 | 
| 180 | 
            -
                 | 
| 181 | 
            -
                current_frame = np.array(image)
         | 
| 182 | 
            -
                if emotion_cache["counter"] % SKIP_RATE == 0 or emotion_cache["text"] is None:
         | 
| 183 | 
            -
                    text = compute_emotion_overlay(image)
         | 
| 184 | 
            -
                    emotion_cache["text"] = text
         | 
| 185 |  | 
| 186 | 
            -
             | 
|  | |
|  | |
| 187 |  | 
| 188 | 
            -
            def  | 
| 189 | 
            -
                 | 
| 190 | 
            -
                 | 
| 191 | 
            -
                current_frame = np.array(image)
         | 
| 192 | 
            -
                if objects_cache["counter"] % SKIP_RATE == 0 or objects_cache["boxes"] is None:
         | 
| 193 | 
            -
                    boxes, text, object_list_text = compute_objects_overlay(image)
         | 
| 194 | 
            -
                    objects_cache["boxes"] = boxes
         | 
| 195 | 
            -
                    objects_cache["text"] = text
         | 
| 196 | 
            -
                    objects_cache["object_list_text"] = object_list_text
         | 
| 197 | 
            -
             | 
| 198 | 
            -
                output = current_frame.copy()
         | 
| 199 | 
            -
                if objects_cache["boxes"]:
         | 
| 200 | 
            -
                    output = draw_boxes_overlay(output, objects_cache["boxes"], (255, 255, 0))
         | 
| 201 | 
            -
                combined_text = f"Object Detection: {objects_cache['text']}<br>Details: {objects_cache['object_list_text']}"
         | 
| 202 | 
            -
                return output, f"<div style='color: lime !important;'>{combined_text}</div>"
         | 
| 203 | 
            -
             | 
| 204 | 
            -
            def analyze_faces_current(image):
         | 
| 205 | 
            -
                global faces_cache
         | 
| 206 | 
            -
                faces_cache["counter"] += 1
         | 
| 207 | 
            -
                current_frame = np.array(image)
         | 
| 208 | 
            -
                if faces_cache["counter"] % SKIP_RATE == 0 or faces_cache["boxes"] is None:
         | 
| 209 | 
            -
                    boxes, text = compute_faces_overlay(image)
         | 
| 210 | 
            -
                    faces_cache["boxes"] = boxes
         | 
| 211 | 
            -
                    faces_cache["text"] = text
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                output = current_frame.copy()
         | 
| 214 | 
            -
                if faces_cache["boxes"]:
         | 
| 215 | 
            -
                    output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
         | 
| 216 | 
            -
                return output, f"<div style='color: lime !important;'>Face Detection: {faces_cache['text']}</div>"
         | 
| 217 | 
            -
             | 
| 218 | 
            -
            def analyze_all(image):
         | 
| 219 | 
            -
                current_frame = np.array(image).copy()
         | 
| 220 | 
            -
                
         | 
| 221 | 
            -
                # Posture Analysis
         | 
| 222 | 
            -
                landmarks, posture_text = compute_posture_overlay(image)
         | 
| 223 | 
            -
                if landmarks:
         | 
| 224 | 
            -
                    current_frame = draw_posture_overlay(current_frame, landmarks)
         | 
| 225 | 
            -
                    
         | 
| 226 | 
            -
                # Emotion Analysis
         | 
| 227 | 
            -
                emotion_text = compute_emotion_overlay(image)
         | 
| 228 | 
            -
                
         | 
| 229 | 
            -
                # Object Detection
         | 
| 230 | 
            -
                boxes_obj, objects_text, object_list_text = compute_objects_overlay(image)
         | 
| 231 | 
            -
                if boxes_obj:
         | 
| 232 | 
            -
                    current_frame = draw_boxes_overlay(current_frame, boxes_obj, (255, 255, 0))
         | 
| 233 | 
            -
                    
         | 
| 234 | 
            -
                # Face Detection
         | 
| 235 | 
            -
                boxes_face, faces_text = compute_faces_overlay(image)
         | 
| 236 | 
            -
                if boxes_face:
         | 
| 237 | 
            -
                    current_frame = draw_boxes_overlay(current_frame, boxes_face, (0, 0, 255))
         | 
| 238 | 
            -
                    
         | 
| 239 | 
            -
                # Combined Analysis Text
         | 
| 240 | 
            -
                combined_text = (
         | 
| 241 | 
            -
                    f"<b>Posture Analysis:</b> {posture_text}<br>"
         | 
| 242 | 
            -
                    f"<b>Emotion Analysis:</b> {emotion_text}<br>"
         | 
| 243 | 
            -
                    f"<b>Object Detection:</b> {objects_text}<br>"
         | 
| 244 | 
            -
                    f"<b>Detected Objects:</b> {object_list_text}<br>"
         | 
| 245 | 
            -
                    f"<b>Face Detection:</b> {faces_text}"
         | 
| 246 | 
            -
                )
         | 
| 247 | 
            -
                
         | 
| 248 | 
            -
                # Image Description Panel (High-Tech)
         | 
| 249 | 
            -
                if object_list_text and object_list_text != "None":
         | 
| 250 | 
            -
                    description_text = f"Image Description: The scene features {object_list_text}."
         | 
| 251 | 
            -
                else:
         | 
| 252 | 
            -
                    description_text = "Image Description: No prominent objects detected."
         | 
| 253 | 
            -
                    
         | 
| 254 | 
            -
                combined_text += f"<br><br><div style='border:1px solid lime; padding:10px; box-shadow: 0 0 10px lime;'><b>{description_text}</b></div>"
         | 
| 255 | 
            -
                
         | 
| 256 | 
            -
                combined_text_html = f"<div style='color: lime !important;'>{combined_text}</div>"
         | 
| 257 | 
            -
                return current_frame, combined_text_html
         | 
| 258 |  | 
| 259 | 
             
            # -----------------------------
         | 
| 260 | 
            -
            # Custom CSS ( | 
| 261 | 
             
            # -----------------------------
         | 
| 262 | 
             
            custom_css = """
         | 
| 263 | 
             
            @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700&display=swap');
         | 
| @@ -289,50 +261,55 @@ input, button, .output { | |
| 289 | 
             
            """
         | 
| 290 |  | 
| 291 | 
             
            # -----------------------------
         | 
| 292 | 
            -
            # Create  | 
| 293 | 
             
            # -----------------------------
         | 
| 294 | 
            -
             | 
| 295 | 
            -
                fn= | 
| 296 | 
            -
                inputs=gr.Image(label="Upload  | 
| 297 | 
            -
                outputs=[gr.Image(type="numpy", label=" | 
| 298 | 
            -
             | 
| 299 | 
            -
                 | 
|  | |
| 300 | 
             
                live=False
         | 
| 301 | 
             
            )
         | 
| 302 |  | 
| 303 | 
             
            emotion_interface = gr.Interface(
         | 
| 304 | 
            -
                fn= | 
| 305 | 
            -
                inputs=gr.Image(label="Upload  | 
| 306 | 
            -
                outputs=[gr.Image(type="numpy", label=" | 
| 307 | 
            -
             | 
| 308 | 
            -
                 | 
|  | |
| 309 | 
             
                live=False
         | 
| 310 | 
             
            )
         | 
| 311 |  | 
| 312 | 
            -
             | 
| 313 | 
            -
                fn= | 
| 314 | 
            -
                inputs=gr.Image(label="Upload  | 
| 315 | 
            -
                outputs=[gr.Image(type="numpy", label=" | 
| 316 | 
            -
             | 
| 317 | 
            -
                 | 
|  | |
| 318 | 
             
                live=False
         | 
| 319 | 
             
            )
         | 
| 320 |  | 
| 321 | 
            -
             | 
| 322 | 
            -
                fn= | 
| 323 | 
            -
                inputs=gr.Image(label="Upload  | 
| 324 | 
            -
                outputs=[gr.Image(type="numpy", label=" | 
| 325 | 
            -
             | 
| 326 | 
            -
                 | 
|  | |
| 327 | 
             
                live=False
         | 
| 328 | 
             
            )
         | 
| 329 |  | 
| 330 | 
            -
             | 
| 331 | 
            -
                fn= | 
| 332 | 
            -
                inputs=gr.Image(label="Upload an Image for  | 
| 333 | 
            -
                outputs=[gr.Image(type="numpy", label=" | 
| 334 | 
            -
             | 
| 335 | 
            -
                 | 
|  | |
| 336 | 
             
                live=False
         | 
| 337 | 
             
            )
         | 
| 338 |  | 
| @@ -340,17 +317,29 @@ all_interface = gr.Interface( | |
| 340 | 
             
            # Create a Tabbed Interface
         | 
| 341 | 
             
            # -----------------------------
         | 
| 342 | 
             
            tabbed_interface = gr.TabbedInterface(
         | 
| 343 | 
            -
                interface_list=[ | 
| 344 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 345 | 
             
            )
         | 
| 346 |  | 
| 347 | 
             
            # -----------------------------
         | 
| 348 | 
            -
            # Wrap in a Blocks Layout
         | 
| 349 | 
             
            # -----------------------------
         | 
| 350 | 
             
            demo = gr.Blocks(css=custom_css)
         | 
| 351 | 
             
            with demo:
         | 
| 352 | 
            -
                gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis  | 
| 353 | 
            -
                gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run  | 
| 354 | 
             
                tabbed_interface.render()
         | 
| 355 |  | 
| 356 | 
             
            if __name__ == "__main__":
         | 
|  | |
| 2 | 
             
            import cv2
         | 
| 3 | 
             
            import numpy as np
         | 
| 4 | 
             
            import torch
         | 
|  | |
|  | |
| 5 | 
             
            from PIL import Image
         | 
| 6 | 
             
            import mediapipe as mp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from transformers import (
         | 
| 9 | 
            +
                AutoFeatureExtractor,
         | 
| 10 | 
            +
                AutoModel,
         | 
| 11 | 
            +
                AutoImageProcessor,
         | 
| 12 | 
            +
                AutoModelForImageClassification,
         | 
| 13 | 
            +
                AutoModelForSemanticSegmentation
         | 
| 14 | 
            +
            )
         | 
| 15 |  | 
| 16 | 
             
            # -----------------------------
         | 
| 17 | 
            +
            # Configuration & Device Setup
         | 
| 18 | 
             
            # -----------------------------
         | 
|  | |
|  | |
|  | |
| 19 | 
             
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
|  | |
|  | |
| 20 | 
             
            DESIRED_SIZE = (640, 480)
         | 
| 21 |  | 
| 22 | 
             
            # -----------------------------
         | 
| 23 | 
            +
            # Initialize Mediapipe Face Detection
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 24 | 
             
            # -----------------------------
         | 
|  | |
|  | |
|  | |
|  | |
| 25 | 
             
            mp_face_detection = mp.solutions.face_detection
         | 
| 26 | 
             
            face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
         | 
| 27 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 28 | 
             
            # -----------------------------
         | 
| 29 | 
            +
            # Load New Models from Hugging Face
         | 
| 30 | 
             
            # -----------------------------
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 31 |  | 
| 32 | 
            +
            # 1. Facial Recognition & Identification (facebook/dino-vitb16)
         | 
| 33 | 
            +
            facial_recognition_extractor = AutoFeatureExtractor.from_pretrained("facebook/dino-vitb16")
         | 
| 34 | 
            +
            facial_recognition_model = AutoModel.from_pretrained("facebook/dino-vitb16")
         | 
| 35 | 
            +
            facial_recognition_model.to(device)
         | 
| 36 | 
            +
            facial_recognition_model.eval()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            # Create a dummy database for demonstration (embeddings of dimension 768 assumed)
         | 
| 39 | 
            +
            dummy_database = {
         | 
| 40 | 
            +
                "Alice": torch.randn(768).to(device),
         | 
| 41 | 
            +
                "Bob": torch.randn(768).to(device)
         | 
| 42 | 
            +
            }
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # 2. Emotion Detection (nateraw/facial-expression-recognition)
         | 
| 45 | 
            +
            emotion_processor = AutoImageProcessor.from_pretrained("nateraw/facial-expression-recognition")
         | 
| 46 | 
            +
            emotion_model = AutoModelForImageClassification.from_pretrained("nateraw/facial-expression-recognition")
         | 
| 47 | 
            +
            emotion_model.to(device)
         | 
| 48 | 
            +
            emotion_model.eval()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # 3. Age & Gender Prediction (oayu/age-gender-estimation)
         | 
| 51 | 
            +
            age_gender_processor = AutoImageProcessor.from_pretrained("oayu/age-gender-estimation")
         | 
| 52 | 
            +
            age_gender_model = AutoModelForImageClassification.from_pretrained("oayu/age-gender-estimation")
         | 
| 53 | 
            +
            age_gender_model.to(device)
         | 
| 54 | 
            +
            age_gender_model.eval()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            # 4. Face Parsing (hila-chefer/face-parsing)
         | 
| 57 | 
            +
            face_parsing_processor = AutoImageProcessor.from_pretrained("hila-chefer/face-parsing")
         | 
| 58 | 
            +
            face_parsing_model = AutoModelForSemanticSegmentation.from_pretrained("hila-chefer/face-parsing")
         | 
| 59 | 
            +
            face_parsing_model.to(device)
         | 
| 60 | 
            +
            face_parsing_model.eval()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            # 5. Deepfake Detection (microsoft/FaceForensics)
         | 
| 63 | 
            +
            deepfake_processor = AutoImageProcessor.from_pretrained("microsoft/FaceForensics")
         | 
| 64 | 
            +
            deepfake_model = AutoModelForImageClassification.from_pretrained("microsoft/FaceForensics")
         | 
| 65 | 
            +
            deepfake_model.to(device)
         | 
| 66 | 
            +
            deepfake_model.eval()
         | 
| 67 |  | 
| 68 | 
             
            # -----------------------------
         | 
| 69 | 
            +
            # Helper Functions for New Inferences
         | 
| 70 | 
             
            # -----------------------------
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 71 |  | 
| 72 | 
            +
            def compute_facial_recognition(image):
         | 
| 73 | 
            +
                """
         | 
| 74 | 
            +
                Detects a face using MediaPipe, crops it, and computes its embedding with DINO-ViT.
         | 
| 75 | 
            +
                Compares the embedding against a dummy database to "identify" the person.
         | 
| 76 | 
            +
                """
         | 
| 77 | 
            +
                frame = np.array(image)
         | 
| 78 | 
            +
                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
         | 
| 79 | 
            +
                frame_resized = cv2.resize(frame_bgr, DESIRED_SIZE)
         | 
| 80 | 
            +
                frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
         | 
| 81 | 
            +
                
         | 
| 82 | 
            +
                face_results = face_detection.process(frame_rgb)
         | 
| 83 | 
            +
                if face_results.detections:
         | 
| 84 | 
            +
                    detection = face_results.detections[0]
         | 
| 85 | 
            +
                    bbox = detection.location_data.relative_bounding_box
         | 
| 86 | 
            +
                    h, w, _ = frame_rgb.shape
         | 
| 87 | 
            +
                    x = int(bbox.xmin * w)
         | 
| 88 | 
            +
                    y = int(bbox.ymin * h)
         | 
| 89 | 
            +
                    box_w = int(bbox.width * w)
         | 
| 90 | 
            +
                    box_h = int(bbox.height * h)
         | 
| 91 | 
            +
                    face_crop = frame_rgb[y:y+box_h, x:x+box_w]
         | 
| 92 | 
            +
                    face_image = Image.fromarray(face_crop)
         | 
| 93 | 
            +
                    
         | 
| 94 | 
            +
                    inputs = facial_recognition_extractor(face_image, return_tensors="pt").to(device)
         | 
| 95 | 
            +
                    with torch.no_grad():
         | 
| 96 | 
            +
                        outputs = facial_recognition_model(**inputs)
         | 
| 97 | 
            +
                    # Use mean pooling over the last hidden state to get an embedding vector
         | 
| 98 | 
            +
                    embeddings = outputs.last_hidden_state.mean(dim=1).squeeze()
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    # Compare against dummy database using cosine similarity
         | 
| 101 | 
            +
                    best_score = -1
         | 
| 102 | 
            +
                    best_name = "Unknown"
         | 
| 103 | 
            +
                    for name, db_emb in dummy_database.items():
         | 
| 104 | 
            +
                        cos_sim = torch.nn.functional.cosine_similarity(embeddings, db_emb, dim=0)
         | 
| 105 | 
            +
                        if cos_sim > best_score:
         | 
| 106 | 
            +
                            best_score = cos_sim
         | 
| 107 | 
            +
                            best_name = name
         | 
| 108 | 
            +
                    threshold = 0.7  # dummy threshold for identification
         | 
| 109 | 
            +
                    if best_score > threshold:
         | 
| 110 | 
            +
                        result = f"Identified as {best_name} (sim: {best_score:.2f})"
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        result = f"No match found (best: {best_name}, sim: {best_score:.2f})"
         | 
| 113 | 
            +
                    return face_crop, result
         | 
| 114 | 
             
                else:
         | 
| 115 | 
            +
                    return frame, "No face detected"
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            def compute_emotion_detection(image):
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
                Detects a face, crops it, and classifies the facial expression.
         | 
| 120 | 
            +
                """
         | 
| 121 | 
            +
                frame = np.array(image)
         | 
| 122 | 
            +
                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
         | 
| 123 | 
            +
                frame_resized = cv2.resize(frame_bgr, DESIRED_SIZE)
         | 
| 124 | 
            +
                frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                face_results = face_detection.process(frame_rgb)
         | 
| 127 | 
            +
                if face_results.detections:
         | 
| 128 | 
            +
                    detection = face_results.detections[0]
         | 
| 129 | 
            +
                    bbox = detection.location_data.relative_bounding_box
         | 
| 130 | 
            +
                    h, w, _ = frame_rgb.shape
         | 
| 131 | 
            +
                    x = int(bbox.xmin * w)
         | 
| 132 | 
            +
                    y = int(bbox.ymin * h)
         | 
| 133 | 
            +
                    box_w = int(bbox.width * w)
         | 
| 134 | 
            +
                    box_h = int(bbox.height * h)
         | 
| 135 | 
            +
                    face_crop = frame_rgb[y:y+box_h, x:x+box_w]
         | 
| 136 | 
            +
                    face_image = Image.fromarray(face_crop)
         | 
| 137 | 
            +
                    
         | 
| 138 | 
            +
                    inputs = emotion_processor(face_image, return_tensors="pt").to(device)
         | 
| 139 | 
            +
                    with torch.no_grad():
         | 
| 140 | 
            +
                        outputs = emotion_model(**inputs)
         | 
| 141 | 
            +
                    logits = outputs.logits
         | 
| 142 | 
            +
                    pred = logits.argmax(-1).item()
         | 
| 143 | 
            +
                    label = emotion_model.config.id2label[pred]
         | 
| 144 | 
            +
                    return face_crop, f"Emotion: {label}"
         | 
| 145 | 
             
                else:
         | 
| 146 | 
            +
                    return frame, "No face detected"
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            def compute_age_gender(image):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Detects a face, crops it, and predicts the age & gender.
         | 
| 151 | 
            +
                """
         | 
| 152 | 
            +
                frame = np.array(image)
         | 
| 153 | 
            +
                frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
         | 
| 154 | 
            +
                frame_resized = cv2.resize(frame_bgr, DESIRED_SIZE)
         | 
| 155 | 
            +
                frame_rgb = cv2.cvtColor(frame_resized, cv2.COLOR_BGR2RGB)
         | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                face_results = face_detection.process(frame_rgb)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 158 | 
             
                if face_results.detections:
         | 
| 159 | 
            +
                    detection = face_results.detections[0]
         | 
| 160 | 
            +
                    bbox = detection.location_data.relative_bounding_box
         | 
| 161 | 
            +
                    h, w, _ = frame_rgb.shape
         | 
| 162 | 
            +
                    x = int(bbox.xmin * w)
         | 
| 163 | 
            +
                    y = int(bbox.ymin * h)
         | 
| 164 | 
            +
                    box_w = int(bbox.width * w)
         | 
| 165 | 
            +
                    box_h = int(bbox.height * h)
         | 
| 166 | 
            +
                    face_crop = frame_rgb[y:y+box_h, x:x+box_w]
         | 
| 167 | 
            +
                    face_image = Image.fromarray(face_crop)
         | 
| 168 | 
            +
                    
         | 
| 169 | 
            +
                    inputs = age_gender_processor(face_image, return_tensors="pt").to(device)
         | 
| 170 | 
            +
                    with torch.no_grad():
         | 
| 171 | 
            +
                        outputs = age_gender_model(**inputs)
         | 
| 172 | 
            +
                    logits = outputs.logits
         | 
| 173 | 
            +
                    pred = logits.argmax(-1).item()
         | 
| 174 | 
            +
                    label = age_gender_model.config.id2label[pred]
         | 
| 175 | 
            +
                    return face_crop, f"Age & Gender: {label}"
         | 
| 176 | 
             
                else:
         | 
| 177 | 
            +
                    return frame, "No face detected"
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            def compute_face_parsing(image):
         | 
| 180 | 
            +
                """
         | 
| 181 | 
            +
                Runs face parsing (segmentation) on the provided image.
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                image_pil = Image.fromarray(np.array(image))
         | 
| 184 | 
            +
                inputs = face_parsing_processor(image_pil, return_tensors="pt").to(device)
         | 
| 185 | 
            +
                with torch.no_grad():
         | 
| 186 | 
            +
                    outputs = face_parsing_model(**inputs)
         | 
| 187 | 
            +
                logits = outputs.logits  # shape: (batch, num_labels, H, W)
         | 
| 188 | 
            +
                segmentation = logits.argmax(dim=1)[0].cpu().numpy()
         | 
| 189 | 
            +
                # For visualization, we apply a color map to the segmentation mask.
         | 
| 190 | 
            +
                segmentation_norm = np.uint8(255 * segmentation / (segmentation.max() + 1e-5))
         | 
| 191 | 
            +
                segmentation_color = cv2.applyColorMap(segmentation_norm, cv2.COLORMAP_JET)
         | 
| 192 | 
            +
                return segmentation_color, "Face Parsing completed"
         | 
| 193 | 
            +
             | 
| 194 | 
            +
            def compute_deepfake_detection(image):
         | 
| 195 | 
            +
                """
         | 
| 196 | 
            +
                Runs deepfake detection on the image.
         | 
| 197 | 
            +
                """
         | 
| 198 | 
            +
                image_pil = Image.fromarray(np.array(image))
         | 
| 199 | 
            +
                inputs = deepfake_processor(image_pil, return_tensors="pt").to(device)
         | 
| 200 | 
            +
                with torch.no_grad():
         | 
| 201 | 
            +
                    outputs = deepfake_model(**inputs)
         | 
| 202 | 
            +
                logits = outputs.logits
         | 
| 203 | 
            +
                pred = logits.argmax(-1).item()
         | 
| 204 | 
            +
                label = deepfake_model.config.id2label[pred]
         | 
| 205 | 
            +
                return np.array(image), f"Deepfake Detection: {label}"
         | 
| 206 |  | 
| 207 | 
             
            # -----------------------------
         | 
| 208 | 
            +
            # Analysis Functions (Wrapping Inference & Green Text)
         | 
| 209 | 
             
            # -----------------------------
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 210 |  | 
| 211 | 
            +
            def analyze_facial_recognition(image):
         | 
| 212 | 
            +
                annotated_face, result = compute_facial_recognition(image)
         | 
| 213 | 
            +
                return annotated_face, f"<div style='color: lime !important;'>Facial Recognition: {result}</div>"
         | 
| 214 |  | 
| 215 | 
            +
            def analyze_emotion_detection(image):
         | 
| 216 | 
            +
                face_crop, result = compute_emotion_detection(image)
         | 
| 217 | 
            +
                return face_crop, f"<div style='color: lime !important;'>{result}</div>"
         | 
| 218 |  | 
| 219 | 
            +
            def analyze_age_gender(image):
         | 
| 220 | 
            +
                face_crop, result = compute_age_gender(image)
         | 
| 221 | 
            +
                return face_crop, f"<div style='color: lime !important;'>{result}</div>"
         | 
|  | |
|  | |
|  | |
|  | |
| 222 |  | 
| 223 | 
            +
            def analyze_face_parsing(image):
         | 
| 224 | 
            +
                segmentation, result = compute_face_parsing(image)
         | 
| 225 | 
            +
                return segmentation, f"<div style='color: lime !important;'>{result}</div>"
         | 
| 226 |  | 
| 227 | 
            +
            def analyze_deepfake_detection(image):
         | 
| 228 | 
            +
                output, result = compute_deepfake_detection(image)
         | 
| 229 | 
            +
                return output, f"<div style='color: lime !important;'>{result}</div>"
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 230 |  | 
| 231 | 
             
            # -----------------------------
         | 
| 232 | 
            +
            # Custom CSS (All Text in Green)
         | 
| 233 | 
             
            # -----------------------------
         | 
| 234 | 
             
            custom_css = """
         | 
| 235 | 
             
            @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700&display=swap');
         | 
|  | |
| 261 | 
             
            """
         | 
| 262 |  | 
| 263 | 
             
            # -----------------------------
         | 
| 264 | 
            +
            # Create Gradio Interfaces for New Models
         | 
| 265 | 
             
            # -----------------------------
         | 
| 266 | 
            +
            facial_recognition_interface = gr.Interface(
         | 
| 267 | 
            +
                fn=analyze_facial_recognition,
         | 
| 268 | 
            +
                inputs=gr.Image(label="Upload a Face Image for Facial Recognition"),
         | 
| 269 | 
            +
                outputs=[gr.Image(type="numpy", label="Cropped Face / Embedding Visualization"), 
         | 
| 270 | 
            +
                         gr.HTML(label="Facial Recognition Result")],
         | 
| 271 | 
            +
                title="Facial Recognition & Identification",
         | 
| 272 | 
            +
                description="Extracts facial embeddings using facebook/dino-vitb16 and identifies the face by comparing against a dummy database.",
         | 
| 273 | 
             
                live=False
         | 
| 274 | 
             
            )
         | 
| 275 |  | 
| 276 | 
             
            emotion_interface = gr.Interface(
         | 
| 277 | 
            +
                fn=analyze_emotion_detection,
         | 
| 278 | 
            +
                inputs=gr.Image(label="Upload a Face Image for Emotion Detection"),
         | 
| 279 | 
            +
                outputs=[gr.Image(type="numpy", label="Cropped Face"), 
         | 
| 280 | 
            +
                         gr.HTML(label="Emotion Detection")],
         | 
| 281 | 
            +
                title="Emotion Detection",
         | 
| 282 | 
            +
                description="Classifies the facial expression using nateraw/facial-expression-recognition.",
         | 
| 283 | 
             
                live=False
         | 
| 284 | 
             
            )
         | 
| 285 |  | 
| 286 | 
            +
            age_gender_interface = gr.Interface(
         | 
| 287 | 
            +
                fn=analyze_age_gender,
         | 
| 288 | 
            +
                inputs=gr.Image(label="Upload a Face Image for Age & Gender Prediction"),
         | 
| 289 | 
            +
                outputs=[gr.Image(type="numpy", label="Cropped Face"), 
         | 
| 290 | 
            +
                         gr.HTML(label="Age & Gender Prediction")],
         | 
| 291 | 
            +
                title="Age & Gender Prediction",
         | 
| 292 | 
            +
                description="Predicts age and gender from the face using oayu/age-gender-estimation.",
         | 
| 293 | 
             
                live=False
         | 
| 294 | 
             
            )
         | 
| 295 |  | 
| 296 | 
            +
            face_parsing_interface = gr.Interface(
         | 
| 297 | 
            +
                fn=analyze_face_parsing,
         | 
| 298 | 
            +
                inputs=gr.Image(label="Upload a Face Image for Face Parsing"),
         | 
| 299 | 
            +
                outputs=[gr.Image(type="numpy", label="Segmentation Overlay"), 
         | 
| 300 | 
            +
                         gr.HTML(label="Face Parsing")],
         | 
| 301 | 
            +
                title="Face Parsing",
         | 
| 302 | 
            +
                description="Segments face regions (eyes, nose, lips, hair, etc.) using hila-chefer/face-parsing.",
         | 
| 303 | 
             
                live=False
         | 
| 304 | 
             
            )
         | 
| 305 |  | 
| 306 | 
            +
            deepfake_interface = gr.Interface(
         | 
| 307 | 
            +
                fn=analyze_deepfake_detection,
         | 
| 308 | 
            +
                inputs=gr.Image(label="Upload an Image for Deepfake Detection"),
         | 
| 309 | 
            +
                outputs=[gr.Image(type="numpy", label="Input Image"), 
         | 
| 310 | 
            +
                         gr.HTML(label="Deepfake Detection")],
         | 
| 311 | 
            +
                title="Deepfake Detection",
         | 
| 312 | 
            +
                description="Detects manipulated or deepfake images using microsoft/FaceForensics.",
         | 
| 313 | 
             
                live=False
         | 
| 314 | 
             
            )
         | 
| 315 |  | 
|  | |
| 317 | 
             
            # Create a Tabbed Interface
         | 
| 318 | 
             
            # -----------------------------
         | 
| 319 | 
             
            tabbed_interface = gr.TabbedInterface(
         | 
| 320 | 
            +
                interface_list=[
         | 
| 321 | 
            +
                    facial_recognition_interface, 
         | 
| 322 | 
            +
                    emotion_interface, 
         | 
| 323 | 
            +
                    age_gender_interface, 
         | 
| 324 | 
            +
                    face_parsing_interface, 
         | 
| 325 | 
            +
                    deepfake_interface
         | 
| 326 | 
            +
                ],
         | 
| 327 | 
            +
                tab_names=[
         | 
| 328 | 
            +
                    "Facial Recognition",
         | 
| 329 | 
            +
                    "Emotion Detection",
         | 
| 330 | 
            +
                    "Age & Gender",
         | 
| 331 | 
            +
                    "Face Parsing",
         | 
| 332 | 
            +
                    "Deepfake Detection"
         | 
| 333 | 
            +
                ]
         | 
| 334 | 
             
            )
         | 
| 335 |  | 
| 336 | 
             
            # -----------------------------
         | 
| 337 | 
            +
            # Wrap in a Blocks Layout & Launch
         | 
| 338 | 
             
            # -----------------------------
         | 
| 339 | 
             
            demo = gr.Blocks(css=custom_css)
         | 
| 340 | 
             
            with demo:
         | 
| 341 | 
            +
                gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis Face App</h1>")
         | 
| 342 | 
            +
                gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run advanced face analysis using state-of-the-art Hugging Face models.</p>")
         | 
| 343 | 
             
                tabbed_interface.render()
         | 
| 344 |  | 
| 345 | 
             
            if __name__ == "__main__":
         | 
