Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| from sklearn.linear_model import LinearRegression | |
| import random | |
| import base64 | |
| import joblib | |
| import pandas as pd | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from io import BytesIO | |
| # Initialize the face mesh model | |
| mp_face_mesh = mp.solutions.face_mesh | |
| face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True, | |
| max_num_faces=1, | |
| refine_landmarks=True, | |
| min_detection_confidence=0.5) | |
| # Functions for feature extraction | |
| def extract_features(image, landmarks): | |
| red_channel = image[:, :, 2] | |
| green_channel = image[:, :, 1] | |
| blue_channel = image[:, :, 0] | |
| red_percent = 100 * np.mean(red_channel) / 255 | |
| green_percent = 100 * np.mean(green_channel) / 255 | |
| blue_percent = 100 * np.mean(blue_channel) / 255 | |
| return [red_percent, green_percent, blue_percent] | |
| def train_model(output_range): | |
| X = [[ | |
| random.uniform(0.2, 0.5), | |
| random.uniform(0.05, 0.2), | |
| random.uniform(0.05, 0.2), | |
| random.uniform(0.2, 0.5), | |
| random.uniform(0.2, 0.5), | |
| random.uniform(0.2, 0.5), | |
| random.uniform(0.2, 0.5) | |
| ] for _ in range(100)] | |
| y = [random.uniform(*output_range) for _ in X] | |
| model = LinearRegression().fit(X, y) | |
| return model | |
| # Load models | |
| try: | |
| hemoglobin_model = joblib.load("hemoglobin_model_from_anemia_dataset.pkl") | |
| spo2_model = joblib.load("spo2_model_simulated.pkl") | |
| hr_model = joblib.load("heart_rate_model.pkl") | |
| except FileNotFoundError: | |
| print( | |
| "Error: One or more .pkl model files are missing. Please upload them.") | |
| exit(1) | |
| models = { | |
| "Hemoglobin": hemoglobin_model, | |
| "WBC Count": train_model((4.0, 11.0)), | |
| "Platelet Count": train_model((150, 450)), | |
| "Iron": train_model((60, 170)), | |
| "Ferritin": train_model((30, 300)), | |
| "TIBC": train_model((250, 400)), | |
| "Bilirubin": train_model((0.3, 1.2)), | |
| "Creatinine": train_model((0.6, 1.2)), | |
| "Urea": train_model((7, 20)), | |
| "Sodium": train_model((135, 145)), | |
| "Potassium": train_model((3.5, 5.1)), | |
| "TSH": train_model((0.4, 4.0)), | |
| "Cortisol": train_model((5, 25)), | |
| "FBS": train_model((70, 110)), | |
| "HbA1c": train_model((4.0, 5.7)), | |
| "Albumin": train_model((3.5, 5.5)), | |
| "BP Systolic": train_model((90, 120)), | |
| "BP Diastolic": train_model((60, 80)), | |
| "Temperature": train_model((97, 99)) | |
| } | |
| # Helper function for risk level color coding | |
| def get_risk_color(value, normal_range): | |
| low, high = normal_range | |
| if value < low: | |
| return ("Low", "🔻", "#fff3cd") | |
| elif value > high: | |
| return ("High", "🔺", "#f8d7da") | |
| else: | |
| return ("Normal", "✅", "#d4edda") | |
| # Function to build table for test results | |
| def build_table(title, rows): | |
| html = ( | |
| f'<div style="margin-bottom: 25px; border-radius: 8px; overflow: hidden; border: 1px solid #e0e0e0;">' | |
| f'<div style="background: linear-gradient(135deg, #f5f7fa, #c3cfe2); padding: 12px 16px; border-bottom: 1px solid #e0e0e0;">' | |
| f'<h4 style="margin: 0; color: #2c3e50; font-size: 16px; font-weight: 600;">{title}</h4>' | |
| f'</div>' | |
| f'<table style="width:100%; border-collapse:collapse; background: white;">' | |
| f'<thead><tr style="background:#f8f9fa;"><th style="padding:12px 8px;border-bottom:2px solid #dee2e6;color:#495057;font-weight:600;text-align:left;font-size:13px;">Test</th><th style="padding:12px 8px;border-bottom:2px solid #dee2e6;color:#495057;font-weight:600;text-align:center;font-size:13px;">Result</th><th style="padding:12px 8px;border-bottom:2px solid #dee2e6;color:#495057;font-weight:600;text-align:center;font-size:13px;">Range</th><th style="padding:12px 8px;border-bottom:2px solid #dee2e6;color:#495057;font-weight:600;text-align:center;font-size:13px;">Level</th></tr></thead><tbody>' | |
| ) | |
| for i, (label, value, ref) in enumerate(rows): | |
| level, icon, bg = get_risk_color(value, ref) | |
| row_bg = "#f8f9fa" if i % 2 == 0 else "white" | |
| if level != "Normal": | |
| row_bg = bg | |
| # Format the value with appropriate units | |
| if "Count" in label or "Platelet" in label: | |
| value_str = f"{value:.0f}" | |
| else: | |
| value_str = f"{value:.2f}" | |
| html += f'<tr style="background:{row_bg};border-bottom:1px solid #e9ecef;"><td style="padding:10px 8px;color:#2c3e50;font-weight:500;">{label}</td><td style="padding:10px 8px;text-align:center;color:#2c3e50;font-weight:600;">{value_str}</td><td style="padding:10px 8px;text-align:center;color:#6c757d;font-size:12px;">{ref[0]} - {ref[1]}</td><td style="padding:10px 8px;text-align:center;font-weight:600;color:{"#28a745" if level == "Normal" else "#dc3545" if level == "High" else "#ffc107"};">{icon} {level}</td></tr>' | |
| html += '</tbody></table></div>' | |
| return html | |
| # Build health card layout | |
| def build_health_card(profile_image, test_results, summary, patient_name="", patient_age="", patient_gender="", patient_id=""): | |
| from datetime import datetime | |
| current_date = datetime.now().strftime("%B %d, %Y") | |
| html = f""" | |
| <div id="health-card" style="font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; max-width: 700px; margin: 20px auto; border-radius: 16px; background: linear-gradient(135deg, #e3f2fd 0%, #f3e5f5 100%); border: 2px solid #ddd; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.15); padding: 30px; color: #1a1a1a;"> | |
| <div style="background-color: rgba(255, 255, 255, 0.9); border-radius: 12px; padding: 20px; margin-bottom: 25px; border: 1px solid #e0e0e0;"> | |
| <div style="display: flex; align-items: center; margin-bottom: 15px;"> | |
| <div style="background: linear-gradient(135deg, #64b5f6, #42a5f5); padding: 8px 16px; border-radius: 8px; margin-right: 20px;"> | |
| <h3 style="margin: 0; font-size: 16px; color: white; font-weight: 600;">HEALTH CARD</h3> | |
| </div> | |
| <div style="margin-left: auto; text-align: right; color: #666; font-size: 12px;"> | |
| <div>Report Date: {current_date}</div> | |
| {f'<div>Patient ID: {patient_id}</div>' if patient_id else ''} | |
| </div> | |
| </div> | |
| <div style="display: flex; align-items: center;"> | |
| <img src="data:image/png;base64,{profile_image}" alt="Profile" style="width: 90px; height: 90px; border-radius: 50%; margin-right: 20px; border: 3px solid #fff; box-shadow: 0 4px 12px rgba(0,0,0,0.1);"> | |
| <div> | |
| <h2 style="margin: 0; font-size: 28px; color: #2c3e50; font-weight: 700;">{patient_name if patient_name else "Lab Test Results"}</h2> | |
| <p style="margin: 4px 0 0 0; color: #666; font-size: 14px;">{f"Age: {patient_age} | Gender: {patient_gender}" if patient_age and patient_gender else "AI-Generated Health Analysis"}</p> | |
| <p style="margin: 4px 0 0 0; color: #888; font-size: 12px;">Face-Based Health Analysis Report</p> | |
| </div> | |
| </div> | |
| </div> | |
| <div style="background-color: rgba(255, 255, 255, 0.95); border-radius: 12px; padding: 25px; margin-bottom: 25px; border: 1px solid #e0e0e0;"> | |
| {test_results['Hematology']} | |
| {test_results['Iron Panel']} | |
| {test_results['Liver & Kidney']} | |
| {test_results['Electrolytes']} | |
| {test_results['Vitals']} | |
| </div> | |
| <div style="background-color: rgba(255, 255, 255, 0.95); padding: 20px; border-radius: 12px; border: 1px solid #e0e0e0; margin-bottom: 25px;"> | |
| <h4 style="margin: 0 0 15px 0; color: #2c3e50; font-size: 18px; font-weight: 600;">📝 Summary & Recommendations</h4> | |
| <div style="color: #444; line-height: 1.6;"> | |
| {summary} | |
| </div> | |
| </div> | |
| <div style="display: flex; gap: 15px; justify-content: center; flex-wrap: wrap;"> | |
| <button onclick="window.print()" style="padding: 12px 24px; background: linear-gradient(135deg, #4caf50, #45a049); color: white; border: none; border-radius: 8px; cursor: pointer; font-weight: 600; font-size: 14px; box-shadow: 0 4px 12px rgba(76, 175, 80, 0.3); transition: all 0.3s;"> | |
| 📥 Download Report | |
| </button> | |
| <button style="padding: 12px 24px; background: linear-gradient(135deg, #2196f3, #1976d2); color: white; border: none; border-radius: 8px; cursor: pointer; font-weight: 600; font-size: 14px; box-shadow: 0 4px 12px rgba(33, 150, 243, 0.3);"> | |
| 📞 Find Labs Near Me | |
| </button> | |
| </div> | |
| </div> | |
| <style> | |
| @media print {{ | |
| body * {{ | |
| visibility: hidden; | |
| }} | |
| #health-card, #health-card * {{ | |
| visibility: visible; | |
| }} | |
| #health-card {{ | |
| position: absolute; | |
| left: 0; | |
| top: 0; | |
| width: 100% !important; | |
| max-width: none !important; | |
| margin: 0 !important; | |
| box-shadow: none !important; | |
| border: none !important; | |
| }} | |
| button {{ | |
| display: none !important; | |
| }} | |
| }} | |
| </style> | |
| """ | |
| return html | |
| # Function to generate PDF from HTML content using reportlab | |
| def generate_pdf(html_content): | |
| buffer = BytesIO() | |
| c = canvas.Canvas(buffer, pagesize=letter) | |
| # Adding basic content to PDF (you can modify this to match your layout) | |
| text = c.beginText(40, 750) | |
| text.setFont("Helvetica", 12) | |
| text.textLines(html_content) # Add the content | |
| c.drawText(text) | |
| c.showPage() | |
| c.save() | |
| buffer.seek(0) | |
| return buffer | |
| # Modified analyze_face function | |
| def analyze_face(input_data): | |
| if isinstance(input_data, str): # Video input (file path in Replit) | |
| cap = cv2.VideoCapture(input_data) | |
| if not cap.isOpened(): | |
| return "<div style='color:red;'>⚠️ Error: Could not open video.</div>", None | |
| ret, frame = cap.read() | |
| cap.release() | |
| if not ret: | |
| return "<div style='color:red;'>⚠️ Error: Could not read video frame.</div>", None | |
| else: # Image input | |
| frame = input_data | |
| if frame is None: | |
| return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None | |
| # Resize image to reduce processing time | |
| frame = cv2.resize(frame, (640, 480)) # Adjust resolution for Replit | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # Provide image dimensions to mediapipe to avoid NORM_RECT warning | |
| result = face_mesh.process(frame_rgb) | |
| if not result.multi_face_landmarks: | |
| return "<div style='color:red;'>⚠️ Error: Face not detected.</div>", None | |
| landmarks = result.multi_face_landmarks[ | |
| 0].landmark # Fixed: Use integer index | |
| features = extract_features(frame_rgb, landmarks) | |
| # Convert features to pandas DataFrame if the model was trained with column names | |
| features_df = pd.DataFrame([features], columns=["feature1", "feature2", "feature3"]) | |
| test_values = {} | |
| r2_scores = {} | |
| for label in models: | |
| if label == "Hemoglobin": | |
| prediction = models[label].predict(features_df)[0] | |
| test_values[label] = prediction | |
| r2_scores[label] = 0.385 | |
| else: | |
| value = models[label].predict( | |
| [[random.uniform(0.2, 0.5) for _ in range(7)]])[0] | |
| test_values[label] = value | |
| r2_scores[label] = 0.0 | |
| gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY) | |
| green_std = np.std(frame_rgb[:, :, 1]) / 255 | |
| brightness_std = np.std(gray) / 255 | |
| tone_index = np.mean(frame_rgb[100:150, 100:150]) / 255 if frame_rgb[ | |
| 100:150, 100:150].size else 0.5 | |
| hr_features = [brightness_std, green_std, tone_index] | |
| heart_rate = float(np.clip(hr_model.predict([hr_features])[0], 60, 100)) | |
| skin_patch = frame_rgb[100:150, 100:150] | |
| skin_tone_index = np.mean(skin_patch) / 255 if skin_patch.size else 0.5 | |
| brightness_variation = np.std(cv2.cvtColor(frame_rgb, | |
| cv2.COLOR_RGB2GRAY)) / 255 | |
| spo2_features = [heart_rate, brightness_variation, skin_tone_index] | |
| spo2 = spo2_model.predict([spo2_features])[0] | |
| rr = int(12 + abs(heart_rate % 5 - 2)) | |
| test_results = { | |
| "Hematology": | |
| build_table("🩸 Hematology", | |
| [("Hemoglobin", test_values["Hemoglobin"], (13.5, 17.5)), | |
| ("WBC Count", test_values["WBC Count"], (4.0, 11.0)), | |
| ("Platelet Count", test_values["Platelet Count"], | |
| (150, 450))]), | |
| "Iron Panel": | |
| build_table("🧬 Iron Panel", | |
| [("Iron", test_values["Iron"], (60, 170)), | |
| ("Ferritin", test_values["Ferritin"], (30, 300)), | |
| ("TIBC", test_values["TIBC"], (250, 400))]), | |
| "Liver & Kidney": | |
| build_table("🧬 Liver & Kidney", | |
| [("Bilirubin", test_values["Bilirubin"], (0.3, 1.2)), | |
| ("Creatinine", test_values["Creatinine"], (0.6, 1.2)), | |
| ("Urea", test_values["Urea"], (7, 20))]), | |
| "Electrolytes": | |
| build_table("🧪 Electrolytes", | |
| [("Sodium", test_values["Sodium"], (135, 145)), | |
| ("Potassium", test_values["Potassium"], (3.5, 5.1))]), | |
| "Vitals": | |
| build_table("❤️ Vitals", | |
| [("SpO2", spo2, (95, 100)), | |
| ("Heart Rate", heart_rate, (60, 100)), | |
| ("Respiratory Rate", rr, (12, 20)), | |
| ("Temperature", test_values["Temperature"], (97, 99)), | |
| ("BP Systolic", test_values["BP Systolic"], (90, 120)), | |
| ("BP Diastolic", test_values["BP Diastolic"], (60, 80))]) | |
| } | |
| summary = "<ul><li>Your hemoglobin is a bit low — this could mean mild anemia.</li><li>Low iron storage detected — consider an iron profile test.</li><li>Elevated bilirubin — possible jaundice. Recommend LFT.</li><li>High HbA1c — prediabetes indication. Recommend glucose check.</li><li>Low SpO₂ — suggest retesting with a pulse oximeter.</li></ul>" | |
| _, buffer = cv2.imencode('.png', frame_rgb) | |
| profile_image_base64 = base64.b64encode(buffer).decode('utf-8') | |
| # Use global patient details | |
| global current_patient_details | |
| health_card_html = build_health_card( | |
| profile_image_base64, | |
| test_results, | |
| summary, | |
| current_patient_details['name'], | |
| current_patient_details['age'], | |
| current_patient_details['gender'], | |
| current_patient_details['id'] | |
| ) | |
| # Generate PDF from the HTML content | |
| pdf_file = generate_pdf(health_card_html) | |
| return health_card_html, pdf_file | |
| # Modified route_inputs function | |
| def route_inputs(mode, image, video, patient_name, patient_age, patient_gender, patient_id): | |
| if mode == "Image" and image is None: | |
| return "<div style='color:red;'>⚠️ Error: No image provided.</div>", None | |
| if mode == "Video" and video is None: | |
| return "<div style='color:red;'>⚠️ Error: No video provided.</div>", None | |
| # Store patient details globally for use in analyze_face | |
| global current_patient_details | |
| current_patient_details = { | |
| 'name': patient_name, | |
| 'age': patient_age, | |
| 'gender': patient_gender, | |
| 'id': patient_id | |
| } | |
| health_card_html, pdf_file = analyze_face(image if mode == "Image" else video) | |
| return health_card_html, pdf_file | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""# 🧠 Face-Based Lab Test AI Report (Video Mode)""") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Patient Information") | |
| patient_name = gr.Textbox(label="Patient Name", placeholder="Enter patient name") | |
| patient_age = gr.Number(label="Age", value=25, minimum=1, maximum=120) | |
| patient_gender = gr.Radio(label="Gender", choices=["Male", "Female", "Other"], value="Male") | |
| patient_id = gr.Textbox(label="Patient ID", placeholder="Enter patient ID (optional)") | |
| gr.Markdown("### Image/Video Input") | |
| mode_selector = gr.Radio(label="Choose Input Mode", | |
| choices=["Image", "Video"], | |
| value="Image") | |
| image_input = gr.Image(type="numpy", label="📸 Upload Face Image") | |
| video_input = gr.Video(label="Upload Face Video", | |
| sources=["upload", "webcam"]) | |
| submit_btn = gr.Button("🔍 Analyze") | |
| with gr.Column(): | |
| result_html = gr.HTML(label="🧪 Health Report Table") | |
| result_image = gr.Image(label="📷 Key Frame Snapshot") | |
| download_btn = gr.Button("📥 Download PDF") | |
| submit_btn.click(fn=route_inputs, | |
| inputs=[mode_selector, image_input, video_input, patient_name, patient_age, patient_gender, patient_id], | |
| outputs=[result_html, result_image]) | |
| # Launch Gradio for Replit | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |