Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
-
# Enhanced Face-Based Lab Test Predictor with AI Models for 30 Lab Metrics
|
2 |
-
|
3 |
import gradio as gr
|
4 |
import cv2
|
5 |
import numpy as np
|
6 |
import mediapipe as mp
|
7 |
from sklearn.linear_model import LinearRegression
|
8 |
import random
|
|
|
9 |
|
|
|
10 |
mp_face_mesh = mp.solutions.face_mesh
|
11 |
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
|
12 |
|
|
|
13 |
def extract_features(image, landmarks):
|
14 |
red_channel = image[:, :, 2]
|
15 |
green_channel = image[:, :, 1]
|
@@ -21,6 +22,7 @@ def extract_features(image, landmarks):
|
|
21 |
|
22 |
return [red_percent, green_percent, blue_percent]
|
23 |
|
|
|
24 |
def train_model(output_range):
|
25 |
X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
|
26 |
random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
|
@@ -29,14 +31,12 @@ def train_model(output_range):
|
|
29 |
model = LinearRegression().fit(X, y)
|
30 |
return model
|
31 |
|
32 |
-
|
33 |
hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
|
34 |
-
|
35 |
-
hemoglobin_r2 = 0.385
|
36 |
-
import joblib
|
37 |
spo2_model = joblib.load("spo2_model_simulated.pkl")
|
38 |
hr_model = joblib.load("heart_rate_model.pkl")
|
39 |
|
|
|
40 |
models = {
|
41 |
"Hemoglobin": hemoglobin_model,
|
42 |
"WBC Count": train_model((4.0, 11.0)),
|
@@ -59,6 +59,7 @@ models = {
|
|
59 |
"Temperature": train_model((97, 99))
|
60 |
}
|
61 |
|
|
|
62 |
def get_risk_color(value, normal_range):
|
63 |
low, high = normal_range
|
64 |
if value < low:
|
@@ -68,6 +69,7 @@ def get_risk_color(value, normal_range):
|
|
68 |
else:
|
69 |
return ("Normal", "✅", "#CCFFCC")
|
70 |
|
|
|
71 |
def build_table(title, rows):
|
72 |
html = (
|
73 |
f'<div style="margin-bottom: 24px;">'
|
@@ -81,9 +83,9 @@ def build_table(title, rows):
|
|
81 |
html += '</tbody></table></div>'
|
82 |
return html
|
83 |
|
|
|
84 |
def analyze_video(video_path):
|
85 |
import matplotlib.pyplot as plt
|
86 |
-
from PIL import Image
|
87 |
cap = cv2.VideoCapture(video_path)
|
88 |
brightness_vals = []
|
89 |
green_vals = []
|
@@ -99,135 +101,26 @@ def analyze_video(video_path):
|
|
99 |
brightness_vals.append(np.mean(gray))
|
100 |
green_vals.append(np.mean(green))
|
101 |
cap.release()
|
102 |
-
|
|
|
103 |
brightness_std = np.std(brightness_vals) / 255
|
104 |
green_std = np.std(green_vals) / 255
|
105 |
tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
|
106 |
hr_features = [brightness_std, green_std, tone_index]
|
107 |
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
108 |
-
|
109 |
-
brightness_variation = np.std(cv2.cvtColor(frame_sample, cv2.COLOR_BGR2GRAY)) / 255
|
110 |
-
spo2_features = [heart_rate, brightness_variation, skin_tone_index]
|
111 |
-
spo2 = spo2_model.predict([spo2_features])[0]
|
112 |
-
rr = int(12 + abs(heart_rate % 5 - 2))
|
113 |
-
plt.figure(figsize=(6, 2))
|
114 |
-
plt.plot(brightness_vals, label='rPPG Signal')
|
115 |
-
plt.title("Simulated rPPG Signal")
|
116 |
-
plt.xlabel("Frame")
|
117 |
-
plt.ylabel("Brightness")
|
118 |
-
plt.legend()
|
119 |
-
plt.tight_layout()
|
120 |
-
plot_path = "/tmp/ppg_plot.png"
|
121 |
-
plt.savefig(plot_path)
|
122 |
-
plt.close()
|
123 |
-
# Reuse frame_sample for full analysis
|
124 |
-
frame_rgb = cv2.cvtColor(frame_sample, cv2.COLOR_BGR2RGB)
|
125 |
-
result = face_mesh.process(frame_rgb)
|
126 |
-
if not result.multi_face_landmarks:
|
127 |
-
return "<div style='color:red;'>⚠️ Face not detected in video.</div>", frame_rgb
|
128 |
-
landmarks = result.multi_face_landmarks[0].landmark
|
129 |
-
features = extract_features(frame_rgb, landmarks)
|
130 |
-
test_values = {}
|
131 |
-
r2_scores = {}
|
132 |
-
for label in models:
|
133 |
-
if label == "Hemoglobin":
|
134 |
-
prediction = models[label].predict([features])[0]
|
135 |
-
test_values[label] = prediction
|
136 |
-
r2_scores[label] = hemoglobin_r2
|
137 |
-
else:
|
138 |
-
value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
|
139 |
-
test_values[label] = value
|
140 |
-
r2_scores[label] = 0.0
|
141 |
-
html_output = "".join([
|
142 |
-
f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
|
143 |
-
build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
|
144 |
-
build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
|
145 |
-
build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
|
146 |
-
build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
|
147 |
-
build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
|
148 |
-
build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
|
149 |
-
build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
|
150 |
-
])
|
151 |
-
summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
|
152 |
-
summary += "<h4>📝 Summary for You</h4><ul>"
|
153 |
-
if test_values["Hemoglobin"] < 13.5:
|
154 |
-
summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
|
155 |
-
if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
|
156 |
-
summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
|
157 |
-
if test_values["Bilirubin"] > 1.2:
|
158 |
-
summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
|
159 |
-
if test_values["HbA1c"] > 5.7:
|
160 |
-
summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
|
161 |
-
if spo2 < 95:
|
162 |
-
summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
|
163 |
-
summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
|
164 |
-
html_output += summary
|
165 |
-
html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
|
166 |
-
html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
|
167 |
-
html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
|
168 |
-
return html_output, frame_rgb
|
169 |
-
|
170 |
-
def analyze_face(image):
|
171 |
-
if image is None:
|
172 |
-
return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None
|
173 |
-
frame_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
174 |
-
result = face_mesh.process(frame_rgb)
|
175 |
-
if not result.multi_face_landmarks:
|
176 |
-
return "<div style='color:red;'>⚠️ Error: Face not detected.</div>", None
|
177 |
-
landmarks = result.multi_face_landmarks[0].landmark
|
178 |
-
features = extract_features(frame_rgb, landmarks)
|
179 |
-
test_values = {}
|
180 |
-
r2_scores = {}
|
181 |
-
for label in models:
|
182 |
-
if label == "Hemoglobin":
|
183 |
-
prediction = models[label].predict([features])[0]
|
184 |
-
test_values[label] = prediction
|
185 |
-
r2_scores[label] = hemoglobin_r2
|
186 |
-
else:
|
187 |
-
value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
|
188 |
-
test_values[label] = value
|
189 |
-
r2_scores[label] = 0.0 # simulate other 7D inputs
|
190 |
-
gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)
|
191 |
-
green_std = np.std(frame_rgb[:, :, 1]) / 255
|
192 |
-
brightness_std = np.std(gray) / 255
|
193 |
-
tone_index = np.mean(frame_rgb[100:150, 100:150]) / 255 if frame_rgb[100:150, 100:150].size else 0.5
|
194 |
-
hr_features = [brightness_std, green_std, tone_index]
|
195 |
-
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
196 |
-
skin_patch = frame_rgb[100:150, 100:150]
|
197 |
-
skin_tone_index = np.mean(skin_patch) / 255 if skin_patch.size else 0.5
|
198 |
-
brightness_variation = np.std(cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)) / 255
|
199 |
-
spo2_features = [heart_rate, brightness_variation, skin_tone_index]
|
200 |
spo2 = spo2_model.predict([spo2_features])[0]
|
201 |
-
|
|
|
202 |
html_output = "".join([
|
203 |
-
|
204 |
-
build_table("
|
205 |
-
build_table("
|
206 |
-
build_table("
|
207 |
-
build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
|
208 |
-
build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
|
209 |
-
build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
|
210 |
-
build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
|
211 |
])
|
212 |
-
|
213 |
-
summary += "<h4>📝 Summary for You</h4><ul>"
|
214 |
-
if test_values["Hemoglobin"] < 13.5:
|
215 |
-
summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
|
216 |
-
if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
|
217 |
-
summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
|
218 |
-
if test_values["Bilirubin"] > 1.2:
|
219 |
-
summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
|
220 |
-
if test_values["HbA1c"] > 5.7:
|
221 |
-
summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
|
222 |
-
if spo2 < 95:
|
223 |
-
summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
|
224 |
-
summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
|
225 |
-
html_output += summary
|
226 |
-
html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
|
227 |
-
html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
|
228 |
-
html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
|
229 |
-
return html_output, frame_rgb
|
230 |
|
|
|
231 |
with gr.Blocks() as demo:
|
232 |
gr.Markdown("""
|
233 |
# 🧠 Face-Based Lab Test AI Report (Video Mode)
|
@@ -244,7 +137,7 @@ with gr.Blocks() as demo:
|
|
244 |
result_image = gr.Image(label="📷 Key Frame Snapshot")
|
245 |
|
246 |
def route_inputs(mode, image, video):
|
247 |
-
return analyze_video(video) if mode == "Video" else
|
248 |
|
249 |
submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
|
250 |
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
import mediapipe as mp
|
5 |
from sklearn.linear_model import LinearRegression
|
6 |
import random
|
7 |
+
import joblib
|
8 |
|
9 |
+
# Setup for Face Mesh detection
|
10 |
mp_face_mesh = mp.solutions.face_mesh
|
11 |
face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
|
12 |
|
13 |
+
# Function to extract color features from the image
|
14 |
def extract_features(image, landmarks):
|
15 |
red_channel = image[:, :, 2]
|
16 |
green_channel = image[:, :, 1]
|
|
|
22 |
|
23 |
return [red_percent, green_percent, blue_percent]
|
24 |
|
25 |
+
# Mock models training (for demonstration)
|
26 |
def train_model(output_range):
|
27 |
X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
|
28 |
random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
|
|
|
31 |
model = LinearRegression().fit(X, y)
|
32 |
return model
|
33 |
|
34 |
+
# Load pre-trained models for Hemoglobin, SPO2, and Heart Rate
|
35 |
hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
|
|
|
|
|
|
|
36 |
spo2_model = joblib.load("spo2_model_simulated.pkl")
|
37 |
hr_model = joblib.load("heart_rate_model.pkl")
|
38 |
|
39 |
+
# Model dictionary setup for other tests
|
40 |
models = {
|
41 |
"Hemoglobin": hemoglobin_model,
|
42 |
"WBC Count": train_model((4.0, 11.0)),
|
|
|
59 |
"Temperature": train_model((97, 99))
|
60 |
}
|
61 |
|
62 |
+
# Function to determine risk level
|
63 |
def get_risk_color(value, normal_range):
|
64 |
low, high = normal_range
|
65 |
if value < low:
|
|
|
69 |
else:
|
70 |
return ("Normal", "✅", "#CCFFCC")
|
71 |
|
72 |
+
# Function to build an HTML table for displaying test results
|
73 |
def build_table(title, rows):
|
74 |
html = (
|
75 |
f'<div style="margin-bottom: 24px;">'
|
|
|
83 |
html += '</tbody></table></div>'
|
84 |
return html
|
85 |
|
86 |
+
# Analyzing video for health metrics
|
87 |
def analyze_video(video_path):
|
88 |
import matplotlib.pyplot as plt
|
|
|
89 |
cap = cv2.VideoCapture(video_path)
|
90 |
brightness_vals = []
|
91 |
green_vals = []
|
|
|
101 |
brightness_vals.append(np.mean(gray))
|
102 |
green_vals.append(np.mean(green))
|
103 |
cap.release()
|
104 |
+
|
105 |
+
# Simulate heart rate and SPO2 estimation
|
106 |
brightness_std = np.std(brightness_vals) / 255
|
107 |
green_std = np.std(green_vals) / 255
|
108 |
tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
|
109 |
hr_features = [brightness_std, green_std, tone_index]
|
110 |
heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
|
111 |
+
spo2_features = [heart_rate, np.std(brightness_vals), np.mean(frame_sample[100:150, 100:150])]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
spo2 = spo2_model.predict([spo2_features])[0]
|
113 |
+
|
114 |
+
# Generating the health card with test results
|
115 |
html_output = "".join([
|
116 |
+
build_table("🩸 Hematology", [("Hemoglobin", models["Hemoglobin"].predict([hr_features])[0], (13.5, 17.5))]),
|
117 |
+
build_table("🧬 Iron Panel", [("Iron", models["Iron"].predict([hr_features])[0], (60, 170))]),
|
118 |
+
build_table("🧪 Electrolytes", [("Sodium", models["Sodium"].predict([hr_features])[0], (135, 145))]),
|
119 |
+
build_table("❤️ Vitals", [("Heart Rate", heart_rate, (60, 100)), ("SpO2", spo2, (95, 100))]),
|
|
|
|
|
|
|
|
|
120 |
])
|
121 |
+
return html_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
+
# Gradio Interface setup
|
124 |
with gr.Blocks() as demo:
|
125 |
gr.Markdown("""
|
126 |
# 🧠 Face-Based Lab Test AI Report (Video Mode)
|
|
|
137 |
result_image = gr.Image(label="📷 Key Frame Snapshot")
|
138 |
|
139 |
def route_inputs(mode, image, video):
|
140 |
+
return analyze_video(video) if mode == "Video" else analyze_video(image)
|
141 |
|
142 |
submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
|
143 |
|