EswariNani commited on
Commit
acab3de
ยท
verified ยท
1 Parent(s): 49eeabd

Upload virtualhealth.py

Browse files
Files changed (1) hide show
  1. virtualhealth.py +152 -0
virtualhealth.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import xgboost as xgb
2
+ import pickle
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import streamlit as st
7
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
8
+ import nltk
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.corpus import stopwords
11
+ import re
12
+
13
+ # ๐Ÿ”น Download stopwords only when needed
14
+ nltk.download("stopwords")
15
+ nltk.download("punkt")
16
+ nltk.download('punkt_tab')
17
+
18
+ # Load English stopwords
19
+ stop_words = set(stopwords.words("english"))
20
+
21
+ # ============================
22
+ # ๐Ÿ”น 1. Load Pretrained Medical Q&A Model
23
+ # ============================
24
+ # qa_model_name = "deepset/roberta-base-squad2" # Better model for medical Q&A
25
+ # tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
26
+ # qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name)
27
+ model_name = "dmis-lab/biobert-large-cased-v1.1-squad" # โœ… Updated Model
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(model_name)
30
+ # ============================
31
+ # ๐Ÿ”น 2. Load Symptom Checker Model
32
+ # ============================
33
+ model = xgb.XGBClassifier()
34
+ model.load_model("symptom_disease_model.json") # Load trained model
35
+ label_encoder = pickle.load(open("label_encoder.pkl", "rb")) # Load label encoder
36
+ X_train = pd.read_csv("X_train.csv") # Load symptoms
37
+ symptom_list = X_train.columns.tolist()
38
+
39
+ # ============================
40
+ # ๐Ÿ”น 3. Load Precaution Data
41
+ # ============================
42
+ precaution_df = pd.read_csv("Disease precaution.csv")
43
+ precaution_dict = {
44
+ row["Disease"].strip().lower(): [row[f"Precaution_{i}"] for i in range(1, 5) if pd.notna(row[f"Precaution_{i}"])]
45
+ for _, row in precaution_df.iterrows()
46
+ }
47
+
48
+ # ============================
49
+ # ๐Ÿ”น 4. Load Medical Context
50
+ # ============================
51
+ def load_medical_context():
52
+ with open("medical_context.txt", "r", encoding="utf-8") as file:
53
+ return file.read()
54
+
55
+ medical_context = load_medical_context()
56
+
57
+ # ============================
58
+ # ๐Ÿ”น 5. Doctor Database
59
+ # ============================
60
+ doctor_database = {
61
+ "malaria": [{"name": "Dr. Rajesh Kumar", "specialty": "Infectious Diseases", "location": "Apollo Hospital", "contact": "9876543210"}],
62
+ "diabetes": [{"name": "Dr. Anil Mehta", "specialty": "Endocrinologist", "location": "AIIMS Delhi", "contact": "9876543233"}],
63
+ "heart attack": [{"name": "Dr. Vikram Singh", "specialty": "Cardiologist", "location": "Medanta Hospital", "contact": "9876543255"}],
64
+ }
65
+
66
+ # ============================
67
+ # ๐Ÿ”น 6. Predict Disease from Symptoms
68
+ # ============================
69
+ def predict_disease(user_symptoms):
70
+ """Predicts disease based on user symptoms using the trained XGBoost model."""
71
+ input_vector = np.zeros(len(symptom_list))
72
+
73
+ for symptom in user_symptoms:
74
+ if symptom in symptom_list:
75
+ input_vector[symptom_list.index(symptom)] = 1
76
+
77
+ input_vector = input_vector.reshape(1, -1) # Reshape for model input
78
+ predicted_class = model.predict(input_vector)[0] # Predict disease
79
+ predicted_disease = label_encoder.inverse_transform([predicted_class])[0]
80
+
81
+ return predicted_disease
82
+
83
+ # ============================
84
+ # ๐Ÿ”น 7. Get Precautions for a Disease
85
+ # ============================
86
+ def get_precautions(disease):
87
+ """Returns the precautions for a given disease."""
88
+ return precaution_dict.get(disease.lower(), ["No precautions available"])
89
+
90
+ # ============================
91
+ # ๐Ÿ”น 8. Answer Medical Questions (Q&A Model)
92
+ # ============================
93
+ def get_medical_answer(question):
94
+ """Uses the pre-trained Q&A model to answer general medical questions."""
95
+ inputs = tokenizer(question, medical_context, return_tensors="pt", truncation=True, max_length=512)
96
+ with torch.no_grad():
97
+ outputs = qa_model(**inputs)
98
+
99
+ answer_start = torch.argmax(outputs.start_logits)
100
+ answer_end = torch.argmax(outputs.end_logits) + 1
101
+
102
+ answer = tokenizer.convert_tokens_to_string(
103
+ tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][answer_start:answer_end])
104
+ )
105
+
106
+ if answer.strip() in ["", "[CLS]", "<s>"]:
107
+ return "I'm not sure. Please consult a medical professional."
108
+
109
+ return answer
110
+ # ============================
111
+ # ๐Ÿ”น 9. Book a Doctor's Appointment
112
+ # ============================
113
+ def book_appointment(disease):
114
+ """Finds a doctor for the given disease and returns appointment details."""
115
+ disease = disease.lower().strip()
116
+ doctors = doctor_database.get(disease, [])
117
+ if not doctors:
118
+ return f"Sorry, no available doctors found for {disease}."
119
+
120
+ doctor = doctors[0]
121
+ return f"Appointment booked with **{doctor['name']}** ({doctor['specialty']}) at **{doctor['location']}**.\nContact: {doctor['contact']}"
122
+
123
+ # ============================
124
+ # ๐Ÿ”น 10. Handle User Queries
125
+ # ============================
126
+ def handle_user_query(user_query):
127
+ """Handles user queries related to symptoms, diseases, and doctor appointments."""
128
+ user_query = user_query.lower().strip()
129
+
130
+ # Check if query is about symptoms
131
+ if "symptoms" in user_query or "signs" in user_query:
132
+ disease = user_query.replace("symptoms", "").replace("signs", "").strip()
133
+ return get_medical_answer(f"What are the symptoms of {disease}?")
134
+
135
+ # Check if query is about treatment
136
+ elif "treatment" in user_query or "treat" in user_query:
137
+ disease = user_query.replace("treatment", "").replace("treat", "").strip()
138
+ return get_medical_answer(f"What is the treatment for {disease}?")
139
+
140
+ # Check for doctor recommendation
141
+ elif "who should i see" in user_query:
142
+ disease = user_query.replace("who should i see for", "").strip()
143
+ return book_appointment(disease)
144
+
145
+ # Check for appointment booking
146
+ elif "book appointment" in user_query:
147
+ disease = user_query.replace("book appointment for", "").strip()
148
+ return book_appointment(disease)
149
+
150
+ # Default case: general medical question
151
+ else:
152
+ return get_medical_answer(user_query)