Rammohan0504 commited on
Commit
420f765
·
verified ·
1 Parent(s): a70b9fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -128
app.py CHANGED
@@ -1,15 +1,16 @@
1
- # Enhanced Face-Based Lab Test Predictor with AI Models for 30 Lab Metrics
2
-
3
  import gradio as gr
4
  import cv2
5
  import numpy as np
6
  import mediapipe as mp
7
  from sklearn.linear_model import LinearRegression
8
  import random
 
9
 
 
10
  mp_face_mesh = mp.solutions.face_mesh
11
  face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
12
 
 
13
  def extract_features(image, landmarks):
14
  red_channel = image[:, :, 2]
15
  green_channel = image[:, :, 1]
@@ -21,6 +22,7 @@ def extract_features(image, landmarks):
21
 
22
  return [red_percent, green_percent, blue_percent]
23
 
 
24
  def train_model(output_range):
25
  X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
26
  random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
@@ -29,14 +31,12 @@ def train_model(output_range):
29
  model = LinearRegression().fit(X, y)
30
  return model
31
 
32
- import joblib
33
  hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
34
-
35
- hemoglobin_r2 = 0.385
36
- import joblib
37
  spo2_model = joblib.load("spo2_model_simulated.pkl")
38
  hr_model = joblib.load("heart_rate_model.pkl")
39
 
 
40
  models = {
41
  "Hemoglobin": hemoglobin_model,
42
  "WBC Count": train_model((4.0, 11.0)),
@@ -59,6 +59,7 @@ models = {
59
  "Temperature": train_model((97, 99))
60
  }
61
 
 
62
  def get_risk_color(value, normal_range):
63
  low, high = normal_range
64
  if value < low:
@@ -68,6 +69,7 @@ def get_risk_color(value, normal_range):
68
  else:
69
  return ("Normal", "✅", "#CCFFCC")
70
 
 
71
  def build_table(title, rows):
72
  html = (
73
  f'<div style="margin-bottom: 24px;">'
@@ -81,9 +83,9 @@ def build_table(title, rows):
81
  html += '</tbody></table></div>'
82
  return html
83
 
 
84
  def analyze_video(video_path):
85
  import matplotlib.pyplot as plt
86
- from PIL import Image
87
  cap = cv2.VideoCapture(video_path)
88
  brightness_vals = []
89
  green_vals = []
@@ -99,135 +101,26 @@ def analyze_video(video_path):
99
  brightness_vals.append(np.mean(gray))
100
  green_vals.append(np.mean(green))
101
  cap.release()
102
- # simulate HR via std deviation signal
 
103
  brightness_std = np.std(brightness_vals) / 255
104
  green_std = np.std(green_vals) / 255
105
  tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
106
  hr_features = [brightness_std, green_std, tone_index]
107
  heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
108
- skin_tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
109
- brightness_variation = np.std(cv2.cvtColor(frame_sample, cv2.COLOR_BGR2GRAY)) / 255
110
- spo2_features = [heart_rate, brightness_variation, skin_tone_index]
111
- spo2 = spo2_model.predict([spo2_features])[0]
112
- rr = int(12 + abs(heart_rate % 5 - 2))
113
- plt.figure(figsize=(6, 2))
114
- plt.plot(brightness_vals, label='rPPG Signal')
115
- plt.title("Simulated rPPG Signal")
116
- plt.xlabel("Frame")
117
- plt.ylabel("Brightness")
118
- plt.legend()
119
- plt.tight_layout()
120
- plot_path = "/tmp/ppg_plot.png"
121
- plt.savefig(plot_path)
122
- plt.close()
123
- # Reuse frame_sample for full analysis
124
- frame_rgb = cv2.cvtColor(frame_sample, cv2.COLOR_BGR2RGB)
125
- result = face_mesh.process(frame_rgb)
126
- if not result.multi_face_landmarks:
127
- return "<div style='color:red;'>⚠️ Face not detected in video.</div>", frame_rgb
128
- landmarks = result.multi_face_landmarks[0].landmark
129
- features = extract_features(frame_rgb, landmarks)
130
- test_values = {}
131
- r2_scores = {}
132
- for label in models:
133
- if label == "Hemoglobin":
134
- prediction = models[label].predict([features])[0]
135
- test_values[label] = prediction
136
- r2_scores[label] = hemoglobin_r2
137
- else:
138
- value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
139
- test_values[label] = value
140
- r2_scores[label] = 0.0
141
- html_output = "".join([
142
- f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
143
- build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
144
- build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
145
- build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
146
- build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
147
- build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
148
- build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
149
- build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
150
- ])
151
- summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
152
- summary += "<h4>📝 Summary for You</h4><ul>"
153
- if test_values["Hemoglobin"] < 13.5:
154
- summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
155
- if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
156
- summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
157
- if test_values["Bilirubin"] > 1.2:
158
- summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
159
- if test_values["HbA1c"] > 5.7:
160
- summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
161
- if spo2 < 95:
162
- summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
163
- summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
164
- html_output += summary
165
- html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
166
- html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
167
- html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
168
- return html_output, frame_rgb
169
-
170
- def analyze_face(image):
171
- if image is None:
172
- return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None
173
- frame_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
174
- result = face_mesh.process(frame_rgb)
175
- if not result.multi_face_landmarks:
176
- return "<div style='color:red;'>⚠️ Error: Face not detected.</div>", None
177
- landmarks = result.multi_face_landmarks[0].landmark
178
- features = extract_features(frame_rgb, landmarks)
179
- test_values = {}
180
- r2_scores = {}
181
- for label in models:
182
- if label == "Hemoglobin":
183
- prediction = models[label].predict([features])[0]
184
- test_values[label] = prediction
185
- r2_scores[label] = hemoglobin_r2
186
- else:
187
- value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
188
- test_values[label] = value
189
- r2_scores[label] = 0.0 # simulate other 7D inputs
190
- gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)
191
- green_std = np.std(frame_rgb[:, :, 1]) / 255
192
- brightness_std = np.std(gray) / 255
193
- tone_index = np.mean(frame_rgb[100:150, 100:150]) / 255 if frame_rgb[100:150, 100:150].size else 0.5
194
- hr_features = [brightness_std, green_std, tone_index]
195
- heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
196
- skin_patch = frame_rgb[100:150, 100:150]
197
- skin_tone_index = np.mean(skin_patch) / 255 if skin_patch.size else 0.5
198
- brightness_variation = np.std(cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)) / 255
199
- spo2_features = [heart_rate, brightness_variation, skin_tone_index]
200
  spo2 = spo2_model.predict([spo2_features])[0]
201
- rr = int(12 + abs(heart_rate % 5 - 2))
 
202
  html_output = "".join([
203
- f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
204
- build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
205
- build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
206
- build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
207
- build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
208
- build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
209
- build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
210
- build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
211
  ])
212
- summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
213
- summary += "<h4>📝 Summary for You</h4><ul>"
214
- if test_values["Hemoglobin"] < 13.5:
215
- summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
216
- if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
217
- summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
218
- if test_values["Bilirubin"] > 1.2:
219
- summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
220
- if test_values["HbA1c"] > 5.7:
221
- summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
222
- if spo2 < 95:
223
- summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
224
- summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
225
- html_output += summary
226
- html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
227
- html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
228
- html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
229
- return html_output, frame_rgb
230
 
 
231
  with gr.Blocks() as demo:
232
  gr.Markdown("""
233
  # 🧠 Face-Based Lab Test AI Report (Video Mode)
@@ -244,7 +137,7 @@ with gr.Blocks() as demo:
244
  result_image = gr.Image(label="📷 Key Frame Snapshot")
245
 
246
  def route_inputs(mode, image, video):
247
- return analyze_video(video) if mode == "Video" else analyze_face(image)
248
 
249
  submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
250
 
 
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import mediapipe as mp
5
  from sklearn.linear_model import LinearRegression
6
  import random
7
+ import joblib
8
 
9
+ # Setup for Face Mesh detection
10
  mp_face_mesh = mp.solutions.face_mesh
11
  face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
12
 
13
+ # Function to extract color features from the image
14
  def extract_features(image, landmarks):
15
  red_channel = image[:, :, 2]
16
  green_channel = image[:, :, 1]
 
22
 
23
  return [red_percent, green_percent, blue_percent]
24
 
25
+ # Mock models training (for demonstration)
26
  def train_model(output_range):
27
  X = [[random.uniform(0.2, 0.5), random.uniform(0.05, 0.2), random.uniform(0.05, 0.2),
28
  random.uniform(0.2, 0.5), random.uniform(0.2, 0.5), random.uniform(0.2, 0.5),
 
31
  model = LinearRegression().fit(X, y)
32
  return model
33
 
34
+ # Load pre-trained models for Hemoglobin, SPO2, and Heart Rate
35
  hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
 
 
 
36
  spo2_model = joblib.load("spo2_model_simulated.pkl")
37
  hr_model = joblib.load("heart_rate_model.pkl")
38
 
39
+ # Model dictionary setup for other tests
40
  models = {
41
  "Hemoglobin": hemoglobin_model,
42
  "WBC Count": train_model((4.0, 11.0)),
 
59
  "Temperature": train_model((97, 99))
60
  }
61
 
62
+ # Function to determine risk level
63
  def get_risk_color(value, normal_range):
64
  low, high = normal_range
65
  if value < low:
 
69
  else:
70
  return ("Normal", "✅", "#CCFFCC")
71
 
72
+ # Function to build an HTML table for displaying test results
73
  def build_table(title, rows):
74
  html = (
75
  f'<div style="margin-bottom: 24px;">'
 
83
  html += '</tbody></table></div>'
84
  return html
85
 
86
+ # Analyzing video for health metrics
87
  def analyze_video(video_path):
88
  import matplotlib.pyplot as plt
 
89
  cap = cv2.VideoCapture(video_path)
90
  brightness_vals = []
91
  green_vals = []
 
101
  brightness_vals.append(np.mean(gray))
102
  green_vals.append(np.mean(green))
103
  cap.release()
104
+
105
+ # Simulate heart rate and SPO2 estimation
106
  brightness_std = np.std(brightness_vals) / 255
107
  green_std = np.std(green_vals) / 255
108
  tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
109
  hr_features = [brightness_std, green_std, tone_index]
110
  heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
111
+ spo2_features = [heart_rate, np.std(brightness_vals), np.mean(frame_sample[100:150, 100:150])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  spo2 = spo2_model.predict([spo2_features])[0]
113
+
114
+ # Generating the health card with test results
115
  html_output = "".join([
116
+ build_table("🩸 Hematology", [("Hemoglobin", models["Hemoglobin"].predict([hr_features])[0], (13.5, 17.5))]),
117
+ build_table("🧬 Iron Panel", [("Iron", models["Iron"].predict([hr_features])[0], (60, 170))]),
118
+ build_table("🧪 Electrolytes", [("Sodium", models["Sodium"].predict([hr_features])[0], (135, 145))]),
119
+ build_table("❤️ Vitals", [("Heart Rate", heart_rate, (60, 100)), ("SpO2", spo2, (95, 100))]),
 
 
 
 
120
  ])
121
+ return html_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Gradio Interface setup
124
  with gr.Blocks() as demo:
125
  gr.Markdown("""
126
  # 🧠 Face-Based Lab Test AI Report (Video Mode)
 
137
  result_image = gr.Image(label="📷 Key Frame Snapshot")
138
 
139
  def route_inputs(mode, image, video):
140
+ return analyze_video(video) if mode == "Video" else analyze_video(image)
141
 
142
  submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
143