David Driscoll
commited on
Commit
·
473b2d5
1
Parent(s):
f3de933
fix emotion, output vector
Browse files
app.py
CHANGED
@@ -6,9 +6,8 @@ from torchvision import models, transforms
|
|
6 |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
7 |
from PIL import Image
|
8 |
import mediapipe as mp
|
9 |
-
|
10 |
-
|
11 |
-
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
12 |
|
13 |
# -----------------------------
|
14 |
# Configuration
|
@@ -28,6 +27,7 @@ faces_cache = {"boxes": None, "text": "Initializing...", "counter": 0}
|
|
28 |
# -----------------------------
|
29 |
# Initialize Models and Helpers
|
30 |
# -----------------------------
|
|
|
31 |
mp_pose = mp.solutions.pose
|
32 |
pose = mp_pose.Pose()
|
33 |
mp_drawing = mp.solutions.drawing_utils
|
@@ -35,22 +35,27 @@ mp_drawing = mp.solutions.drawing_utils
|
|
35 |
mp_face_detection = mp.solutions.face_detection
|
36 |
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
|
37 |
|
|
|
38 |
object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
|
39 |
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
40 |
)
|
41 |
object_detection_model.eval().to(device)
|
42 |
obj_transform = transforms.Compose([transforms.ToTensor()])
|
43 |
|
44 |
-
# Initialize the
|
45 |
-
|
46 |
-
emotion_processor = AutoImageProcessor.from_pretrained("nateraw/fer")
|
47 |
-
emotion_model = AutoModelForImageClassification.from_pretrained("nateraw/fer")
|
48 |
-
emotion_model.to(device)
|
49 |
-
emotion_model.eval()
|
50 |
|
51 |
# Retrieve object categories from model weights metadata
|
52 |
object_categories = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta["categories"]
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# -----------------------------
|
55 |
# Overlay Drawing Functions
|
56 |
# -----------------------------
|
@@ -96,36 +101,14 @@ def compute_posture_overlay(image):
|
|
96 |
return landmarks, text
|
97 |
|
98 |
def compute_emotion_overlay(image):
|
99 |
-
|
100 |
-
This function mimics the original FER-based expression recognition,
|
101 |
-
but uses a Hugging Face emotion model instead.
|
102 |
-
"""
|
103 |
frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
104 |
frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
|
105 |
frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
detection = face_results.detections[0]
|
111 |
-
bbox = detection.location_data.relative_bounding_box
|
112 |
-
h, w, _ = frame_rgb_small.shape
|
113 |
-
x = int(bbox.xmin * w)
|
114 |
-
y = int(bbox.ymin * h)
|
115 |
-
box_w = int(bbox.width * w)
|
116 |
-
box_h = int(bbox.height * h)
|
117 |
-
face_crop = frame_rgb_small[y:y+box_h, x:x+box_w]
|
118 |
-
face_image = Image.fromarray(face_crop)
|
119 |
-
|
120 |
-
# Process face crop with the Hugging Face emotion model
|
121 |
-
inputs = emotion_processor(face_image, return_tensors="pt").to(device)
|
122 |
-
with torch.no_grad():
|
123 |
-
outputs = emotion_model(**inputs)
|
124 |
-
logits = outputs.logits
|
125 |
-
probs = torch.softmax(logits, dim=-1)
|
126 |
-
score, pred = torch.max(probs, dim=-1)
|
127 |
-
label = emotion_model.config.id2label[pred.item()]
|
128 |
-
text = f"{label} ({score.item():.2f})"
|
129 |
else:
|
130 |
text = "No face detected"
|
131 |
return text
|
@@ -172,6 +155,37 @@ def compute_faces_overlay(image):
|
|
172 |
text = "No faces detected"
|
173 |
return boxes, text
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
# -----------------------------
|
176 |
# Main Analysis Functions for Single Image
|
177 |
# -----------------------------
|
@@ -225,6 +239,11 @@ def analyze_faces_current(image):
|
|
225 |
output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
|
226 |
return output, f"<div style='color: lime !important;'>Face Detection: {faces_cache['text']}</div>"
|
227 |
|
|
|
|
|
|
|
|
|
|
|
228 |
def analyze_all(image):
|
229 |
current_frame = np.array(image).copy()
|
230 |
# Posture Analysis
|
@@ -304,7 +323,7 @@ emotion_interface = gr.Interface(
|
|
304 |
inputs=gr.Image(label="Upload an Image for Emotion Analysis"),
|
305 |
outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Emotion Analysis")],
|
306 |
title="Emotion",
|
307 |
-
description="Detects facial emotions using
|
308 |
live=False
|
309 |
)
|
310 |
|
@@ -326,6 +345,15 @@ faces_interface = gr.Interface(
|
|
326 |
live=False
|
327 |
)
|
328 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
329 |
all_interface = gr.Interface(
|
330 |
fn=analyze_all,
|
331 |
inputs=gr.Image(label="Upload an Image for All Inferences"),
|
@@ -336,8 +364,22 @@ all_interface = gr.Interface(
|
|
336 |
)
|
337 |
|
338 |
tabbed_interface = gr.TabbedInterface(
|
339 |
-
interface_list=[
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
)
|
342 |
|
343 |
# -----------------------------
|
@@ -346,7 +388,7 @@ tabbed_interface = gr.TabbedInterface(
|
|
346 |
demo = gr.Blocks(css=custom_css)
|
347 |
with demo:
|
348 |
gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis Image App</h1>")
|
349 |
-
gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run high-tech analysis for posture, emotions, objects, and
|
350 |
tabbed_interface.render()
|
351 |
|
352 |
if __name__ == "__main__":
|
|
|
6 |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
|
7 |
from PIL import Image
|
8 |
import mediapipe as mp
|
9 |
+
from fer import FER # Facial emotion recognition
|
10 |
+
from transformers import AutoFeatureExtractor, AutoModel
|
|
|
11 |
|
12 |
# -----------------------------
|
13 |
# Configuration
|
|
|
27 |
# -----------------------------
|
28 |
# Initialize Models and Helpers
|
29 |
# -----------------------------
|
30 |
+
# MediaPipe Pose and Face Detection
|
31 |
mp_pose = mp.solutions.pose
|
32 |
pose = mp_pose.Pose()
|
33 |
mp_drawing = mp.solutions.drawing_utils
|
|
|
35 |
mp_face_detection = mp.solutions.face_detection
|
36 |
face_detection = mp_face_detection.FaceDetection(min_detection_confidence=0.5)
|
37 |
|
38 |
+
# Object Detection using Faster R-CNN
|
39 |
object_detection_model = models.detection.fasterrcnn_resnet50_fpn(
|
40 |
weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
|
41 |
)
|
42 |
object_detection_model.eval().to(device)
|
43 |
obj_transform = transforms.Compose([transforms.ToTensor()])
|
44 |
|
45 |
+
# Initialize the FER emotion detector (using the FER package)
|
46 |
+
emotion_detector = FER(mtcnn=True)
|
|
|
|
|
|
|
|
|
47 |
|
48 |
# Retrieve object categories from model weights metadata
|
49 |
object_categories = FasterRCNN_ResNet50_FPN_Weights.DEFAULT.meta["categories"]
|
50 |
|
51 |
+
# -----------------------------
|
52 |
+
# Facial Recognition Model (DINO-ViT)
|
53 |
+
# -----------------------------
|
54 |
+
facial_recognition_extractor = AutoFeatureExtractor.from_pretrained("facebook/dino-vitb16")
|
55 |
+
facial_recognition_model = AutoModel.from_pretrained("facebook/dino-vitb16")
|
56 |
+
facial_recognition_model.to(device)
|
57 |
+
facial_recognition_model.eval()
|
58 |
+
|
59 |
# -----------------------------
|
60 |
# Overlay Drawing Functions
|
61 |
# -----------------------------
|
|
|
101 |
return landmarks, text
|
102 |
|
103 |
def compute_emotion_overlay(image):
|
104 |
+
# Use the FER package (exactly as in your provided code)
|
|
|
|
|
|
|
105 |
frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
106 |
frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
|
107 |
frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
|
108 |
+
emotions = emotion_detector.detect_emotions(frame_rgb_small)
|
109 |
+
if emotions:
|
110 |
+
top_emotion, score = max(emotions[0]["emotions"].items(), key=lambda x: x[1])
|
111 |
+
text = f"{top_emotion} ({score:.2f})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
text = "No face detected"
|
114 |
return text
|
|
|
155 |
text = "No faces detected"
|
156 |
return boxes, text
|
157 |
|
158 |
+
def compute_facial_recognition_vector(image):
|
159 |
+
"""
|
160 |
+
Detects a face using MediaPipe, crops it, and computes its embedding vector
|
161 |
+
using facebook/dino-vitb16. The raw vector is returned as a string.
|
162 |
+
"""
|
163 |
+
frame_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
164 |
+
frame_bgr_small = cv2.resize(frame_bgr, DESIRED_SIZE)
|
165 |
+
frame_rgb_small = cv2.cvtColor(frame_bgr_small, cv2.COLOR_BGR2RGB)
|
166 |
+
face_results = face_detection.process(frame_rgb_small)
|
167 |
+
if face_results.detections:
|
168 |
+
detection = face_results.detections[0]
|
169 |
+
bbox = detection.location_data.relative_bounding_box
|
170 |
+
h, w, _ = frame_rgb_small.shape
|
171 |
+
x = int(bbox.xmin * w)
|
172 |
+
y = int(bbox.ymin * h)
|
173 |
+
box_w = int(bbox.width * w)
|
174 |
+
box_h = int(bbox.height * h)
|
175 |
+
face_crop = frame_rgb_small[y:y+box_h, x:x+box_w]
|
176 |
+
face_image = Image.fromarray(face_crop)
|
177 |
+
inputs = facial_recognition_extractor(face_image, return_tensors="pt").to(device)
|
178 |
+
with torch.no_grad():
|
179 |
+
outputs = facial_recognition_model(**inputs)
|
180 |
+
# Mean pooling of the last hidden state to obtain a vector representation
|
181 |
+
vector = outputs.last_hidden_state.mean(dim=1).squeeze()
|
182 |
+
vector_np = vector.cpu().numpy()
|
183 |
+
# Format vector as a string with limited decimal places
|
184 |
+
vector_str = np.array2string(vector_np, precision=2, separator=',')
|
185 |
+
return face_crop, vector_str
|
186 |
+
else:
|
187 |
+
return np.array(image), "No face detected"
|
188 |
+
|
189 |
# -----------------------------
|
190 |
# Main Analysis Functions for Single Image
|
191 |
# -----------------------------
|
|
|
239 |
output = draw_boxes_overlay(output, faces_cache["boxes"], (0, 0, 255))
|
240 |
return output, f"<div style='color: lime !important;'>Face Detection: {faces_cache['text']}</div>"
|
241 |
|
242 |
+
def analyze_facial_recognition(image):
|
243 |
+
# Compute and return the facial vector (and the cropped face)
|
244 |
+
face_crop, vector_str = compute_facial_recognition_vector(image)
|
245 |
+
return face_crop, f"<div style='color: lime !important;'>Facial Vector: {vector_str}</div>"
|
246 |
+
|
247 |
def analyze_all(image):
|
248 |
current_frame = np.array(image).copy()
|
249 |
# Posture Analysis
|
|
|
323 |
inputs=gr.Image(label="Upload an Image for Emotion Analysis"),
|
324 |
outputs=[gr.Image(type="numpy", label="Annotated Output"), gr.HTML(label="Emotion Analysis")],
|
325 |
title="Emotion",
|
326 |
+
description="Detects facial emotions using FER.",
|
327 |
live=False
|
328 |
)
|
329 |
|
|
|
345 |
live=False
|
346 |
)
|
347 |
|
348 |
+
facial_recognition_interface = gr.Interface(
|
349 |
+
fn=analyze_facial_recognition,
|
350 |
+
inputs=gr.Image(label="Upload a Face Image for Facial Recognition"),
|
351 |
+
outputs=[gr.Image(type="numpy", label="Cropped Face"), gr.HTML(label="Facial Recognition")],
|
352 |
+
title="Facial Recognition",
|
353 |
+
description="Extracts and outputs the facial vector using facebook/dino-vitb16.",
|
354 |
+
live=False
|
355 |
+
)
|
356 |
+
|
357 |
all_interface = gr.Interface(
|
358 |
fn=analyze_all,
|
359 |
inputs=gr.Image(label="Upload an Image for All Inferences"),
|
|
|
364 |
)
|
365 |
|
366 |
tabbed_interface = gr.TabbedInterface(
|
367 |
+
interface_list=[
|
368 |
+
posture_interface,
|
369 |
+
emotion_interface,
|
370 |
+
objects_interface,
|
371 |
+
faces_interface,
|
372 |
+
facial_recognition_interface,
|
373 |
+
all_interface
|
374 |
+
],
|
375 |
+
tab_names=[
|
376 |
+
"Posture",
|
377 |
+
"Emotion",
|
378 |
+
"Objects",
|
379 |
+
"Faces",
|
380 |
+
"Facial Recognition",
|
381 |
+
"All Inferences"
|
382 |
+
]
|
383 |
)
|
384 |
|
385 |
# -----------------------------
|
|
|
388 |
demo = gr.Blocks(css=custom_css)
|
389 |
with demo:
|
390 |
gr.Markdown("<h1 class='gradio-title' style='color: #32CD32;'>Multi-Analysis Image App</h1>")
|
391 |
+
gr.Markdown("<p class='gradio-description' style='color: #32CD32;'>Upload an image to run high-tech analysis for posture, emotions, objects, faces, and facial embeddings.</p>")
|
392 |
tabbed_interface.render()
|
393 |
|
394 |
if __name__ == "__main__":
|