Files changed (1) hide show
  1. app.py +416 -312
app.py CHANGED
@@ -1,258 +1,269 @@
1
  import streamlit as st
2
  import os
 
 
3
  from groq import Groq
4
  from datetime import datetime
5
 
6
- # Set page config FIRST
7
  st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
8
 
9
- # Custom CSS for styling
10
- st.markdown("""
11
- <style>
12
- /* Color Variables */
13
- :root {
14
- --primary: #3498db; /* Blue */
15
- --secondary: #2c3e50; /* Dark accent */
16
- --accent: #f1c40f; /* Yellow */
17
- --success: #2ecc71; /* Positive actions */
18
- --light: #ffffff; /* White backgrounds */
19
- --dark: #000000; /* Black text/elements */
20
- }
21
- /* Main container styling */
22
- .stApp {
23
- background: linear-gradient(135deg, #3498db 0%, #e0e0e0 100%);
24
- font-family: 'Arial', sans-serif;
25
- }
26
- /* Headers styling */
27
- h1, h2, h3 {
28
- color: var(--dark) !important;
29
- border-bottom: 3px solid var(--primary);
30
- padding-bottom: 0.3em;
31
- }
32
- /* Form containers */
33
- .stForm {
34
- background: #000000;
35
- border: 1px solid rgba(44, 62, 80, 0.2);
36
- border-radius: 15px;
37
- padding: 2rem;
38
- box-shadow: 0 8px 30px rgba(0, 0, 0, 0.12);
39
- margin: 1rem 0;
40
- }
41
- /* Input fields */
42
- .stTextInput input, .stNumberInput input,
43
- .stSelectbox select, .stTextArea textarea {
44
- border: 2px solid #00FFFF !important;
45
- border-radius: 10px !important;
46
- padding: 1rem !important;
47
- background: #00FFFF !important;
48
- transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
49
- color: var(--dark) !important;
50
- }
51
- .stTextInput input:focus, .stNumberInput input:focus,
52
- .stSelectbox select:focus, .stTextArea textarea:focus {
53
- border-color: var(--primary) !important;
54
- box-shadow: 0 0 12px rgba(52, 152, 219, 0.2) !important;
55
- background: white !important;
56
- color: var(--dark) !important;
57
- }
58
- /* Buttons styling */
59
- .stButton>button {
60
- background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important;
61
- color: var(--dark) !important;
62
- border: none !important;
63
- border-radius: 10px !important;
64
- padding: 1rem 2rem !important;
65
- font-size: 1rem !important;
66
- transition: all 0.3s ease;
67
- position: relative;
68
- overflow: hidden;
69
- }
70
- .stButton>button:hover {
71
- transform: translateY(-2px);
72
- box-shadow: 0 8px 15px rgba(52, 152, 219, 0.3);
73
- opacity: 0.95;
74
- }
75
- .stButton>button:active {
76
- transform: translateY(0);
77
- opacity: 1;
78
- }
79
- /* Progress indicator */
80
- .progress-bar {
81
- display: flex;
82
- justify-content: space-between;
83
- margin: 2rem 0;
84
- padding: 1rem;
85
- background: rgba(255, 255, 255, 0.9);
86
- border-radius: 10px;
87
- color: var(--dark) !important;
88
- }
89
- .step {
90
- flex: 1;
91
- text-align: center;
92
- padding: 1rem;
93
- font-weight: 600;
94
- color: #95a5a6;
95
- position: relative;
96
- }
97
- .step.active {
98
- color: var(--primary);
99
- }
100
- .step.active:after {
101
- content: '';
102
- position: absolute;
103
- bottom: -1px;
104
- left: 50%;
105
- transform: translateX(-50%);
106
- width: 40%;
107
- height: 3px;
108
- background: var(--primary);
109
- }
110
- /* Chat bubbles */
111
- .dr-message {
112
- background: linear-gradient(135deg, var(--primary) 0%, #2980b9 100%);
113
- color: white;
114
- border-radius: 20px 20px 20px 4px;
115
- padding: 1.2rem 1.5rem;
116
- margin: 1rem 0;
117
- max-width: 80%;
118
- width: fit-content;
119
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
120
- }
121
- .user-message {
122
- background: linear-gradient(135deg, #f1c40f 0%, #e1b800 100%);
123
- margin-left: auto;
124
- border-radius: 20px 20px 4px 20px;
125
- color: var(--dark) !important;
126
- }
127
- /* Emergency alert */
128
- .emergency-alert {
129
- background: linear-gradient(135deg, var(--accent) 0%, #c0392b 100%);
130
- color: white;
131
- padding: 2rem;
132
- border-radius: 15px;
133
- animation: pulse 1.5s infinite;
134
- text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
135
- }
136
- @keyframes pulse {
137
- 0% { transform: scale(1); }
138
- 50% { transform: scale(1.02); }
139
- 100% { transform: scale(1); }
140
- }
141
- /* Download button */
142
- .download-btn {
143
- background: linear-gradient(135deg, var(--success) 0%, #27ae60 100%) !important;
144
- }
145
- /* Enhanced Data Visualization Contrast */
146
- .stDataFrame {
147
- border: 1px solid rgba(0, 0, 0, 0.1);
148
- border-radius: 12px;
149
- overflow: hidden;
150
- background: #f0f0f0;
151
- color: var(--dark) !important;
152
- }
153
- /* Tabbed Interface Styling */
154
- .stTabs [role="tablist"] {
155
- gap: 10px;
156
- padding: 8px;
157
- background: rgba(240, 240, 240, 0.9);
158
- border-radius: 12px;
159
- color: var(--dark) !important;
160
- }
161
- .stTabs [role="tab"] {
162
- background: #ffffff !important;
163
- border-radius: 8px !important;
164
- transition: all 0.3s ease;
165
- color: var(--dark) !important;
166
- }
167
- .stTabs [role="tab"][aria-selected="true"] {
168
- background: var(--primary) !important;
169
- color: white !important;
170
- transform: scale(1.05);
171
- }
172
- </style>
173
- """, unsafe_allow_html=True)
174
 
 
175
 
176
  # Initialize session state variables
177
  if 'current_step' not in st.session_state:
178
  st.session_state.current_step = 0
179
  if 'symptom_details' not in st.session_state:
180
- st.session_state.symptom_details = [] # Initialize as an empty list
181
  if 'patient_info' not in st.session_state:
182
  st.session_state.patient_info = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  def initialize_groq_client():
185
  try:
186
- # Try to get the API key from Streamlit secrets
187
- api_key = None
188
- try:
189
- api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
190
- except FileNotFoundError:
191
- st.warning("No `secrets.toml` file found. Please create one in the `.streamlit` folder.")
192
-
193
- # If not found, prompt the user to enter it
194
  if not api_key:
195
- api_key = st.text_input("Enter your Groq API Key:", type="password")
196
- if not api_key:
197
- st.warning("Please provide a valid Groq API key to proceed.")
198
- return False
199
 
200
- # Initialize the Groq client
201
- client = Groq(api_key=api_key)
202
- st.session_state.client = client
203
  return True
204
  except Exception as e:
205
- st.error(f"Error initializing Groq client: {str(e)}")
206
  return False
207
 
208
  def symptom_interrogation_step():
209
  client = st.session_state.client
210
  main_symptom = st.session_state.patient_info['main_symptom']
211
- step = len(st.session_state.symptom_details) # Use the number of collected details as the step
212
 
213
  if step == 0:
214
- # First question: ask about the main symptom
215
  medical_focus = {
216
  'pain': "location/radiation/provoking factors",
217
  'fever': "pattern/associated symptoms/response to meds",
218
  'gi': "bowel changes/ingestion timing/associated symptoms",
219
  'respiratory': "exertion relationship/sputum/triggers"
220
  }
221
- focus = medical_focus.get(main_symptom.lower(),
222
- "temporal pattern/severity progression/associated symptoms")
223
-
224
  prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
225
- focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient.
226
- Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}.
227
- Ask only one question at a time to avoid overwhelming the patient.
228
- Keep your language clear, professional, and easy to understand.
229
-
230
- Dont display possibe symptoms or why you are asking questions."""
231
-
232
- messages = [
233
- {"role": "system", "content": "Ask focused clinical questions. One at a time."},
234
- {"role": "user", "content": prompt}
235
- ]
236
  else:
237
- # Subsequent questions: use the last Q&A to generate the next question
238
  last_qa = st.session_state.symptom_details[-1]
239
- prompt = f"""Last Q&A: {last_qa['question']} β†’ {last_qa['answer']}
240
- Based on this, ask the NEXT most critical question to differentiate between
241
- possible causes of {main_symptom}. Consider red flags and likelihood."""
242
- messages = [{"role": "user", "content": prompt}]
243
 
244
  try:
245
  response = client.chat.completions.create(
246
- messages=messages,
247
  model="mixtral-8x7b-32768",
248
  temperature=0.3
249
  )
250
  question = response.choices[0].message.content.strip()
251
- if not question.endswith('?'):
252
- question += '?'
253
  st.session_state.current_question = question
254
  except Exception as e:
255
- st.error(f"Error generating question: {str(e)}")
256
  st.stop()
257
 
258
  def handle_symptom_interrogation():
@@ -264,7 +275,7 @@ def handle_symptom_interrogation():
264
 
265
  if 'current_question' in st.session_state:
266
  with st.form("symptom_qna"):
267
- st.markdown(f'<div class="dr-message">πŸ‘¨β€βš•οΈ {st.session_state.current_question}</div>', unsafe_allow_html=True)
268
  answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
269
 
270
  if st.form_submit_button("Next"):
@@ -274,31 +285,27 @@ def handle_symptom_interrogation():
274
  "answer": answer
275
  })
276
  del st.session_state.current_question
277
-
278
- # Check for emergency after 3 questions
279
  if len(st.session_state.symptom_details) >= 3:
280
  last_answer = st.session_state.symptom_details[-1]['answer']
281
  try:
282
  urgency_check = st.session_state.client.chat.completions.create(
283
- messages=[{"role": "user", "content":
284
- f"Does '{last_answer}' indicate immediate emergency? Yes/No"}],
285
  model="mixtral-8x7b-32768",
286
  temperature=0
287
  ).choices[0].message.content
288
-
289
  if 'YES' in urgency_check.upper():
290
- st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True)
291
  st.session_state.current_step = 4
292
  return
293
- except Exception as e:
294
- st.error(f"Error checking urgency: {str(e)}")
295
-
296
  if len(st.session_state.symptom_details) < 7:
297
  st.session_state.current_step = 1
298
- st.rerun()
299
  else:
300
  st.session_state.current_step = 3
301
- st.rerun()
302
  else:
303
  st.warning("Please provide an answer")
304
 
@@ -311,11 +318,11 @@ def collect_basic_info():
311
  st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
312
 
313
  if st.form_submit_button("Next"):
314
- if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]):
315
  st.session_state.current_step = 1
316
  st.rerun()
317
  else:
318
- st.warning("Please fill all required fields")
319
 
320
  def collect_medical_history():
321
  st.header("Medical History")
@@ -331,116 +338,213 @@ def collect_medical_history():
331
  st.rerun()
332
 
333
  def generate_risk_assessment():
334
- st.header("Risk Assessment")
335
 
336
  try:
337
- symptom_log = "\n".join(
338
- [f"Q: {q['question']}\nA: {q['answer']}"
339
- for q in st.session_state.symptom_details]
340
- )
341
-
342
  patient_profile = f"""
343
- **Patient Profile**
344
  Name: {st.session_state.patient_info['name']}
345
  Age: {st.session_state.patient_info['age']}
346
  Gender: {st.session_state.patient_info['gender']}
 
347
 
348
- **Primary Complaint**
349
- {st.session_state.patient_info['main_symptom']}
350
-
351
- **Symptom Interrogation**
352
  {symptom_log}
353
 
354
- **Medical History**
355
- {st.session_state.patient_info.get('medical_history', 'None reported')}
356
-
357
- **Current Medications**
358
- {st.session_state.patient_info.get('medications', 'None')}
359
-
360
- **Allergies**
361
- {st.session_state.patient_info.get('allergies', 'None reported')}
362
-
363
- **Recent Context**
364
- Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')}
365
- Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')}
366
  """
367
 
368
- analysis_prompt = f"""STRICTLY follow these instructions:
369
- 1. Analyze this case: {patient_profile}
370
- 2. *Include ONLY symptoms the patient is actively experiencing*. Exclude all negated symptoms (e.g., "no fever," "denies breathlessness").
371
- 3. Output *EXCLUSIVELY* in this format with NO additional text or explanations:
372
- [Age]-year-old [gender] with [specific, present symptoms].
373
- Example Output:
374
- "45-year-old man with severe chest pain radiating to the jaw"
375
- Your Output:"""
376
-
377
-
378
- response = st.session_state.client.chat.completions.create(
379
- messages=[
380
- {"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."},
381
- {"role": "user", "content": analysis_prompt}
382
- ],
383
- model="mixtral-8x7b-32768",
384
- temperature=0.3,
385
- max_tokens=100
386
- )
387
-
388
- risk_prompt = response.choices[0].message.content.strip('"')
389
 
390
- st.subheader("Clinical Summary")
391
- st.markdown(f"```\n{risk_prompt}\n```")
392
-
393
- # Create download button
394
- timestamp = datetime.now().strftime('%Y%m%d%H%M')
395
- filename = f"{st.session_state.patient_info['name'].replace(' ', '_')}_assessment_{timestamp}.txt"
396
- st.download_button(
397
- label="Download Assessment",
398
- data=risk_prompt,
399
- file_name=filename,
400
- mime="text/plain"
401
- )
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  except Exception as e:
404
- st.error(f"Error generating risk assessment: {str(e)}")
405
 
406
- def main():
407
- st.title("πŸ₯ AI Medical Consultancy")
408
 
409
- # Progress indicator
410
- steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"]
411
- progress_html = """
412
- <div class="progress-bar">
413
- <div class="step {}">{}</div>
414
- <div class="step {}">{}</div>
415
- <div class="step {}">{}</div>
416
- <div class="step {}">{}</div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
  </div>
418
- """.format(
419
- 'active' if st.session_state.current_step >= 0 else '',
420
- '1. Patient Info',
421
- 'active' if st.session_state.current_step >= 1 else '',
422
- '2. Symptoms',
423
- 'active' if st.session_state.current_step >= 3 else '',
424
- '3. History',
425
- 'active' if st.session_state.current_step >= 4 else '',
426
- '4. Report'
427
- )
428
- st.markdown(progress_html, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
 
430
  if not initialize_groq_client():
 
431
  return
 
 
 
432
 
433
- steps = {
434
- 0: collect_basic_info,
435
- 1: handle_symptom_interrogation,
436
- 2: handle_symptom_interrogation,
437
- 3: collect_medical_history,
438
- 4: generate_risk_assessment
439
- }
440
-
441
- current_step = st.session_state.get('current_step', 0)
442
- if current_step in steps:
443
- steps[current_step]()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
- if __name__ == "__main__":
446
  main()
 
1
  import streamlit as st
2
  import os
3
+ import pandas as pd
4
+ import re
5
  from groq import Groq
6
  from datetime import datetime
7
 
8
+ # Set page config
9
  st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
10
 
11
+ # Load custom CSS
12
+ def load_css():
13
+ try:
14
+ with open("style.css") as f:
15
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
16
+ except FileNotFoundError:
17
+ st.warning("CSS file not found. Please ensure 'style.css' is in the directory.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ load_css()
20
 
21
  # Initialize session state variables
22
  if 'current_step' not in st.session_state:
23
  st.session_state.current_step = 0
24
  if 'symptom_details' not in st.session_state:
25
+ st.session_state.symptom_details = []
26
  if 'patient_info' not in st.session_state:
27
  st.session_state.patient_info = {}
28
+ if 'appointment_details' not in st.session_state:
29
+ st.session_state.appointment_details = None
30
+ if 'appointment_summary' not in st.session_state:
31
+ st.session_state.appointment_summary = None
32
+ if 'analysis_results' not in st.session_state:
33
+ st.session_state.analysis_results = None
34
+
35
+ class MedicalAnalysisSystem:
36
+ def _init_(self, dataset_path):
37
+ try:
38
+ self.data = pd.read_csv(dataset_path)
39
+ print("Dataset Columns:", self.data.columns.tolist()) # Debug
40
+ print("Sample Data:\n", self.data.head()) # Debug
41
+
42
+ # Clean data - remove placeholder rows
43
+ self.data = self.data[~self.data['Symptom'].str.contains('Symptom|Condition', case=False)]
44
+
45
+ self.data['Risk Score'] = pd.to_numeric(self.data['Risk Score'], errors='coerce')
46
+ # Handle missing values
47
+ self.data['Risk Score'] = self.data['Risk Score'].fillna(0)
48
+
49
+ # Calculate MAX_RISK_SCORE dynamically
50
+ symptom_max_risk = self.data.groupby('Symptom')['Risk Score'].max().sum()
51
+ max_age = 120
52
+ max_age_risk = (max_age - 40) * 0.05 if max_age > 40 else 0
53
+ self.MAX_RISK_SCORE = symptom_max_risk + max_age_risk
54
+
55
+ self.local_messages = []
56
+ self.severity_mapping = {
57
+ 'Mild': ['mild', 'slight', 'minor', 'low grade'],
58
+ 'Moderate': ['moderate', 'medium', 'average'],
59
+ 'Severe': ['severe', 'high', 'extreme', 'critical', 'intense', 'very bad', 'acute']
60
+ }
61
+ self.negation_words = {'no', 'not', 'denies', 'without', 'negative', 'none', 'denied'}
62
+ except Exception as e:
63
+ st.error(f"Dataset Error: {str(e)}")
64
+ raise
65
+
66
+ def add_patient_data(self, patient_message):
67
+ try:
68
+ if not patient_message:
69
+ raise ValueError("Patient message cannot be empty")
70
+ self.local_messages.append({
71
+ 'message': patient_message,
72
+ 'timestamp': datetime.now().timestamp()
73
+ })
74
+ except Exception as e:
75
+ st.error(f"Error adding patient data: {str(e)}")
76
+
77
+ def extract_info_from_bot_response(self, bot_response_data):
78
+ try:
79
+ if not bot_response_data:
80
+ return 0, [], {}
81
+
82
+ bot_response_text = str(bot_response_data)
83
+ bot_response_lower = bot_response_text.lower()
84
+
85
+ # Age extraction
86
+ age = 0
87
+ age_pattern = r'(\d{1,3})\s*(?:years?-?old|yo|years|-years-old?)'
88
+ age_match = re.search(age_pattern, bot_response_text, re.IGNORECASE)
89
+ if age_match:
90
+ age = int(age_match.group(1))
91
+ if not (0 <= age <= 120): age = 20
92
+
93
+ # Symptom extraction
94
+ symptoms = []
95
+ for symptom in self.data['Symptom'].unique():
96
+ symptom_lower = symptom.lower()
97
+ pattern = re.compile(r'\b' + re.escape(symptom_lower) + r'\b', re.IGNORECASE)
98
+ matches = pattern.finditer(bot_response_lower)
99
+ for match in matches:
100
+ start_pos = match.start()
101
+ preceding_text = bot_response_lower[:start_pos].split()
102
+ preceding_words = preceding_text[-3:]
103
+ if not any(neg in preceding_words for neg in self.negation_words):
104
+ symptoms.append(symptom)
105
+ break
106
+
107
+ # Severity analysis
108
+ symptom_severity = {}
109
+ for symptom in symptoms:
110
+ symptom_lower = symptom.lower()
111
+ highest_severity_score = 0
112
+ pattern = re.compile(r'\b' + re.escape(symptom_lower) + r'\b', re.IGNORECASE)
113
+ matches = pattern.finditer(bot_response_lower)
114
+ for match in matches:
115
+ start, end = match.start(), match.end()
116
+ words = bot_response_lower.split()
117
+ match_index = len(bot_response_lower[:start].split())
118
+ context_start = max(0, match_index - 5)
119
+ context_end = min(len(words), match_index + 6)
120
+ context = ' '.join(words[context_start:context_end])
121
+ for severity, keywords in self.severity_mapping.items():
122
+ for keyword in keywords:
123
+ if re.search(r'\b' + re.escape(keyword) + r'\b', context):
124
+ condition_data = self.data[(self.data['Symptom'] == symptom) &
125
+ (self.data['Condition'] == severity)]
126
+ if not condition_data.empty:
127
+ risk_score = condition_data['Risk Score'].values[0]
128
+ if risk_score > highest_severity_score:
129
+ highest_severity_score = risk_score
130
+ if highest_severity_score == 0:
131
+ highest_severity_score = self.data[self.data['Symptom'] == symptom]['Risk Score'].max()
132
+ symptom_severity[symptom] = highest_severity_score
133
+
134
+ return age, symptoms, symptom_severity
135
+
136
+ except Exception as e:
137
+ st.error(f"Extraction Error: {str(e)}")
138
+ return 0, [], {}
139
+
140
+ def calculate_risk_score(self, age, symptoms, symptom_severity):
141
+ try:
142
+ # Validate symptoms
143
+ valid_symptoms = [s for s in symptoms if s in self.data['Symptom'].values]
144
+ if not valid_symptoms:
145
+ return "Unknown", 0, 0
146
+
147
+ # Calculate scores with validation
148
+ symptom_risk = sum(float(symptom_severity.get(s, 0)) for s in valid_symptoms)
149
+ age_risk = max((age - 40) * 0.05, 0) if age >= 40 else 0
150
+ final_score = symptom_risk + age_risk
151
+
152
+ # Ensure we don't divide by zero
153
+ max_score = self.MAX_RISK_SCORE if self.MAX_RISK_SCORE > 0 else 1
154
+ risk_pct = min(100, max(0, (final_score / max_score) * 100))
155
+
156
+ if risk_pct <= 30: label = "Low"
157
+ elif risk_pct <= 70: label = "Medium"
158
+ else: label = "High"
159
+
160
+ return label, final_score, round(risk_pct, 1)
161
+ except Exception as e:
162
+ st.error(f"Risk Calculation Error: {str(e)}")
163
+ return "Low", 0, 0
164
+
165
+ def analyze_patient_data(self, patient_message):
166
+ """Full analysis workflow"""
167
+ try:
168
+ # Clean input message
169
+ patient_message = patient_message.replace("Symptom", "").replace("Condition", "")
170
+
171
+ self.add_patient_data(patient_message)
172
+ age, symptoms, severity = self.extract_info_from_bot_response(patient_message)
173
+
174
+ # Filter invalid symptoms
175
+ valid_symptoms = [s for s in symptoms if s in self.data['Symptom'].values]
176
+ if not valid_symptoms:
177
+ return {"error": "No valid symptoms detected"}
178
+
179
+ # Get unique conditions from valid symptoms
180
+ conditions = self.data[self.data['Symptom'].isin(valid_symptoms)]['Condition'].unique()
181
+ valid_conditions = [c for c in conditions if c not in ['Normal', 'Moderate', 'Severe', 'Condition']]
182
+
183
+ risk_label, risk_score, risk_pct = self.calculate_risk_score(age, valid_symptoms, severity)
184
+
185
+ return {
186
+ 'age': age,
187
+ 'symptoms': valid_symptoms,
188
+ 'symptom_severity': severity,
189
+ 'risk_label': risk_label,
190
+ 'risk_score': round(risk_score, 2),
191
+ 'risk_percentage': risk_pct,
192
+ 'possible_conditions': valid_conditions,
193
+ 'analysis_timestamp': datetime.now().isoformat()
194
+ }
195
+ except Exception as e:
196
+ return {"error": f"Analysis Error: {str(e)}"}
197
+
198
+ def process_user_data(self):
199
+ try:
200
+ if not self.local_messages:
201
+ return {"error": "No messages available"}
202
+
203
+ latest = max(self.local_messages, key=lambda x: x['timestamp'])
204
+ age, symptoms, severity = self.extract_info_from_bot_response(latest['message'])
205
+
206
+ if not symptoms: return {"error": "No symptoms detected"}
207
+
208
+ risk_label, risk_score, risk_pct = self.calculate_risk_score(age, symptoms, severity)
209
+
210
+ return {
211
+ 'age': age,
212
+ 'symptoms': symptoms,
213
+ 'symptom_severity': severity,
214
+ 'risk_label': risk_label,
215
+ 'risk_score': round(risk_score, 2),
216
+ 'risk_percentage': risk_pct,
217
+ 'possible_conditions': self.data[self.data['Symptom'].isin(symptoms)]['Condition'].unique().tolist(),
218
+ 'analysis_timestamp': datetime.now().isoformat()
219
+ }
220
+ except Exception as e:
221
+ return {"error": f"Processing Error: {str(e)}"}
222
 
223
  def initialize_groq_client():
224
  try:
225
+ api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
 
 
 
 
 
 
 
226
  if not api_key:
227
+ api_key = st.text_input("Enter Groq API Key:", type="password")
228
+ if not api_key: return False
 
 
229
 
230
+ st.session_state.client = Groq(api_key=api_key)
 
 
231
  return True
232
  except Exception as e:
233
+ st.error(f"Groq Error: {str(e)}")
234
  return False
235
 
236
  def symptom_interrogation_step():
237
  client = st.session_state.client
238
  main_symptom = st.session_state.patient_info['main_symptom']
239
+ step = len(st.session_state.symptom_details)
240
 
241
  if step == 0:
 
242
  medical_focus = {
243
  'pain': "location/radiation/provoking factors",
244
  'fever': "pattern/associated symptoms/response to meds",
245
  'gi': "bowel changes/ingestion timing/associated symptoms",
246
  'respiratory': "exertion relationship/sputum/triggers"
247
  }
248
+ focus = medical_focus.get(main_symptom.lower(), "temporal pattern/severity progression/associated symptoms")
 
 
249
  prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
250
+ focusing on {focus}. Use simple, patient-friendly language. Ask only ONE question."""
 
 
 
 
 
 
 
 
 
 
251
  else:
 
252
  last_qa = st.session_state.symptom_details[-1]
253
+ prompt = f"""Based on previous Q: {last_qa['question']} β†’ A: {last_qa['answer']}
254
+ Ask the NEXT critical question about {main_symptom} considering red flags."""
 
 
255
 
256
  try:
257
  response = client.chat.completions.create(
258
+ messages=[{"role": "user", "content": prompt}],
259
  model="mixtral-8x7b-32768",
260
  temperature=0.3
261
  )
262
  question = response.choices[0].message.content.strip()
263
+ if not question.endswith('?'): question += '?'
 
264
  st.session_state.current_question = question
265
  except Exception as e:
266
+ st.error(f"Question Generation Error: {str(e)}")
267
  st.stop()
268
 
269
  def handle_symptom_interrogation():
 
275
 
276
  if 'current_question' in st.session_state:
277
  with st.form("symptom_qna"):
278
+ st.markdown(f'<div class="dr-message">πŸ‘¨β€βš• {st.session_state.current_question}</div>', unsafe_allow_html=True)
279
  answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
280
 
281
  if st.form_submit_button("Next"):
 
285
  "answer": answer
286
  })
287
  del st.session_state.current_question
288
+
289
+ # Emergency check
290
  if len(st.session_state.symptom_details) >= 3:
291
  last_answer = st.session_state.symptom_details[-1]['answer']
292
  try:
293
  urgency_check = st.session_state.client.chat.completions.create(
294
+ messages=[{"role": "user", "content": f"Does this indicate emergency? '{last_answer}' Yes/No"}],
 
295
  model="mixtral-8x7b-32768",
296
  temperature=0
297
  ).choices[0].message.content
 
298
  if 'YES' in urgency_check.upper():
299
+ st.markdown('<div class="emergency-alert">🚨 Emergency Detected! Seek Immediate Care.</div>', unsafe_allow_html=True)
300
  st.session_state.current_step = 4
301
  return
302
+ except: pass
303
+
 
304
  if len(st.session_state.symptom_details) < 7:
305
  st.session_state.current_step = 1
 
306
  else:
307
  st.session_state.current_step = 3
308
+ st.rerun()
309
  else:
310
  st.warning("Please provide an answer")
311
 
 
318
  st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
319
 
320
  if st.form_submit_button("Next"):
321
+ if all(st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']):
322
  st.session_state.current_step = 1
323
  st.rerun()
324
  else:
325
+ st.warning("Please fill all fields")
326
 
327
  def collect_medical_history():
328
  st.header("Medical History")
 
338
  st.rerun()
339
 
340
  def generate_risk_assessment():
341
+ st.header("Comprehensive Assessment")
342
 
343
  try:
344
+ # Generate clinical summary
345
+ symptom_log = "\n".join([f"Q: {q['question']}\nA: {q['answer']}" for q in st.session_state.symptom_details])
 
 
 
346
  patient_profile = f"""
 
347
  Name: {st.session_state.patient_info['name']}
348
  Age: {st.session_state.patient_info['age']}
349
  Gender: {st.session_state.patient_info['gender']}
350
+ Main Symptom: {st.session_state.patient_info['main_symptom']}
351
 
352
+ Symptom Details:
 
 
 
353
  {symptom_log}
354
 
355
+ Medical History: {st.session_state.patient_info.get('medical_history', 'N/A')}
356
+ Medications: {st.session_state.patient_info.get('medications', 'N/A')}
357
+ Allergies: {st.session_state.patient_info.get('allergies', 'N/A')}
 
 
 
 
 
 
 
 
 
358
  """
359
 
360
+ # Risk analysis
361
+ analysis_system = MedicalAnalysisSystem("DATASET.csv")
362
+ analysis_results = analysis_system.analyze_patient_data(patient_profile)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ # Store the analysis results in session state
365
+ st.session_state.analysis_results = analysis_results
 
 
 
 
 
 
 
 
 
 
366
 
367
+ col1, col2 = st.columns(2)
368
+ with col1:
369
+ st.subheader("Clinical Summary")
370
+ st.markdown(f"\n{patient_profile}\n")
371
+
372
+ with col2:
373
+ st.subheader("Risk Analysis")
374
+ if "error" in analysis_results:
375
+ st.error(analysis_results["error"])
376
+ else:
377
+ st.metric("Risk Level", analysis_results['risk_label'])
378
+ st.progress(analysis_results['risk_percentage'] / 100)
379
+ st.write(f"*Score*: {analysis_results['risk_score']:.1f}/{analysis_system.MAX_RISK_SCORE:.1f}")
380
+
381
+ # Download report
382
+ report_content = f"CLINICAL SUMMARY:\n{patient_profile}\n\nRISK ANALYSIS:\n{analysis_results}"
383
+ st.download_button("Download Full Report", report_content, "medical_report.txt")
384
+
385
  except Exception as e:
386
+ st.error(f"Assessment Error: {str(e)}")
387
 
388
+ def schedule_appointment():
389
+ st.header("πŸš‘ Schedule Specialist Appointment")
390
 
391
+ # Doctor database
392
+ doctors = [
393
+ {
394
+ 'name': 'Dr. Sarah Johnson',
395
+ 'hospital': 'City General Hospital',
396
+ 'specialty': 'Cardiology',
397
+ 'slots': ['2024-03-25 09:00', '2024-03-25 10:00', '2024-03-26 11:00'],
398
+ 'contact': '555-0101',
399
+ 'emergency': True
400
+ },
401
+ {
402
+ 'name': 'Dr. Michael Chen',
403
+ 'hospital': 'Metropolitan Health',
404
+ 'specialty': 'Neurology',
405
+ 'slots': ['2024-03-25 14:00', '2024-03-26 09:30', '2024-03-27 15:00'],
406
+ 'contact': '555-0102',
407
+ 'emergency': True
408
+ },
409
+ {
410
+ 'name': 'Dr. Emily White',
411
+ 'hospital': 'Sunrise Clinic',
412
+ 'specialty': 'General Practice',
413
+ 'slots': ['2024-03-24 10:00', '2024-03-25 11:00', '2024-03-26 16:00'],
414
+ 'contact': '555-0103',
415
+ 'emergency': False
416
+ },
417
+ {
418
+ 'name': 'Dr. Raj Patel',
419
+ 'hospital': 'Westside Medical Center',
420
+ 'specialty': 'Orthopedics',
421
+ 'slots': ['2024-03-25 08:00', '2024-03-26 10:00', '2024-03-27 09:00'],
422
+ 'contact': '555-0104',
423
+ 'emergency': True
424
+ },
425
+ {
426
+ 'name': 'Dr. Linda Garcia',
427
+ 'hospital': "Children's Hospital",
428
+ 'specialty': 'Pediatrics',
429
+ 'slots': ['2024-03-25 12:00', '2024-03-26 14:00', '2024-03-27 10:00'],
430
+ 'contact': '555-0105',
431
+ 'emergency': True
432
+ }
433
+ ]
434
+
435
+ risk_data = st.session_state.get('analysis_results', {})
436
+
437
+ # Check if risk_data is None or empty
438
+ if not risk_data or "error" in risk_data:
439
+ st.error("No risk assessment available. Please complete the assessment first.")
440
+ return
441
+
442
+ risk_label = risk_data.get('risk_label', 'Low')
443
+
444
+ # Priority explanation
445
+ st.markdown(f"""
446
+ <div class="priority-banner">
447
+ Your current risk level: <strong>{risk_label}</strong> priority
448
+ <br>{(risk_label == 'High') and 'πŸŸ₯ Urgent - Same day appointments available'
449
+ or (risk_label == 'Medium') and '🟨 Semi-Urgent - Next day appointments'
450
+ or '🟩 Routine - Book within 3 days'}
451
  </div>
452
+ """, unsafe_allow_html=True)
453
+
454
+ # Filter doctors based on risk
455
+ if risk_label == 'High':
456
+ available_doctors = [d for d in doctors if d['emergency']]
457
+ else:
458
+ available_doctors = doctors
459
+
460
+ # Display doctors in columns
461
+ cols = st.columns(2)
462
+ for idx, doctor in enumerate(available_doctors):
463
+ with cols[idx % 2]:
464
+ with st.container():
465
+ st.subheader(f"πŸ₯ {doctor['hospital']}")
466
+ st.markdown(f"""
467
+ *Doctor*: {doctor['name']}
468
+ *Specialty*: {doctor['specialty']}
469
+ *Contact*: {doctor['contact']}
470
+ """)
471
+
472
+ # Sort slots based on risk
473
+ slots = sorted(doctor['slots'], key=lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M'))
474
+ if risk_label == 'Low':
475
+ slots = slots[::-1]
476
+
477
+ selected_slot = st.selectbox(f"Available slots with {doctor['name']}",
478
+ slots,
479
+ key=f"slot_{idx}")
480
+
481
+ if st.button(f"Book with {doctor['name']}", key=f"book_{idx}"):
482
+ st.session_state.appointment_details = {
483
+ 'doctor': doctor['name'],
484
+ 'hospital': doctor['hospital'],
485
+ 'time': selected_slot,
486
+ 'contact': doctor['contact'],
487
+ 'risk_level': risk_label
488
+ }
489
+ st.success("Appointment booked successfully!")
490
+ st.balloons()
491
+
492
+ # Generate appointment summary
493
+ summary = f"""
494
+ *Patient Name*: {st.session_state.patient_info['name']}
495
+ *Age*: {st.session_state.patient_info['age']}
496
+ *Booked Appointment*:
497
+ - Doctor: {doctor['name']}
498
+ - Hospital: {doctor['hospital']}
499
+ - Time: {selected_slot}
500
+ - Contact: {doctor['contact']}
501
+ - Priority Level: {risk_label}
502
+ """
503
+ st.session_state.appointment_summary = summary
504
+
505
+ # Show download button
506
+ st.download_button("Download Appointment Details",
507
+ summary,
508
+ "appointment_confirmation.txt",
509
+ help="Save your appointment details")
510
+ def main():
511
+ st.title("πŸ₯ AI Medical Consultancy")
512
 
513
+ # Initialize Groq client
514
  if not initialize_groq_client():
515
+ st.warning("Please provide a valid Groq API key to proceed.")
516
  return
517
+
518
+ # Define steps for the progress bar
519
+ steps = ["Patient Info", "Symptoms", "History", "Report", "Booking"]
520
 
521
+ # Display progress bar
522
+ progress = f"""
523
+ <div class="progress-bar">
524
+ {"".join(f'<div class="step {"active" if st.session_state.current_step >= i else ""}">{i+1}. {step}</div>'
525
+ for i, step in enumerate(steps))}
526
+ </div>
527
+ """
528
+ st.markdown(progress, unsafe_allow_html=True)
529
+
530
+ # Step routing logic
531
+ if st.session_state.current_step == 0:
532
+ collect_basic_info() # Step 1: Collect patient information
533
+ elif st.session_state.current_step in [1, 2]:
534
+ handle_symptom_interrogation() # Step 2: Symptom analysis
535
+ elif st.session_state.current_step == 3:
536
+ collect_medical_history() # Step 3: Collect medical history
537
+ elif st.session_state.current_step == 4:
538
+ generate_risk_assessment() # Step 4: Generate risk assessment
539
+ if st.button("πŸ“… Schedule Doctor Appointment"):
540
+ st.session_state.current_step = 5 # Move to the booking step
541
+ st.rerun()
542
+ elif st.session_state.current_step == 5:
543
+ schedule_appointment() # Step 5: Schedule appointment with a doctor
544
+
545
+ # Debugging: Show session state (optional)
546
+ if st.sidebar.checkbox("Show Session State (Debug)"):
547
+ st.sidebar.write(st.session_state)
548
 
549
+ if _name_ == "_main_":
550
  main()