ayush200399391001 commited on
Commit
87fdfe1
·
verified ·
1 Parent(s): 87c871c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +590 -387
app.py CHANGED
@@ -1,388 +1,591 @@
1
- import streamlit as st
2
- import os
3
- from groq import Groq
4
- from datetime import datetime
5
-
6
- # Set page config FIRST
7
- st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
8
-
9
- # Custom CSS for styling
10
- st.markdown("""
11
- <style>
12
- /* Main container styling */
13
- .stApp {
14
- background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
15
- font-family: 'Arial', sans-serif;
16
- }
17
-
18
- /* Headers styling */
19
- h1, h2, h3 {
20
- color: #2c3e50 !important;
21
- border-bottom: 2px solid #3498db;
22
- padding-bottom: 0.3em;
23
- }
24
-
25
- /* Form containers */
26
- .stForm {
27
- background: rgba(255, 255, 255, 0.9);
28
- border-radius: 15px;
29
- padding: 2rem;
30
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
31
- margin: 1rem 0;
32
- }
33
-
34
- /* Input fields */
35
- .stTextInput input, .stNumberInput input, .stSelectbox select, .stTextArea textarea {
36
- border: 2px solid #3498db !important;
37
- border-radius: 8px !important;
38
- padding: 0.8rem !important;
39
- transition: all 0.3s ease;
40
- }
41
-
42
- .stTextInput input:focus, .stNumberInput input:focus,
43
- .stSelectbox select:focus, .stTextArea textarea:focus {
44
- border-color: #2980b9 !important;
45
- box-shadow: 0 0 8px rgba(52, 152, 219, 0.3) !important;
46
- }
47
-
48
- /* Buttons styling */
49
- .stButton>button {
50
- background: linear-gradient(135deg, #3498db 0%, #2980b9 100%) !important;
51
- color: white !important;
52
- border: none !important;
53
- border-radius: 8px !important;
54
- padding: 0.8rem 1.5rem !important;
55
- font-size: 1rem !important;
56
- transition: transform 0.2s ease;
57
- }
58
-
59
- .stButton>button:hover {
60
- transform: translateY(-2px);
61
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
62
- }
63
-
64
- /* Progress indicator */
65
- .progress-bar {
66
- display: flex;
67
- justify-content: space-between;
68
- margin: 2rem 0;
69
- padding: 1rem;
70
- background: rgba(255, 255, 255, 0.9);
71
- border-radius: 10px;
72
- }
73
-
74
- .step {
75
- flex: 1;
76
- text-align: center;
77
- padding: 0.5rem;
78
- font-weight: bold;
79
- color: #7f8c8d;
80
- }
81
-
82
- .step.active {
83
- color: #3498db;
84
- border-bottom: 3px solid #3498db;
85
- }
86
-
87
- /* Chat bubbles */
88
- .dr-message {
89
- background: #3498db;
90
- color: white;
91
- border-radius: 15px;
92
- padding: 1rem;
93
- margin: 1rem 0;
94
- max-width: 80%;
95
- width: fit-content;
96
- }
97
-
98
- /* Emergency alert */
99
- .emergency-alert {
100
- background: #e74c3c;
101
- color: white;
102
- padding: 2rem;
103
- border-radius: 15px;
104
- animation: pulse 1.5s infinite;
105
- }
106
-
107
- @keyframes pulse {
108
- 0% { transform: scale(1); }
109
- 50% { transform: scale(1.02); }
110
- 100% { transform: scale(1); }
111
- }
112
-
113
- /* Download button */
114
- .download-btn {
115
- background: linear-gradient(135deg, #2ecc71 0%, #27ae60 100%) !important;
116
- }
117
- </style>
118
- """, unsafe_allow_html=True)
119
-
120
- # Initialize session state variables
121
- if 'current_step' not in st.session_state:
122
- st.session_state.current_step = 0
123
- if 'symptom_details' not in st.session_state:
124
- st.session_state.symptom_details = [] # Initialize as an empty list
125
- if 'patient_info' not in st.session_state:
126
- st.session_state.patient_info = {}
127
-
128
- def initialize_groq_client():
129
- try:
130
- # Try to get the API key from Streamlit secrets
131
- api_key = None
132
- try:
133
- api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
134
- except FileNotFoundError:
135
- st.warning("No `secrets.toml` file found. Please create one in the `.streamlit` folder.")
136
-
137
- # If not found, prompt the user to enter it
138
- if not api_key:
139
- api_key = st.text_input("Enter your Groq API Key:", type="password")
140
- if not api_key:
141
- st.warning("Please provide a valid Groq API key to proceed.")
142
- return False
143
-
144
- # Initialize the Groq client
145
- client = Groq(api_key=api_key)
146
- st.session_state.client = client
147
- return True
148
- except Exception as e:
149
- st.error(f"Error initializing Groq client: {str(e)}")
150
- return False
151
-
152
- def symptom_interrogation_step():
153
- client = st.session_state.client
154
- main_symptom = st.session_state.patient_info['main_symptom']
155
- step = len(st.session_state.symptom_details) # Use the number of collected details as the step
156
-
157
- if step == 0:
158
- # First question: ask about the main symptom
159
- medical_focus = {
160
- 'pain': "location/radiation/provoking factors",
161
- 'fever': "pattern/associated symptoms/response to meds",
162
- 'gi': "bowel changes/ingestion timing/associated symptoms",
163
- 'respiratory': "exertion relationship/sputum/triggers"
164
- }
165
- focus = medical_focus.get(main_symptom.lower(),
166
- "temporal pattern/severity progression/associated symptoms")
167
-
168
- prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
169
- focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient.
170
- Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}.
171
- Ask only one question at a time to avoid overwhelming the patient.
172
- Keep your language clear, professional, and easy to understand."""
173
-
174
- messages = [
175
- {"role": "system", "content": "Ask focused clinical questions. One at a time."},
176
- {"role": "user", "content": prompt}
177
- ]
178
- else:
179
- # Subsequent questions: use the last Q&A to generate the next question
180
- last_qa = st.session_state.symptom_details[-1]
181
- prompt = f"""Last Q&A: {last_qa['question']} → {last_qa['answer']}
182
- Based on this, ask the NEXT most critical question to differentiate between
183
- possible causes of {main_symptom}. Consider red flags and likelihood."""
184
- messages = [{"role": "user", "content": prompt}]
185
-
186
- try:
187
- response = client.chat.completions.create(
188
- messages=messages,
189
- model="mixtral-8x7b-32768",
190
- temperature=0.3
191
- )
192
- question = response.choices[0].message.content.strip()
193
- if not question.endswith('?'):
194
- question += '?'
195
- st.session_state.current_question = question
196
- except Exception as e:
197
- st.error(f"Error generating question: {str(e)}")
198
- st.stop()
199
-
200
- def handle_symptom_interrogation():
201
- st.header("Symptom Analysis")
202
-
203
- if st.session_state.current_step == 1:
204
- symptom_interrogation_step()
205
- st.session_state.current_step = 2
206
-
207
- if 'current_question' in st.session_state:
208
- with st.form("symptom_qna"):
209
- st.markdown(f'<div class="dr-message">👨‍⚕️ {st.session_state.current_question}</div>', unsafe_allow_html=True)
210
- answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
211
-
212
- if st.form_submit_button("Next"):
213
- if answer:
214
- st.session_state.symptom_details.append({
215
- "question": st.session_state.current_question,
216
- "answer": answer
217
- })
218
- del st.session_state.current_question
219
-
220
- # Check for emergency after 3 questions
221
- if len(st.session_state.symptom_details) >= 3:
222
- last_answer = st.session_state.symptom_details[-1]['answer']
223
- try:
224
- urgency_check = st.session_state.client.chat.completions.create(
225
- messages=[{"role": "user", "content":
226
- f"Does '{last_answer}' indicate immediate emergency? Yes/No"}],
227
- model="mixtral-8x7b-32768",
228
- temperature=0
229
- ).choices[0].message.content
230
-
231
- if 'YES' in urgency_check.upper():
232
- st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True)
233
- st.session_state.current_step = 4
234
- return
235
- except Exception as e:
236
- st.error(f"Error checking urgency: {str(e)}")
237
-
238
- if len(st.session_state.symptom_details) < 7:
239
- st.session_state.current_step = 1
240
- st.rerun()
241
- else:
242
- st.session_state.current_step = 3
243
- st.rerun()
244
- else:
245
- st.warning("Please provide an answer")
246
-
247
- def collect_basic_info():
248
- st.header("Patient Information")
249
- with st.form("basic_info"):
250
- st.session_state.patient_info['name'] = st.text_input("Full Name")
251
- st.session_state.patient_info['age'] = st.number_input("Age", min_value=0, max_value=120)
252
- st.session_state.patient_info['gender'] = st.selectbox("Gender", ["Male", "Female", "Other"])
253
- st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
254
-
255
- if st.form_submit_button("Next"):
256
- if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]):
257
- st.session_state.current_step = 1
258
- st.rerun()
259
- else:
260
- st.warning("Please fill all required fields")
261
-
262
- def collect_medical_history():
263
- st.header("Medical History")
264
- with st.form("medical_history"):
265
- st.session_state.patient_info['medical_history'] = st.text_area("Relevant Medical History")
266
- st.session_state.patient_info['medications'] = st.text_area("Current Medications")
267
- st.session_state.patient_info['allergies'] = st.text_input("Known Allergies")
268
- st.session_state.patient_info['last_meal'] = st.text_input("Last Meal Time")
269
- st.session_state.patient_info['recent_travel'] = st.text_input("Recent Travel History")
270
-
271
- if st.form_submit_button("Submit"):
272
- st.session_state.current_step = 4
273
- st.rerun()
274
-
275
- def generate_risk_assessment():
276
- st.header("Risk Assessment")
277
-
278
- try:
279
- symptom_log = "\n".join(
280
- [f"Q: {q['question']}\nA: {q['answer']}"
281
- for q in st.session_state.symptom_details]
282
- )
283
-
284
- patient_profile = f"""
285
- **Patient Profile**
286
- Name: {st.session_state.patient_info['name']}
287
- Age: {st.session_state.patient_info['age']}
288
- Gender: {st.session_state.patient_info['gender']}
289
-
290
- **Primary Complaint**
291
- {st.session_state.patient_info['main_symptom']}
292
-
293
- **Symptom Interrogation**
294
- {symptom_log}
295
-
296
- **Medical History**
297
- {st.session_state.patient_info.get('medical_history', 'None reported')}
298
-
299
- **Current Medications**
300
- {st.session_state.patient_info.get('medications', 'None')}
301
-
302
- **Allergies**
303
- {st.session_state.patient_info.get('allergies', 'None reported')}
304
-
305
- **Recent Context**
306
- Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')}
307
- Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')}
308
- """
309
-
310
- analysis_prompt = f"""STRICTLY follow these instructions:
311
- 1. Analyze this case: {patient_profile}
312
- 2. Output ONLY this exact format WITHOUT ANY additional text:
313
- [Age]-year-old [gender] with [symptom details]
314
-
315
- Example Output:
316
- "45-year-old man with severe chest pain radiating to the jaw"
317
-
318
- Your Output MUST BE:"""
319
-
320
- response = st.session_state.client.chat.completions.create(
321
- messages=[
322
- {"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."},
323
- {"role": "user", "content": analysis_prompt}
324
- ],
325
- model="mixtral-8x7b-32768",
326
- temperature=0.3,
327
- max_tokens=100
328
- )
329
-
330
- risk_prompt = response.choices[0].message.content.strip('"')
331
-
332
- st.subheader("Clinical Summary")
333
- st.markdown(f"```\n{risk_prompt}\n```")
334
-
335
- # Create download button
336
- timestamp = datetime.now().strftime('%Y%m%d%H%M')
337
- filename = f"{st.session_state.patient_info['name'].replace(' ', '_')}_assessment_{timestamp}.txt"
338
- st.download_button(
339
- label="Download Assessment",
340
- data=risk_prompt,
341
- file_name=filename,
342
- mime="text/plain"
343
- )
344
-
345
- except Exception as e:
346
- st.error(f"Error generating risk assessment: {str(e)}")
347
-
348
- def main():
349
- st.title("🏥 AI Medical Consultancy")
350
-
351
- # Progress indicator
352
- steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"]
353
- progress_html = """
354
- <div class="progress-bar">
355
- <div class="step {}">{}</div>
356
- <div class="step {}">{}</div>
357
- <div class="step {}">{}</div>
358
- <div class="step {}">{}</div>
359
- </div>
360
- """.format(
361
- 'active' if st.session_state.current_step >= 0 else '',
362
- '1. Patient Info',
363
- 'active' if st.session_state.current_step >= 1 else '',
364
- '2. Symptoms',
365
- 'active' if st.session_state.current_step >= 3 else '',
366
- '3. History',
367
- 'active' if st.session_state.current_step >= 4 else '',
368
- '4. Report'
369
- )
370
- st.markdown(progress_html, unsafe_allow_html=True)
371
-
372
- if not initialize_groq_client():
373
- return
374
-
375
- steps = {
376
- 0: collect_basic_info,
377
- 1: handle_symptom_interrogation,
378
- 2: handle_symptom_interrogation,
379
- 3: collect_medical_history,
380
- 4: generate_risk_assessment
381
- }
382
-
383
- current_step = st.session_state.get('current_step', 0)
384
- if current_step in steps:
385
- steps[current_step]()
386
-
387
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  main()
 
1
+ import streamlit as st
2
+ import os
3
+ from groq import Groq
4
+ from datetime import datetime
5
+ from transformers import pipeline
6
+ import pandas as pd
7
+ import re
8
+
9
+ # Set page config FIRST
10
+ st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
11
+
12
+ # Custom CSS for styling
13
+ st.markdown("""
14
+ <style>
15
+ /* Color Variables */
16
+ :root {
17
+ --primary: #3498db; /* Blue */
18
+ --secondary: #2c3e50; /* Dark accent */
19
+ --accent: #f1c40f; /* Yellow */
20
+ --success: #2ecc71; /* Positive actions */
21
+ --light: #ffffff; /* White backgrounds */
22
+ --dark: #000000; /* Black text/elements */
23
+ }
24
+
25
+ /* Main container styling */
26
+ .stApp {
27
+ background: linear-gradient(135deg, #3498db 0%, #e0e0e0 100%);
28
+ font-family: 'Arial', sans-serif;
29
+ }
30
+
31
+ /* Headers styling */
32
+ h1, h2, h3 {
33
+ color: var(--dark) !important;
34
+ border-bottom: 3px solid var(--primary);
35
+ padding-bottom: 0.3em;
36
+ }
37
+
38
+ /* Form containers */
39
+ .stForm {
40
+ background: #000000;
41
+ border: 1px solid rgba(44, 62, 80, 0.2);
42
+ border-radius: 15px;
43
+ padding: 2rem;
44
+ box-shadow: 0 8px 30px rgba(0, 0, 0, 0.12);
45
+ margin: 1rem 0;
46
+ }
47
+
48
+ /* Input fields */
49
+ .stTextInput input, .stNumberInput input,
50
+ .stSelectbox select, .stTextArea textarea {
51
+ border: 2px solid #00FFFF !important;
52
+ border-radius: 10px !important;
53
+ padding: 1rem !important;
54
+ background: #00FFFF !important;
55
+ transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
56
+ color: var(--dark) !important;
57
+ }
58
+
59
+ .stTextInput input:focus, .stNumberInput input:focus,
60
+ .stSelectbox select:focus, .stTextArea textarea:focus {
61
+ border-color: var(--primary) !important;
62
+ box-shadow: 0 0 12px rgba(52, 152, 219, 0.2) !important;
63
+ background: white !important;
64
+ color: var(--dark) !important;
65
+ }
66
+
67
+ /* Buttons styling */
68
+ .stButton>button {
69
+ background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important;
70
+ color: var(--dark) !important;
71
+ border: none !important;
72
+ border-radius: 10px !important;
73
+ padding: 1rem 2rem !important;
74
+ font-size: 1rem !important;
75
+ transition: all 0.3s ease;
76
+ position: relative;
77
+ overflow: hidden;
78
+ }
79
+
80
+ .stButton>button:hover {
81
+ transform: translateY(-2px);
82
+ box-shadow: 0 8px 15px rgba(52, 152, 219, 0.3);
83
+ opacity: 0.95;
84
+ }
85
+
86
+ .stButton>button:active {
87
+ transform: translateY(0);
88
+ opacity: 1;
89
+ }
90
+
91
+ /* Progress indicator */
92
+ .progress-bar {
93
+ display: flex;
94
+ justify-content: space-between;
95
+ margin: 2rem 0;
96
+ padding: 1rem;
97
+ background: rgba(255, 255, 255, 0.9);
98
+ border-radius: 10px;
99
+ color: var(--dark) !important;
100
+ }
101
+
102
+ .step {
103
+ flex: 1;
104
+ text-align: center;
105
+ padding: 1rem;
106
+ font-weight: 600;
107
+ color: #95a5a6;
108
+ position: relative;
109
+ }
110
+
111
+ .step.active {
112
+ color: var(--primary);
113
+ }
114
+
115
+ .step.active:after {
116
+ content: '';
117
+ position: absolute;
118
+ bottom: -1px;
119
+ left: 50%;
120
+ transform: translateX(-50%);
121
+ width: 40%;
122
+ height: 3px;
123
+ background: var(--primary);
124
+ }
125
+
126
+ /* Chat bubbles */
127
+ .dr-message {
128
+ background: linear-gradient(135deg, var(--primary) 0%, #2980b9 100%);
129
+ color: white;
130
+ border-radius: 20px 20px 20px 4px;
131
+ padding: 1.2rem 1.5rem;
132
+ margin: 1rem 0;
133
+ max-width: 80%;
134
+ width: fit-content;
135
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
136
+ }
137
+
138
+ .user-message {
139
+ background: linear-gradient(135deg, #f1c40f 0%, #e1b800 100%);
140
+ margin-left: auto;
141
+ border-radius: 20px 20px 4px 20px;
142
+ color: var(--dark) !important;
143
+ }
144
+
145
+ /* Emergency alert */
146
+ .emergency-alert {
147
+ background: linear-gradient(135deg, var(--accent) 0%, #c0392b 100%);
148
+ color: white;
149
+ padding: 2rem;
150
+ border-radius: 15px;
151
+ animation: pulse 1.5s infinite;
152
+ text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
153
+ }
154
+
155
+ @keyframes pulse {
156
+ 0% { transform: scale(1); }
157
+ 50% { transform: scale(1.02); }
158
+ 100% { transform: scale(1); }
159
+ }
160
+
161
+ /* Download button */
162
+ .download-btn {
163
+ background: linear-gradient(135deg, var(--success) 0%, #27ae60 100%) !important;
164
+ }
165
+
166
+ /* Enhanced Data Visualization Contrast */
167
+ .stDataFrame {
168
+ border: 1px solid rgba(0, 0, 0, 0.1);
169
+ border-radius: 12px;
170
+ overflow: hidden;
171
+ background: #f0f0f0;
172
+ color: var(--dark) !important;
173
+ }
174
+
175
+ /* Tabbed Interface Styling */
176
+ .stTabs [role="tablist"] {
177
+ gap: 10px;
178
+ padding: 8px;
179
+ background: rgba(240, 240, 240, 0.9);
180
+ border-radius: 12px;
181
+ color: var(--dark) !important;
182
+ }
183
+
184
+ .stTabs [role="tab"] {
185
+ background: #ffffff !important;
186
+ border-radius: 8px !important;
187
+ transition: all 0.3s ease;
188
+ color: var(--dark) !important;
189
+ }
190
+
191
+ .stTabs [role="tab"][aria-selected="true"] {
192
+ background: var(--primary) !important;
193
+ color: white !important;
194
+ transform: scale(1.05);
195
+ }
196
+ </style>
197
+ """, unsafe_allow_html=True)
198
+
199
+
200
+ # Initialize session state variables
201
+ if 'current_step' not in st.session_state:
202
+ st.session_state.current_step = 0
203
+ if 'symptom_details' not in st.session_state:
204
+ st.session_state.symptom_details = [] # Initialize as an empty list
205
+ if 'patient_info' not in st.session_state:
206
+ st.session_state.patient_info = {}
207
+
208
+ # Initialize models and dataset with caching
209
+ @st.cache_resource
210
+ def load_ner_model():
211
+ return pipeline("token-classification",
212
+ model="d4data/biomedical-ner-all",
213
+ aggregation_strategy='first')
214
+
215
+ @st.cache_resource
216
+ def load_dataset():
217
+ try:
218
+ data = pd.read_csv('DATASET.csv', encoding='latin1')
219
+ data['Symptom_lower'] = data['Symptom'].str.strip().str.lower()
220
+ return data
221
+ except FileNotFoundError:
222
+ st.error("DATASET.csv not found. Please ensure the file is in the correct directory.")
223
+ return None
224
+
225
+ pipe = load_ner_model()
226
+ data = load_dataset()
227
+ if data is None:
228
+ st.stop()
229
+
230
+ SEVERITY_KEYWORDS = {
231
+ 'normal': ['mild', 'occasional', 'controlled', 'temporary', 'fleeting'],
232
+ 'moderate': ['moderate', 'persistent', 'frequent', 'prolonged', 'bloating'],
233
+ 'severe': ['severe', 'extreme', 'crushing', 'radiating', 'blood', 'inability',
234
+ 'sweating', 'vomiting', 'fever', 'swelling', 'radiates', 'persistent >3 days']
235
+ }
236
+
237
+ def merge_entities(entities, text):
238
+ """Merge entities using dataset symptom names for multi-word matching"""
239
+ merged = []
240
+ detected = set()
241
+ text_lower = text.lower()
242
+
243
+ symptoms_sorted = data['Symptom'].str.lower().sort_values(
244
+ key=lambda x: x.str.len(), ascending=False).tolist()
245
+
246
+ for symptom in symptoms_sorted:
247
+ if symptom in text_lower and symptom not in detected:
248
+ start = text_lower.find(symptom)
249
+ end = start + len(symptom)
250
+ merged.append({
251
+ 'word': text[start:end],
252
+ 'start': start,
253
+ 'end': end,
254
+ 'entity_group': 'Sign_symptom'
255
+ })
256
+ detected.add(symptom)
257
+ return merged
258
+
259
+ def determine_condition(text, entity):
260
+ context_window = re.findall(r'\w+', text[max(0, entity['start']-20):entity['end']+20])
261
+ context = ' '.join(context_window).lower()
262
+
263
+ for level, keywords in SEVERITY_KEYWORDS.items():
264
+ if any(k in context for k in keywords):
265
+ return level.capitalize()
266
+ return 'Normal'
267
+
268
+ def get_risk_score(symptom, condition):
269
+ symptom_clean = symptom.strip().lower()
270
+ matches = data[
271
+ data['Symptom_lower'].str.contains(rf'\b{symptom_clean}\b', regex=True, flags=re.IGNORECASE)
272
+ ]
273
+ if not matches.empty:
274
+ matches = matches[matches['Condition'] == condition]
275
+ return matches['Risk Score'].values[0] if not matches.empty else 0
276
+ return 0
277
+
278
+ def calculate_risk_score(text):
279
+ entities = pipe(text)
280
+ merged_entities = merge_entities(entities, text)
281
+
282
+ score = 0
283
+ breakdown = []
284
+ for ent in merged_entities:
285
+ if ent['entity_group'] == "Sign_symptom":
286
+ symptom = ent['word'].strip().lower()
287
+ condition = determine_condition(text, ent)
288
+ risk = get_risk_score(symptom, condition)
289
+ score += risk
290
+ breakdown.append({
291
+ 'symptom': symptom.capitalize(),
292
+ 'condition': condition,
293
+ 'risk': risk
294
+ })
295
+ return score, breakdown
296
+
297
+ def initialize_groq_client():
298
+ try:
299
+ # Try to get the API key from Streamlit secrets
300
+ api_key = None
301
+ try:
302
+ api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
303
+ except FileNotFoundError:
304
+ st.warning("No secrets.toml file found. Please create one in the .streamlit folder.")
305
+
306
+ # If not found, prompt the user to enter it
307
+ if not api_key:
308
+ api_key = st.text_input("Enter your Groq API Key:", type="password")
309
+ if not api_key:
310
+ st.warning("Please provide a valid Groq API key to proceed.")
311
+ return False
312
+
313
+ # Initialize the Groq client
314
+ client = Groq(api_key=api_key)
315
+ st.session_state.client = client
316
+ return True
317
+ except Exception as e:
318
+ st.error(f"Error initializing Groq client: {str(e)}")
319
+ return False
320
+
321
+ def symptom_interrogation_step():
322
+ client = st.session_state.client
323
+ main_symptom = st.session_state.patient_info['main_symptom']
324
+ step = len(st.session_state.symptom_details) # Use the number of collected details as the step
325
+
326
+ if step == 0:
327
+ # First question: ask about the main symptom
328
+ medical_focus = {
329
+ 'pain': "location/radiation/provoking factors",
330
+ 'fever': "pattern/associated symptoms/response to meds",
331
+ 'gi': "bowel changes/ingestion timing/associated symptoms",
332
+ 'respiratory': "exertion relationship/sputum/triggers"
333
+ }
334
+ focus = medical_focus.get(main_symptom.lower(),
335
+ "temporal pattern/severity progression/associated symptoms")
336
+
337
+ prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
338
+ focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient.
339
+ Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}.
340
+ Ask only one question at a time to avoid overwhelming the patient.
341
+ Keep your language clear, professional, and easy to understand.
342
+
343
+ Dont display possibe symptoms or why you are asking questions."""
344
+
345
+ messages = [
346
+ {"role": "system", "content": "Ask focused clinical questions. One at a time."},
347
+ {"role": "user", "content": prompt}
348
+ ]
349
+ else:
350
+ # Subsequent questions: use the last Q&A to generate the next question
351
+ last_qa = st.session_state.symptom_details[-1]
352
+ prompt = f"""Last Q&A: {last_qa['question']} → {last_qa['answer']}
353
+ Based on this, ask the NEXT most critical question to differentiate between
354
+ possible causes of {main_symptom}. Consider red flags and likelihood."""
355
+ messages = [{"role": "user", "content": prompt}]
356
+
357
+ try:
358
+ response = client.chat.completions.create(
359
+ messages=messages,
360
+ model="mixtral-8x7b-32768",
361
+ temperature=0.3
362
+ )
363
+ question = response.choices[0].message.content.strip()
364
+ if not question.endswith('?'):
365
+ question += '?'
366
+ st.session_state.current_question = question
367
+ except Exception as e:
368
+ st.error(f"Error generating question: {str(e)}")
369
+ st.stop()
370
+
371
+ def handle_symptom_interrogation():
372
+ st.header("Symptom Analysis")
373
+
374
+ if st.session_state.current_step == 1:
375
+ symptom_interrogation_step()
376
+ st.session_state.current_step = 2
377
+
378
+ if 'current_question' in st.session_state:
379
+ with st.form("symptom_qna"):
380
+ st.markdown(f'<div class="dr-message">👨‍⚕ {st.session_state.current_question}</div>', unsafe_allow_html=True)
381
+ answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
382
+
383
+ if st.form_submit_button("Next"):
384
+ if answer:
385
+ st.session_state.symptom_details.append({
386
+ "question": st.session_state.current_question,
387
+ "answer": answer
388
+ })
389
+ del st.session_state.current_question
390
+
391
+ # Check for emergency after 3 questions
392
+ if len(st.session_state.symptom_details) >= 3:
393
+ last_answer = st.session_state.symptom_details[-1]['answer']
394
+ try:
395
+ urgency_check = st.session_state.client.chat.completions.create(
396
+ messages=[{"role": "user", "content":
397
+ f"Does '{last_answer}' indicate immediate emergency? Yes/No"}],
398
+ model="mixtral-8x7b-32768",
399
+ temperature=0
400
+ ).choices[0].message.content
401
+
402
+ if 'YES' in urgency_check.upper():
403
+ st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True)
404
+ st.session_state.current_step = 4
405
+ return
406
+ except Exception as e:
407
+ st.error(f"Error checking urgency: {str(e)}")
408
+
409
+ if len(st.session_state.symptom_details) < 7:
410
+ st.session_state.current_step = 1
411
+ st.rerun()
412
+ else:
413
+ st.session_state.current_step = 3
414
+ st.rerun()
415
+ else:
416
+ st.warning("Please provide an answer")
417
+
418
+ def collect_basic_info():
419
+ st.header("Patient Information")
420
+ with st.form("basic_info"):
421
+ st.session_state.patient_info['name'] = st.text_input("Full Name")
422
+ st.session_state.patient_info['age'] = st.number_input("Age", min_value=0, max_value=120)
423
+ st.session_state.patient_info['gender'] = st.selectbox("Gender", ["Male", "Female", "Other"])
424
+ st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
425
+
426
+ if st.form_submit_button("Next"):
427
+ if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]):
428
+ st.session_state.current_step = 1
429
+ st.rerun()
430
+ else:
431
+ st.warning("Please fill all required fields")
432
+
433
+ def collect_medical_history():
434
+ st.header("Medical History")
435
+ with st.form("medical_history"):
436
+ st.session_state.patient_info['medical_history'] = st.text_area("Relevant Medical History")
437
+ st.session_state.patient_info['medications'] = st.text_area("Current Medications")
438
+ st.session_state.patient_info['allergies'] = st.text_input("Known Allergies")
439
+ st.session_state.patient_info['last_meal'] = st.text_input("Last Meal Time")
440
+ st.session_state.patient_info['recent_travel'] = st.text_input("Recent Travel History")
441
+
442
+ if st.form_submit_button("Submit"):
443
+ st.session_state.current_step = 4
444
+ st.rerun()
445
+
446
+ def generate_risk_assessment():
447
+ st.header("Risk Assessment")
448
+
449
+ try:
450
+ symptom_log = "\n".join(
451
+ [f"Q: {q['question']}\nA: {q['answer']}"
452
+ for q in st.session_state.symptom_details]
453
+ )
454
+
455
+ patient_profile = f"""
456
+ *Patient Profile*
457
+ Name: {st.session_state.patient_info['name']}
458
+ Age: {st.session_state.patient_info['age']}
459
+ Gender: {st.session_state.patient_info['gender']}
460
+
461
+ *Primary Complaint*
462
+ {st.session_state.patient_info['main_symptom']}
463
+
464
+ *Symptom Interrogation*
465
+ {symptom_log}
466
+
467
+ *Medical History*
468
+ {st.session_state.patient_info.get('medical_history', 'None reported')}
469
+
470
+ *Current Medications*
471
+ {st.session_state.patient_info.get('medications', 'None')}
472
+
473
+ *Allergies*
474
+ {st.session_state.patient_info.get('allergies', 'None reported')}
475
+
476
+ *Recent Context*
477
+ Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')}
478
+ Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')}
479
+ """
480
+
481
+ analysis_prompt = f"""STRICTLY follow these instructions:
482
+ 1. Analyze this case: {patient_profile}
483
+ 2. *Include ONLY symptoms the patient is actively experiencing*. Exclude all negated symptoms (e.g., "no fever," "denies breathlessness").
484
+ 3. Output *EXCLUSIVELY* in this format with NO additional text or explanations:
485
+ [Age]-year-old [gender] with [specific, present symptoms].
486
+
487
+ Example Output:
488
+ "45-year-old man with severe chest pain radiating to the jaw"
489
+
490
+ Your Output:"""
491
+
492
+ response = st.session_state.client.chat.completions.create(
493
+ messages=[
494
+ {"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."},
495
+ {"role": "user", "content": analysis_prompt}
496
+ ],
497
+ model="mixtral-8x7b-32768",
498
+ temperature=0.3,
499
+ max_tokens=100
500
+ )
501
+
502
+ risk_prompt = response.choices[0].message.content.strip('"')
503
+
504
+ st.subheader("Clinical Summary")
505
+ st.markdown(f"\n{risk_prompt}\n")
506
+
507
+ # Add risk scoring
508
+ risk_score, breakdown = calculate_risk_score(risk_prompt)
509
+
510
+ # Display risk score with color coding
511
+ score_color = "#2ecc71" if risk_score < 40 else "#f1c40f" if risk_score < 70 else "#e74c3c"
512
+ st.markdown(f"""
513
+ <div style="background: {score_color};
514
+ color: white;
515
+ padding: 1.5rem;
516
+ border-radius: 10px;
517
+ text-align: center;
518
+ margin: 2rem 0;">
519
+ <h2 style="color: white; margin: 0;">Risk Assessment Score</h2>
520
+ <h1 style="font-size: 3.5rem; margin: 0.5rem 0;">{risk_score}/100</h1>
521
+ <p>{'Low Risk' if risk_score < 40 else 'Moderate Risk' if risk_score < 70 else 'High Risk'}</p>
522
+ </div>
523
+ """, unsafe_allow_html=True)
524
+
525
+ # Display breakdown
526
+ with st.expander("View Detailed Risk Breakdown", expanded=True):
527
+ for item in breakdown:
528
+ st.markdown(f"""
529
+ <div style="background: rgba(236, 240, 241, 0.5);
530
+ padding: 1rem;
531
+ border-radius: 8px;
532
+ margin: 0.5rem 0;">
533
+ <strong>{item['symptom']}</strong> ({item['condition']})
534
+ <div style="float: right;">+{item['risk']} pts</div>
535
+ </div>
536
+ """, unsafe_allow_html=True)
537
+
538
+ # Create download button
539
+ timestamp = datetime.now().strftime('%Y%m%d%H%M')
540
+ filename = f"{st.session_state.patient_info['name'].replace(' ', '')}_assessment{timestamp}.txt"
541
+ st.download_button(
542
+ label="Download Assessment",
543
+ data=risk_prompt,
544
+ file_name=filename,
545
+ mime="text/plain"
546
+ )
547
+
548
+ except Exception as e:
549
+ st.error(f"Error generating risk assessment: {str(e)}")
550
+
551
+ def main():
552
+ st.title("🏥 AI Medical Consultancy")
553
+
554
+ # Progress indicator
555
+ steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"]
556
+ progress_html = """
557
+ <div class="progress-bar">
558
+ <div class="step {}">{}</div>
559
+ <div class="step {}">{}</div>
560
+ <div class="step {}">{}</div>
561
+ <div class="step {}">{}</div>
562
+ </div>
563
+ """.format(
564
+ 'active' if st.session_state.current_step >= 0 else '',
565
+ '1. Patient Info',
566
+ 'active' if st.session_state.current_step >= 1 else '',
567
+ '2. Symptoms',
568
+ 'active' if st.session_state.current_step >= 3 else '',
569
+ '3. History',
570
+ 'active' if st.session_state.current_step >= 4 else '',
571
+ '4. Report'
572
+ )
573
+ st.markdown(progress_html, unsafe_allow_html=True)
574
+
575
+ if not initialize_groq_client():
576
+ return
577
+
578
+ steps = {
579
+ 0: collect_basic_info,
580
+ 1: handle_symptom_interrogation,
581
+ 2: handle_symptom_interrogation,
582
+ 3: collect_medical_history,
583
+ 4: generate_risk_assessment
584
+ }
585
+
586
+ current_step = st.session_state.get('current_step', 0)
587
+ if current_step in steps:
588
+ steps[current_step]()
589
+
590
+ if _name_ == "_main_":
591
  main()