Rammohan0504 commited on
Commit
70542c8
·
verified ·
1 Parent(s): a70b9fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -122
app.py CHANGED
@@ -1,15 +1,17 @@
1
- # Enhanced Face-Based Lab Test Predictor with AI Models for 30 Lab Metrics
2
-
3
  import gradio as gr
4
  import cv2
5
  import numpy as np
6
  import mediapipe as mp
7
  from sklearn.linear_model import LinearRegression
8
  import random
 
 
9
 
 
10
  mp_face_mesh = mp.solutions.face_mesh
11
  face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
12
 
 
13
  def extract_features(image, landmarks):
14
  red_channel = image[:, :, 2]
15
  green_channel = image[:, :, 1]
@@ -29,11 +31,8 @@ def train_model(output_range):
29
  model = LinearRegression().fit(X, y)
30
  return model
31
 
32
- import joblib
33
  hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
34
-
35
- hemoglobin_r2 = 0.385
36
- import joblib
37
  spo2_model = joblib.load("spo2_model_simulated.pkl")
38
  hr_model = joblib.load("heart_rate_model.pkl")
39
 
@@ -59,6 +58,7 @@ models = {
59
  "Temperature": train_model((97, 99))
60
  }
61
 
 
62
  def get_risk_color(value, normal_range):
63
  low, high = normal_range
64
  if value < low:
@@ -68,6 +68,7 @@ def get_risk_color(value, normal_range):
68
  else:
69
  return ("Normal", "✅", "#CCFFCC")
70
 
 
71
  def build_table(title, rows):
72
  html = (
73
  f'<div style="margin-bottom: 24px;">'
@@ -81,92 +82,50 @@ def build_table(title, rows):
81
  html += '</tbody></table></div>'
82
  return html
83
 
84
- def analyze_video(video_path):
85
- import matplotlib.pyplot as plt
86
- from PIL import Image
87
- cap = cv2.VideoCapture(video_path)
88
- brightness_vals = []
89
- green_vals = []
90
- frame_sample = None
91
- while True:
92
- ret, frame = cap.read()
93
- if not ret:
94
- break
95
- if frame_sample is None:
96
- frame_sample = frame.copy()
97
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
98
- green = frame[:, :, 1]
99
- brightness_vals.append(np.mean(gray))
100
- green_vals.append(np.mean(green))
101
- cap.release()
102
- # simulate HR via std deviation signal
103
- brightness_std = np.std(brightness_vals) / 255
104
- green_std = np.std(green_vals) / 255
105
- tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
106
- hr_features = [brightness_std, green_std, tone_index]
107
- heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100))
108
- skin_tone_index = np.mean(frame_sample[100:150, 100:150]) / 255 if frame_sample[100:150, 100:150].size else 0.5
109
- brightness_variation = np.std(cv2.cvtColor(frame_sample, cv2.COLOR_BGR2GRAY)) / 255
110
- spo2_features = [heart_rate, brightness_variation, skin_tone_index]
111
- spo2 = spo2_model.predict([spo2_features])[0]
112
- rr = int(12 + abs(heart_rate % 5 - 2))
113
- plt.figure(figsize=(6, 2))
114
- plt.plot(brightness_vals, label='rPPG Signal')
115
- plt.title("Simulated rPPG Signal")
116
- plt.xlabel("Frame")
117
- plt.ylabel("Brightness")
118
- plt.legend()
119
- plt.tight_layout()
120
- plot_path = "/tmp/ppg_plot.png"
121
- plt.savefig(plot_path)
122
- plt.close()
123
- # Reuse frame_sample for full analysis
124
- frame_rgb = cv2.cvtColor(frame_sample, cv2.COLOR_BGR2RGB)
125
- result = face_mesh.process(frame_rgb)
126
- if not result.multi_face_landmarks:
127
- return "<div style='color:red;'>⚠️ Face not detected in video.</div>", frame_rgb
128
- landmarks = result.multi_face_landmarks[0].landmark
129
- features = extract_features(frame_rgb, landmarks)
130
- test_values = {}
131
- r2_scores = {}
132
- for label in models:
133
- if label == "Hemoglobin":
134
- prediction = models[label].predict([features])[0]
135
- test_values[label] = prediction
136
- r2_scores[label] = hemoglobin_r2
137
- else:
138
- value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
139
- test_values[label] = value
140
- r2_scores[label] = 0.0
141
- html_output = "".join([
142
- f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
143
- build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
144
- build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
145
- build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
146
- build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
147
- build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
148
- build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
149
- build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
150
- ])
151
- summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
152
- summary += "<h4>📝 Summary for You</h4><ul>"
153
- if test_values["Hemoglobin"] < 13.5:
154
- summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
155
- if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
156
- summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
157
- if test_values["Bilirubin"] > 1.2:
158
- summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
159
- if test_values["HbA1c"] > 5.7:
160
- summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
161
- if spo2 < 95:
162
- summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
163
- summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
164
- html_output += summary
165
- html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
166
- html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
167
- html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
168
- return html_output, frame_rgb
169
 
 
170
  def analyze_face(image):
171
  if image is None:
172
  return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None
@@ -182,7 +141,7 @@ def analyze_face(image):
182
  if label == "Hemoglobin":
183
  prediction = models[label].predict([features])[0]
184
  test_values[label] = prediction
185
- r2_scores[label] = hemoglobin_r2
186
  else:
187
  value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
188
  test_values[label] = value
@@ -199,40 +158,22 @@ def analyze_face(image):
199
  spo2_features = [heart_rate, brightness_variation, skin_tone_index]
200
  spo2 = spo2_model.predict([spo2_features])[0]
201
  rr = int(12 + abs(heart_rate % 5 - 2))
202
- html_output = "".join([
203
  f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
204
  build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
205
  build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
206
  build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
207
  build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
208
- build_table("🧁 Metabolic & Thyroid", [("FBS", test_values["FBS"], (70, 110)), ("HbA1c", test_values["HbA1c"], (4.0, 5.7)), ("TSH", test_values["TSH"], (0.4, 4.0))]),
209
- build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]),
210
- build_table("🩹 Other Indicators", [("Cortisol", test_values["Cortisol"], (5, 25)), ("Albumin", test_values["Albumin"], (3.5, 5.5))])
211
  ])
212
- summary = "<div style='margin-top:20px;padding:12px;border:1px dashed #999;background:#fcfcfc;'>"
213
- summary += "<h4>📝 Summary for You</h4><ul>"
214
- if test_values["Hemoglobin"] < 13.5:
215
- summary += "<li>Your hemoglobin is a bit low — this could mean mild anemia.</li>"
216
- if test_values["Iron"] < 60 or test_values["Ferritin"] < 30:
217
- summary += "<li>Low iron storage detected — consider an iron profile test.</li>"
218
- if test_values["Bilirubin"] > 1.2:
219
- summary += "<li>Elevated bilirubin — possible jaundice. Recommend LFT.</li>"
220
- if test_values["HbA1c"] > 5.7:
221
- summary += "<li>High HbA1c — prediabetes indication. Recommend glucose check.</li>"
222
- if spo2 < 95:
223
- summary += "<li>Low SpO₂ — suggest retesting with a pulse oximeter.</li>"
224
- summary += "</ul><p><strong>💡 Tip:</strong> This is an AI-based estimate. Please follow up with a lab.</p></div>"
225
- html_output += summary
226
- html_output += "<br><div style='margin-top:20px;padding:12px;border:2px solid #2d87f0;background:#f2faff;text-align:center;border-radius:8px;'>"
227
- html_output += "<h4>📞 Book a Lab Test</h4><p>Prefer confirmation? Find certified labs near you.</p>"
228
- html_output += "<button style='padding:10px 20px;background:#007BFF;color:#fff;border:none;border-radius:5px;cursor:pointer;'>Find Labs Near Me</button></div>"
229
- return html_output, frame_rgb
230
 
 
231
  with gr.Blocks() as demo:
232
- gr.Markdown("""
233
- # 🧠 Face-Based Lab Test AI Report (Video Mode)
234
- Upload a short face video (10–30s) to infer health diagnostics using rPPG analysis.
235
- """)
236
  with gr.Row():
237
  with gr.Column():
238
  mode_selector = gr.Radio(label="Choose Input Mode", choices=["Image", "Video"], value="Image")
@@ -244,11 +185,8 @@ with gr.Blocks() as demo:
244
  result_image = gr.Image(label="📷 Key Frame Snapshot")
245
 
246
  def route_inputs(mode, image, video):
247
- return analyze_video(video) if mode == "Video" else analyze_face(image)
248
 
249
  submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
250
 
251
- gr.Markdown("""---
252
- ✅ Table Format • AI Prediction • rPPG-based HR • Dynamic Summary • Multilingual Support • CTA""")
253
-
254
  demo.launch()
 
 
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
  import mediapipe as mp
5
  from sklearn.linear_model import LinearRegression
6
  import random
7
+ import base64
8
+ import joblib
9
 
10
+ # Initialize the face mesh model
11
  mp_face_mesh = mp.solutions.face_mesh
12
  face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1, refine_landmarks=True, min_detection_confidence=0.5)
13
 
14
+ # Functions for feature extraction
15
  def extract_features(image, landmarks):
16
  red_channel = image[:, :, 2]
17
  green_channel = image[:, :, 1]
 
31
  model = LinearRegression().fit(X, y)
32
  return model
33
 
34
+ # Load models
35
  hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl")
 
 
 
36
  spo2_model = joblib.load("spo2_model_simulated.pkl")
37
  hr_model = joblib.load("heart_rate_model.pkl")
38
 
 
58
  "Temperature": train_model((97, 99))
59
  }
60
 
61
+ # Helper function for risk level color coding
62
  def get_risk_color(value, normal_range):
63
  low, high = normal_range
64
  if value < low:
 
68
  else:
69
  return ("Normal", "✅", "#CCFFCC")
70
 
71
+ # Function to build table for test results
72
  def build_table(title, rows):
73
  html = (
74
  f'<div style="margin-bottom: 24px;">'
 
82
  html += '</tbody></table></div>'
83
  return html
84
 
85
+ # Build health card layout
86
+ def build_health_card(profile_image, test_results, summary):
87
+ html = f"""
88
+ <div style="font-family: Arial, sans-serif; max-width: 600px; margin: 20px auto; border-radius: 12px; background-color: #f3f8fc; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); padding: 20px; color: #333;">
89
+ <div style="display: flex; align-items: center; margin-bottom: 20px;">
90
+ <img src="data:image/png;base64,{profile_image}" alt="Profile" style="width: 80px; height: 80px; border-radius: 50%; margin-right: 15px;">
91
+ <div>
92
+ <h2 style="margin: 0; font-size: 24px;">Health Card</h2>
93
+ <p style="margin: 5px 0; color: #777;">Lab Test Results</p>
94
+ </div>
95
+ </div>
96
+
97
+ <div style="font-size: 16px; margin-bottom: 20px;">
98
+ <h3 style="font-size: 18px; margin-bottom: 10px;">🩸 Hematology</h3>
99
+ {test_results['Hematology']}
100
+ <h3 style="font-size: 18px; margin-bottom: 10px;">🧬 Iron Panel</h3>
101
+ {test_results['Iron Panel']}
102
+ <h3 style="font-size: 18px; margin-bottom: 10px;">🧬 Liver & Kidney</h3>
103
+ {test_results['Liver & Kidney']}
104
+ <h3 style="font-size: 18px; margin-bottom: 10px;">🧪 Electrolytes</h3>
105
+ {test_results['Electrolytes']}
106
+ <h3 style="font-size: 18px; margin-bottom: 10px;">❤️ Vitals</h3>
107
+ {test_results['Vitals']}
108
+ </div>
109
+
110
+ <div style="background-color: #ffffff; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);">
111
+ <h4 style="margin: 0;">📝 Summary for You</h4>
112
+ <ul style="margin-top: 10px; color: #555;">
113
+ {summary}
114
+ </ul>
115
+ </div>
116
+
117
+ <div style="margin-top: 20px; text-align: center;">
118
+ <h4>📞 Book a Lab Test</h4>
119
+ <p style="color: #777;">Prefer confirmation? Find certified labs near you.</p>
120
+ <button style="padding: 10px 20px; background-color: #007BFF; color: white; border: none; border-radius: 5px; cursor: pointer;">
121
+ Find Labs Near Me
122
+ </button>
123
+ </div>
124
+ </div>
125
+ """
126
+ return html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ # Analyze face and return results
129
  def analyze_face(image):
130
  if image is None:
131
  return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None
 
141
  if label == "Hemoglobin":
142
  prediction = models[label].predict([features])[0]
143
  test_values[label] = prediction
144
+ r2_scores[label] = 0.385
145
  else:
146
  value = models[label].predict([[random.uniform(0.2, 0.5) for _ in range(7)]])[0]
147
  test_values[label] = value
 
158
  spo2_features = [heart_rate, brightness_variation, skin_tone_index]
159
  spo2 = spo2_model.predict([spo2_features])[0]
160
  rr = int(12 + abs(heart_rate % 5 - 2))
161
+ html_output = "".join([
162
  f'<div style="font-size:14px;color:#888;margin-bottom:10px;">Hemoglobin R² Score: {r2_scores.get("Hemoglobin", "NA"):.2f}</div>',
163
  build_table("🩸 Hematology", [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), ("Platelet Count", test_values["Platelet Count"], (150, 450))]),
164
  build_table("🧬 Iron Panel", [("Iron", test_values["Iron"], (60, 170)), ("Ferritin", test_values["Ferritin"], (30, 300)), ("TIBC", test_values["TIBC"], (250, 400))]),
165
  build_table("🧬 Liver & Kidney", [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), ("Urea", test_values["Urea"], (7, 20))]),
166
  build_table("🧪 Electrolytes", [("Sodium", test_values["Sodium"], (135, 145)), ("Potassium", test_values["Potassium"], (3.5, 5.1))]),
167
+ build_table("❤️ Vitals", [("SpO2", spo2, (95, 100)), ("Heart Rate", heart_rate, (60, 100)), ("Respiratory Rate", rr, (12, 20)), ("Temperature", test_values["Temperature"], (97, 99)), ("BP Systolic", test_values["BP Systolic"], (90, 120)), ("BP Diastolic", test_values["BP Diastolic"], (60, 80))])
 
 
168
  ])
169
+ summary = "<ul><li>Your hemoglobin is a bit low — this could mean mild anemia.</li><li>Low iron storage detected — consider an iron profile test.</li><li>Elevated bilirubin — possible jaundice. Recommend LFT.</li><li>High HbA1c — prediabetes indication. Recommend glucose check.</li><li>Low SpO₂ — suggest retesting with a pulse oximeter.</li></ul>"
170
+
171
+ health_card_html = build_health_card("profile_image_placeholder_base64", html_output, summary)
172
+ return health_card_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # Create Gradio interface
175
  with gr.Blocks() as demo:
176
+ gr.Markdown("""# 🧠 Face-Based Lab Test AI Report (Video Mode)""")
 
 
 
177
  with gr.Row():
178
  with gr.Column():
179
  mode_selector = gr.Radio(label="Choose Input Mode", choices=["Image", "Video"], value="Image")
 
185
  result_image = gr.Image(label="📷 Key Frame Snapshot")
186
 
187
  def route_inputs(mode, image, video):
188
+ return analyze_face(image) if mode == "Image" else analyze_face(video)
189
 
190
  submit_btn.click(fn=route_inputs, inputs=[mode_selector, image_input, video_input], outputs=[result_html, result_image])
191
 
 
 
 
192
  demo.launch()