David Driscoll commited on
Commit
473b2d5
·
1 Parent(s): f3de933

fix emotion, output vector

Browse files
Files changed (1) hide show
  1. app.py +82 -40
app.py CHANGED
@@ -6,9 +6,8 @@ from torchvision import models, transforms
6
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
7
  from PIL import Image
8
  import mediapipe as mp
9
-
10
- # Hugging Face imports for emotion detection
11
- from transformers import AutoImageProcessor, AutoModelForImageClassification
12
 
13
  # -----------------------------
14
  # Configuration
@@ -28,6 +27,7 @@ faces_cache = {"boxes": None, "text": "Initializing...", "counter": 0}
28
  # -----------------------------
29
  # Initialize Models and Helpers
30
  # -----------------------------
 
31
  mp_pose = mp.solutions.pose
32
  pose = mp_pose.Pose()
33
  mp_drawing = mp.solutions.drawing_utils
@@ -35,22 +35,27 @@ mp_drawing = mp.solutions.drawing_utils
35
  mp_face_detection = mp.solutions.face_detection
36
  face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
37
 
 
38
  object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
39
  weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
40
  )
41
  object_detection_model.eval().to(device)
42
  obj_transform = transforms.Compose([transforms.ToTensor()])
43
 
44
- # Initialize the Hugging Face emotion detection model.
45
- # (Using the public "nateraw/fer" repo to mimic expression recognition.)
46
- emotion_processor = AutoImageProcessor.from_pretrained("nateraw/fer")
47
- emotion_model = AutoModelForImageClassification.from_pretrained("nateraw/fer")
48
- emotion_model.to(device)
49
- emotion_model.eval()
50
 
51
  # Retrieve object categories from model weights metadata
52
  object_categories = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta["categories"]
53
 
 
 
 
 
 
 
 
 
54
  # -----------------------------
55
  # Overlay Drawing Functions
56
  # -----------------------------
@@ -96,36 +101,14 @@ def compute_posture_overlay(image):
96
  return landmarks, text
97
 
98
  def compute_emotion_overlay(image):
99
- """
100
- This function mimics the original FER-based expression recognition,
101
- but uses a Hugging Face emotion model instead.
102
- """
103
  frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
104
  frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
105
  frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
106
-
107
- # Use MediaPipe to detect a face and crop it
108
- face_results = face_detection.process(frame_rgb_small)
109
- if face_results.detections:
110
- detection = face_results.detections[0]
111
- bbox = detection.location_data.relative_bounding_box
112
- h, w, _ = frame_rgb_small.shape
113
- x = int(bbox.xmin * w)
114
- y = int(bbox.ymin * h)
115
- box_w = int(bbox.width * w)
116
- box_h = int(bbox.height * h)
117
- face_crop = frame_rgb_small[y:y+box_h, x:x+box_w]
118
- face_image = Image.fromarray(face_crop)
119
-
120
- # Process face crop with the Hugging Face emotion model
121
- inputs = emotion_processor(face_image, return_tensors="pt").to(device)
122
- with torch.no_grad():
123
- outputs = emotion_model(**inputs)
124
- logits = outputs.logits
125
- probs = torch.softmax(logits, dim=-1)
126
- score, pred = torch.max(probs, dim=-1)
127
- label = emotion_model.config.id2label[pred.item()]
128
- text = f"{label} ({score.item():.2f})"
129
  else:
130
  text = "No face detected"
131
  return text
@@ -172,6 +155,37 @@ def compute_faces_overlay(image):
172
  text = "No faces detected"
173
  return boxes, text
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # -----------------------------
176
  # Main Analysis Functions for Single Image
177
  # -----------------------------
@@ -225,6 +239,11 @@ def analyze_faces_current(image):
225
  output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
226
  return output, f"<div style='color: lime !important;'>Face Detection: {faces_cache['text']}</div>"
227
 
 
 
 
 
 
228
  def analyze_all(image):
229
  current_frame = np.array(image).copy()
230
  # Posture Analysis
@@ -304,7 +323,7 @@ emotion_interface = gr.Interface(
304
  inputs=gr.Image(label="Upload an Image for Emotion Analysis"),
305
  outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Emotion Analysis")],
306
  title="Emotion",
307
- description="Detects facial emotions using a Hugging Face model.",
308
  live=False
309
  )
310
 
@@ -326,6 +345,15 @@ faces_interface = gr.Interface(
326
  live=False
327
  )
328
 
 
 
 
 
 
 
 
 
 
329
  all_interface = gr.Interface(
330
  fn=analyze_all,
331
  inputs=gr.Image(label="Upload an Image for All Inferences"),
@@ -336,8 +364,22 @@ all_interface = gr.Interface(
336
  )
337
 
338
  tabbed_interface = gr.TabbedInterface(
339
- interface_list=[posture_interface, emotion_interface, objects_interface, faces_interface, all_interface],
340
- tab_names=["Posture", "Emotion", "Objects", "Faces", "All Inferences"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  )
342
 
343
  # -----------------------------
@@ -346,7 +388,7 @@ tabbed_interface = gr.TabbedInterface(
346
  demo = gr.Blocks(css=custom_css)
347
  with demo:
348
  gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis Image App</h1>")
349
- gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run high-tech analysis for posture, emotions, objects, and faces.</p>")
350
  tabbed_interface.render()
351
 
352
  if __name__ == "__main__":
 
6
  from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
7
  from PIL import Image
8
  import mediapipe as mp
9
+ from fer import FER # Facial emotion recognition
10
+ from transformers import AutoFeatureExtractor, AutoModel
 
11
 
12
  # -----------------------------
13
  # Configuration
 
27
  # -----------------------------
28
  # Initialize Models and Helpers
29
  # -----------------------------
30
+ # MediaPipe Pose and Face Detection
31
  mp_pose = mp.solutions.pose
32
  pose = mp_pose.Pose()
33
  mp_drawing = mp.solutions.drawing_utils
 
35
  mp_face_detection = mp.solutions.face_detection
36
  face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
37
 
38
+ # Object Detection using Faster R-CNN
39
  object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
40
  weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
41
  )
42
  object_detection_model.eval().to(device)
43
  obj_transform = transforms.Compose([transforms.ToTensor()])
44
 
45
+ # Initialize the FER emotion detector (using the FER package)
46
+ emotion_detector = FER(mtcnn=True)
 
 
 
 
47
 
48
  # Retrieve object categories from model weights metadata
49
  object_categories = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta["categories"]
50
 
51
+ # -----------------------------
52
+ # Facial Recognition Model (DINO-ViT)
53
+ # -----------------------------
54
+ facial_recognition_extractor = AutoFeatureExtractor.from_pretrained("facebook/dino-vitb16")
55
+ facial_recognition_model = AutoModel.from_pretrained("facebook/dino-vitb16")
56
+ facial_recognition_model.to(device)
57
+ facial_recognition_model.eval()
58
+
59
  # -----------------------------
60
  # Overlay Drawing Functions
61
  # -----------------------------
 
101
  return landmarks, text
102
 
103
  def compute_emotion_overlay(image):
104
+ # Use the FER package (exactly as in your provided code)
 
 
 
105
  frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
106
  frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
107
  frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
108
+ emotions = emotion_detector.detect_emotions(frame_rgb_small)
109
+ if emotions:
110
+ top_emotion, score = max(emotions[0]["emotions"].items(), key=lambda x: x[1])
111
+ text = f"{top_emotion} ({score:.2f})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  else:
113
  text = "No face detected"
114
  return text
 
155
  text = "No faces detected"
156
  return boxes, text
157
 
158
+ def compute_facial_recognition_vector(image):
159
+ """
160
+ Detects a face using MediaPipe, crops it, and computes its embedding vector
161
+ using facebook/dino-vitb16. The raw vector is returned as a string.
162
+ """
163
+ frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
164
+ frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
165
+ frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
166
+ face_results = face_detection.process(frame_rgb_small)
167
+ if face_results.detections:
168
+ detection = face_results.detections[0]
169
+ bbox = detection.location_data.relative_bounding_box
170
+ h, w, _ = frame_rgb_small.shape
171
+ x = int(bbox.xmin * w)
172
+ y = int(bbox.ymin * h)
173
+ box_w = int(bbox.width * w)
174
+ box_h = int(bbox.height * h)
175
+ face_crop = frame_rgb_small[y:y+box_h, x:x+box_w]
176
+ face_image = Image.fromarray(face_crop)
177
+ inputs = facial_recognition_extractor(face_image, return_tensors="pt").to(device)
178
+ with torch.no_grad():
179
+ outputs = facial_recognition_model(**inputs)
180
+ # Mean pooling of the last hidden state to obtain a vector representation
181
+ vector = outputs.last_hidden_state.mean(dim=1).squeeze()
182
+ vector_np = vector.cpu().numpy()
183
+ # Format vector as a string with limited decimal places
184
+ vector_str = np.array2string(vector_np, precision=2, separator=',')
185
+ return face_crop, vector_str
186
+ else:
187
+ return np.array(image), "No face detected"
188
+
189
  # -----------------------------
190
  # Main Analysis Functions for Single Image
191
  # -----------------------------
 
239
  output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
240
  return output, f"<div style='color: lime !important;'>Face Detection: {faces_cache['text']}</div>"
241
 
242
+ def analyze_facial_recognition(image):
243
+ # Compute and return the facial vector (and the cropped face)
244
+ face_crop, vector_str = compute_facial_recognition_vector(image)
245
+ return face_crop, f"<div style='color: lime !important;'>Facial Vector: {vector_str}</div>"
246
+
247
  def analyze_all(image):
248
  current_frame = np.array(image).copy()
249
  # Posture Analysis
 
323
  inputs=gr.Image(label="Upload an Image for Emotion Analysis"),
324
  outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Emotion Analysis")],
325
  title="Emotion",
326
+ description="Detects facial emotions using FER.",
327
  live=False
328
  )
329
 
 
345
  live=False
346
  )
347
 
348
+ facial_recognition_interface = gr.Interface(
349
+ fn=analyze_facial_recognition,
350
+ inputs=gr.Image(label="Upload a Face Image for Facial Recognition"),
351
+ outputs=[gr.Image(type="numpy", label="Cropped Face"), gr.HTML(label="Facial Recognition")],
352
+ title="Facial Recognition",
353
+ description="Extracts and outputs the facial vector using facebook/dino-vitb16.",
354
+ live=False
355
+ )
356
+
357
  all_interface = gr.Interface(
358
  fn=analyze_all,
359
  inputs=gr.Image(label="Upload an Image for All Inferences"),
 
364
  )
365
 
366
  tabbed_interface = gr.TabbedInterface(
367
+ interface_list=[
368
+ posture_interface,
369
+ emotion_interface,
370
+ objects_interface,
371
+ faces_interface,
372
+ facial_recognition_interface,
373
+ all_interface
374
+ ],
375
+ tab_names=[
376
+ "Posture",
377
+ "Emotion",
378
+ "Objects",
379
+ "Faces",
380
+ "Facial Recognition",
381
+ "All Inferences"
382
+ ]
383
  )
384
 
385
  # -----------------------------
 
388
  demo = gr.Blocks(css=custom_css)
389
  with demo:
390
  gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis Image App</h1>")
391
+ gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run high-tech analysis for posture, emotions, objects, faces, and facial embeddings.</p>")
392
  tabbed_interface.render()
393
 
394
  if __name__ == "__main__":