Update app.py
Browse files
app.py
CHANGED
@@ -1,388 +1,591 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
import os
|
3 |
-
from groq import Groq
|
4 |
-
from datetime import datetime
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
#
|
10 |
-
st.
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
}
|
24 |
-
|
25 |
-
/*
|
26 |
-
.
|
27 |
-
background:
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
}
|
47 |
-
|
48 |
-
/*
|
49 |
-
.
|
50 |
-
|
51 |
-
|
52 |
-
border:
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
}
|
58 |
-
|
59 |
-
.
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
border
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
main()
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
from groq import Groq
|
4 |
+
from datetime import datetime
|
5 |
+
from transformers import pipeline
|
6 |
+
import pandas as pd
|
7 |
+
import re
|
8 |
+
|
9 |
+
# Set page config FIRST
|
10 |
+
st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
|
11 |
+
|
12 |
+
# Custom CSS for styling
|
13 |
+
st.markdown("""
|
14 |
+
<style>
|
15 |
+
/* Color Variables */
|
16 |
+
:root {
|
17 |
+
--primary: #3498db; /* Blue */
|
18 |
+
--secondary: #2c3e50; /* Dark accent */
|
19 |
+
--accent: #f1c40f; /* Yellow */
|
20 |
+
--success: #2ecc71; /* Positive actions */
|
21 |
+
--light: #ffffff; /* White backgrounds */
|
22 |
+
--dark: #000000; /* Black text/elements */
|
23 |
+
}
|
24 |
+
|
25 |
+
/* Main container styling */
|
26 |
+
.stApp {
|
27 |
+
background: linear-gradient(135deg, #3498db 0%, #e0e0e0 100%);
|
28 |
+
font-family: 'Arial', sans-serif;
|
29 |
+
}
|
30 |
+
|
31 |
+
/* Headers styling */
|
32 |
+
h1, h2, h3 {
|
33 |
+
color: var(--dark) !important;
|
34 |
+
border-bottom: 3px solid var(--primary);
|
35 |
+
padding-bottom: 0.3em;
|
36 |
+
}
|
37 |
+
|
38 |
+
/* Form containers */
|
39 |
+
.stForm {
|
40 |
+
background: #000000;
|
41 |
+
border: 1px solid rgba(44, 62, 80, 0.2);
|
42 |
+
border-radius: 15px;
|
43 |
+
padding: 2rem;
|
44 |
+
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.12);
|
45 |
+
margin: 1rem 0;
|
46 |
+
}
|
47 |
+
|
48 |
+
/* Input fields */
|
49 |
+
.stTextInput input, .stNumberInput input,
|
50 |
+
.stSelectbox select, .stTextArea textarea {
|
51 |
+
border: 2px solid #00FFFF !important;
|
52 |
+
border-radius: 10px !important;
|
53 |
+
padding: 1rem !important;
|
54 |
+
background: #00FFFF !important;
|
55 |
+
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
56 |
+
color: var(--dark) !important;
|
57 |
+
}
|
58 |
+
|
59 |
+
.stTextInput input:focus, .stNumberInput input:focus,
|
60 |
+
.stSelectbox select:focus, .stTextArea textarea:focus {
|
61 |
+
border-color: var(--primary) !important;
|
62 |
+
box-shadow: 0 0 12px rgba(52, 152, 219, 0.2) !important;
|
63 |
+
background: white !important;
|
64 |
+
color: var(--dark) !important;
|
65 |
+
}
|
66 |
+
|
67 |
+
/* Buttons styling */
|
68 |
+
.stButton>button {
|
69 |
+
background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important;
|
70 |
+
color: var(--dark) !important;
|
71 |
+
border: none !important;
|
72 |
+
border-radius: 10px !important;
|
73 |
+
padding: 1rem 2rem !important;
|
74 |
+
font-size: 1rem !important;
|
75 |
+
transition: all 0.3s ease;
|
76 |
+
position: relative;
|
77 |
+
overflow: hidden;
|
78 |
+
}
|
79 |
+
|
80 |
+
.stButton>button:hover {
|
81 |
+
transform: translateY(-2px);
|
82 |
+
box-shadow: 0 8px 15px rgba(52, 152, 219, 0.3);
|
83 |
+
opacity: 0.95;
|
84 |
+
}
|
85 |
+
|
86 |
+
.stButton>button:active {
|
87 |
+
transform: translateY(0);
|
88 |
+
opacity: 1;
|
89 |
+
}
|
90 |
+
|
91 |
+
/* Progress indicator */
|
92 |
+
.progress-bar {
|
93 |
+
display: flex;
|
94 |
+
justify-content: space-between;
|
95 |
+
margin: 2rem 0;
|
96 |
+
padding: 1rem;
|
97 |
+
background: rgba(255, 255, 255, 0.9);
|
98 |
+
border-radius: 10px;
|
99 |
+
color: var(--dark) !important;
|
100 |
+
}
|
101 |
+
|
102 |
+
.step {
|
103 |
+
flex: 1;
|
104 |
+
text-align: center;
|
105 |
+
padding: 1rem;
|
106 |
+
font-weight: 600;
|
107 |
+
color: #95a5a6;
|
108 |
+
position: relative;
|
109 |
+
}
|
110 |
+
|
111 |
+
.step.active {
|
112 |
+
color: var(--primary);
|
113 |
+
}
|
114 |
+
|
115 |
+
.step.active:after {
|
116 |
+
content: '';
|
117 |
+
position: absolute;
|
118 |
+
bottom: -1px;
|
119 |
+
left: 50%;
|
120 |
+
transform: translateX(-50%);
|
121 |
+
width: 40%;
|
122 |
+
height: 3px;
|
123 |
+
background: var(--primary);
|
124 |
+
}
|
125 |
+
|
126 |
+
/* Chat bubbles */
|
127 |
+
.dr-message {
|
128 |
+
background: linear-gradient(135deg, var(--primary) 0%, #2980b9 100%);
|
129 |
+
color: white;
|
130 |
+
border-radius: 20px 20px 20px 4px;
|
131 |
+
padding: 1.2rem 1.5rem;
|
132 |
+
margin: 1rem 0;
|
133 |
+
max-width: 80%;
|
134 |
+
width: fit-content;
|
135 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
|
136 |
+
}
|
137 |
+
|
138 |
+
.user-message {
|
139 |
+
background: linear-gradient(135deg, #f1c40f 0%, #e1b800 100%);
|
140 |
+
margin-left: auto;
|
141 |
+
border-radius: 20px 20px 4px 20px;
|
142 |
+
color: var(--dark) !important;
|
143 |
+
}
|
144 |
+
|
145 |
+
/* Emergency alert */
|
146 |
+
.emergency-alert {
|
147 |
+
background: linear-gradient(135deg, var(--accent) 0%, #c0392b 100%);
|
148 |
+
color: white;
|
149 |
+
padding: 2rem;
|
150 |
+
border-radius: 15px;
|
151 |
+
animation: pulse 1.5s infinite;
|
152 |
+
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
153 |
+
}
|
154 |
+
|
155 |
+
@keyframes pulse {
|
156 |
+
0% { transform: scale(1); }
|
157 |
+
50% { transform: scale(1.02); }
|
158 |
+
100% { transform: scale(1); }
|
159 |
+
}
|
160 |
+
|
161 |
+
/* Download button */
|
162 |
+
.download-btn {
|
163 |
+
background: linear-gradient(135deg, var(--success) 0%, #27ae60 100%) !important;
|
164 |
+
}
|
165 |
+
|
166 |
+
/* Enhanced Data Visualization Contrast */
|
167 |
+
.stDataFrame {
|
168 |
+
border: 1px solid rgba(0, 0, 0, 0.1);
|
169 |
+
border-radius: 12px;
|
170 |
+
overflow: hidden;
|
171 |
+
background: #f0f0f0;
|
172 |
+
color: var(--dark) !important;
|
173 |
+
}
|
174 |
+
|
175 |
+
/* Tabbed Interface Styling */
|
176 |
+
.stTabs [role="tablist"] {
|
177 |
+
gap: 10px;
|
178 |
+
padding: 8px;
|
179 |
+
background: rgba(240, 240, 240, 0.9);
|
180 |
+
border-radius: 12px;
|
181 |
+
color: var(--dark) !important;
|
182 |
+
}
|
183 |
+
|
184 |
+
.stTabs [role="tab"] {
|
185 |
+
background: #ffffff !important;
|
186 |
+
border-radius: 8px !important;
|
187 |
+
transition: all 0.3s ease;
|
188 |
+
color: var(--dark) !important;
|
189 |
+
}
|
190 |
+
|
191 |
+
.stTabs [role="tab"][aria-selected="true"] {
|
192 |
+
background: var(--primary) !important;
|
193 |
+
color: white !important;
|
194 |
+
transform: scale(1.05);
|
195 |
+
}
|
196 |
+
</style>
|
197 |
+
""", unsafe_allow_html=True)
|
198 |
+
|
199 |
+
|
200 |
+
# Initialize session state variables
|
201 |
+
if 'current_step' not in st.session_state:
|
202 |
+
st.session_state.current_step = 0
|
203 |
+
if 'symptom_details' not in st.session_state:
|
204 |
+
st.session_state.symptom_details = [] # Initialize as an empty list
|
205 |
+
if 'patient_info' not in st.session_state:
|
206 |
+
st.session_state.patient_info = {}
|
207 |
+
|
208 |
+
# Initialize models and dataset with caching
|
209 |
+
@st.cache_resource
|
210 |
+
def load_ner_model():
|
211 |
+
return pipeline("token-classification",
|
212 |
+
model="d4data/biomedical-ner-all",
|
213 |
+
aggregation_strategy='first')
|
214 |
+
|
215 |
+
@st.cache_resource
|
216 |
+
def load_dataset():
|
217 |
+
try:
|
218 |
+
data = pd.read_csv('DATASET.csv', encoding='latin1')
|
219 |
+
data['Symptom_lower'] = data['Symptom'].str.strip().str.lower()
|
220 |
+
return data
|
221 |
+
except FileNotFoundError:
|
222 |
+
st.error("DATASET.csv not found. Please ensure the file is in the correct directory.")
|
223 |
+
return None
|
224 |
+
|
225 |
+
pipe = load_ner_model()
|
226 |
+
data = load_dataset()
|
227 |
+
if data is None:
|
228 |
+
st.stop()
|
229 |
+
|
230 |
+
SEVERITY_KEYWORDS = {
|
231 |
+
'normal': ['mild', 'occasional', 'controlled', 'temporary', 'fleeting'],
|
232 |
+
'moderate': ['moderate', 'persistent', 'frequent', 'prolonged', 'bloating'],
|
233 |
+
'severe': ['severe', 'extreme', 'crushing', 'radiating', 'blood', 'inability',
|
234 |
+
'sweating', 'vomiting', 'fever', 'swelling', 'radiates', 'persistent >3 days']
|
235 |
+
}
|
236 |
+
|
237 |
+
def merge_entities(entities, text):
|
238 |
+
"""Merge entities using dataset symptom names for multi-word matching"""
|
239 |
+
merged = []
|
240 |
+
detected = set()
|
241 |
+
text_lower = text.lower()
|
242 |
+
|
243 |
+
symptoms_sorted = data['Symptom'].str.lower().sort_values(
|
244 |
+
key=lambda x: x.str.len(), ascending=False).tolist()
|
245 |
+
|
246 |
+
for symptom in symptoms_sorted:
|
247 |
+
if symptom in text_lower and symptom not in detected:
|
248 |
+
start = text_lower.find(symptom)
|
249 |
+
end = start + len(symptom)
|
250 |
+
merged.append({
|
251 |
+
'word': text[start:end],
|
252 |
+
'start': start,
|
253 |
+
'end': end,
|
254 |
+
'entity_group': 'Sign_symptom'
|
255 |
+
})
|
256 |
+
detected.add(symptom)
|
257 |
+
return merged
|
258 |
+
|
259 |
+
def determine_condition(text, entity):
|
260 |
+
context_window = re.findall(r'\w+', text[max(0, entity['start']-20):entity['end']+20])
|
261 |
+
context = ' '.join(context_window).lower()
|
262 |
+
|
263 |
+
for level, keywords in SEVERITY_KEYWORDS.items():
|
264 |
+
if any(k in context for k in keywords):
|
265 |
+
return level.capitalize()
|
266 |
+
return 'Normal'
|
267 |
+
|
268 |
+
def get_risk_score(symptom, condition):
|
269 |
+
symptom_clean = symptom.strip().lower()
|
270 |
+
matches = data[
|
271 |
+
data['Symptom_lower'].str.contains(rf'\b{symptom_clean}\b', regex=True, flags=re.IGNORECASE)
|
272 |
+
]
|
273 |
+
if not matches.empty:
|
274 |
+
matches = matches[matches['Condition'] == condition]
|
275 |
+
return matches['Risk Score'].values[0] if not matches.empty else 0
|
276 |
+
return 0
|
277 |
+
|
278 |
+
def calculate_risk_score(text):
|
279 |
+
entities = pipe(text)
|
280 |
+
merged_entities = merge_entities(entities, text)
|
281 |
+
|
282 |
+
score = 0
|
283 |
+
breakdown = []
|
284 |
+
for ent in merged_entities:
|
285 |
+
if ent['entity_group'] == "Sign_symptom":
|
286 |
+
symptom = ent['word'].strip().lower()
|
287 |
+
condition = determine_condition(text, ent)
|
288 |
+
risk = get_risk_score(symptom, condition)
|
289 |
+
score += risk
|
290 |
+
breakdown.append({
|
291 |
+
'symptom': symptom.capitalize(),
|
292 |
+
'condition': condition,
|
293 |
+
'risk': risk
|
294 |
+
})
|
295 |
+
return score, breakdown
|
296 |
+
|
297 |
+
def initialize_groq_client():
|
298 |
+
try:
|
299 |
+
# Try to get the API key from Streamlit secrets
|
300 |
+
api_key = None
|
301 |
+
try:
|
302 |
+
api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
|
303 |
+
except FileNotFoundError:
|
304 |
+
st.warning("No secrets.toml file found. Please create one in the .streamlit folder.")
|
305 |
+
|
306 |
+
# If not found, prompt the user to enter it
|
307 |
+
if not api_key:
|
308 |
+
api_key = st.text_input("Enter your Groq API Key:", type="password")
|
309 |
+
if not api_key:
|
310 |
+
st.warning("Please provide a valid Groq API key to proceed.")
|
311 |
+
return False
|
312 |
+
|
313 |
+
# Initialize the Groq client
|
314 |
+
client = Groq(api_key=api_key)
|
315 |
+
st.session_state.client = client
|
316 |
+
return True
|
317 |
+
except Exception as e:
|
318 |
+
st.error(f"Error initializing Groq client: {str(e)}")
|
319 |
+
return False
|
320 |
+
|
321 |
+
def symptom_interrogation_step():
|
322 |
+
client = st.session_state.client
|
323 |
+
main_symptom = st.session_state.patient_info['main_symptom']
|
324 |
+
step = len(st.session_state.symptom_details) # Use the number of collected details as the step
|
325 |
+
|
326 |
+
if step == 0:
|
327 |
+
# First question: ask about the main symptom
|
328 |
+
medical_focus = {
|
329 |
+
'pain': "location/radiation/provoking factors",
|
330 |
+
'fever': "pattern/associated symptoms/response to meds",
|
331 |
+
'gi': "bowel changes/ingestion timing/associated symptoms",
|
332 |
+
'respiratory': "exertion relationship/sputum/triggers"
|
333 |
+
}
|
334 |
+
focus = medical_focus.get(main_symptom.lower(),
|
335 |
+
"temporal pattern/severity progression/associated symptoms")
|
336 |
+
|
337 |
+
prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
|
338 |
+
focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient.
|
339 |
+
Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}.
|
340 |
+
Ask only one question at a time to avoid overwhelming the patient.
|
341 |
+
Keep your language clear, professional, and easy to understand.
|
342 |
+
|
343 |
+
Dont display possibe symptoms or why you are asking questions."""
|
344 |
+
|
345 |
+
messages = [
|
346 |
+
{"role": "system", "content": "Ask focused clinical questions. One at a time."},
|
347 |
+
{"role": "user", "content": prompt}
|
348 |
+
]
|
349 |
+
else:
|
350 |
+
# Subsequent questions: use the last Q&A to generate the next question
|
351 |
+
last_qa = st.session_state.symptom_details[-1]
|
352 |
+
prompt = f"""Last Q&A: {last_qa['question']} → {last_qa['answer']}
|
353 |
+
Based on this, ask the NEXT most critical question to differentiate between
|
354 |
+
possible causes of {main_symptom}. Consider red flags and likelihood."""
|
355 |
+
messages = [{"role": "user", "content": prompt}]
|
356 |
+
|
357 |
+
try:
|
358 |
+
response = client.chat.completions.create(
|
359 |
+
messages=messages,
|
360 |
+
model="mixtral-8x7b-32768",
|
361 |
+
temperature=0.3
|
362 |
+
)
|
363 |
+
question = response.choices[0].message.content.strip()
|
364 |
+
if not question.endswith('?'):
|
365 |
+
question += '?'
|
366 |
+
st.session_state.current_question = question
|
367 |
+
except Exception as e:
|
368 |
+
st.error(f"Error generating question: {str(e)}")
|
369 |
+
st.stop()
|
370 |
+
|
371 |
+
def handle_symptom_interrogation():
|
372 |
+
st.header("Symptom Analysis")
|
373 |
+
|
374 |
+
if st.session_state.current_step == 1:
|
375 |
+
symptom_interrogation_step()
|
376 |
+
st.session_state.current_step = 2
|
377 |
+
|
378 |
+
if 'current_question' in st.session_state:
|
379 |
+
with st.form("symptom_qna"):
|
380 |
+
st.markdown(f'<div class="dr-message">👨⚕ {st.session_state.current_question}</div>', unsafe_allow_html=True)
|
381 |
+
answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
|
382 |
+
|
383 |
+
if st.form_submit_button("Next"):
|
384 |
+
if answer:
|
385 |
+
st.session_state.symptom_details.append({
|
386 |
+
"question": st.session_state.current_question,
|
387 |
+
"answer": answer
|
388 |
+
})
|
389 |
+
del st.session_state.current_question
|
390 |
+
|
391 |
+
# Check for emergency after 3 questions
|
392 |
+
if len(st.session_state.symptom_details) >= 3:
|
393 |
+
last_answer = st.session_state.symptom_details[-1]['answer']
|
394 |
+
try:
|
395 |
+
urgency_check = st.session_state.client.chat.completions.create(
|
396 |
+
messages=[{"role": "user", "content":
|
397 |
+
f"Does '{last_answer}' indicate immediate emergency? Yes/No"}],
|
398 |
+
model="mixtral-8x7b-32768",
|
399 |
+
temperature=0
|
400 |
+
).choices[0].message.content
|
401 |
+
|
402 |
+
if 'YES' in urgency_check.upper():
|
403 |
+
st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True)
|
404 |
+
st.session_state.current_step = 4
|
405 |
+
return
|
406 |
+
except Exception as e:
|
407 |
+
st.error(f"Error checking urgency: {str(e)}")
|
408 |
+
|
409 |
+
if len(st.session_state.symptom_details) < 7:
|
410 |
+
st.session_state.current_step = 1
|
411 |
+
st.rerun()
|
412 |
+
else:
|
413 |
+
st.session_state.current_step = 3
|
414 |
+
st.rerun()
|
415 |
+
else:
|
416 |
+
st.warning("Please provide an answer")
|
417 |
+
|
418 |
+
def collect_basic_info():
|
419 |
+
st.header("Patient Information")
|
420 |
+
with st.form("basic_info"):
|
421 |
+
st.session_state.patient_info['name'] = st.text_input("Full Name")
|
422 |
+
st.session_state.patient_info['age'] = st.number_input("Age", min_value=0, max_value=120)
|
423 |
+
st.session_state.patient_info['gender'] = st.selectbox("Gender", ["Male", "Female", "Other"])
|
424 |
+
st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
|
425 |
+
|
426 |
+
if st.form_submit_button("Next"):
|
427 |
+
if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]):
|
428 |
+
st.session_state.current_step = 1
|
429 |
+
st.rerun()
|
430 |
+
else:
|
431 |
+
st.warning("Please fill all required fields")
|
432 |
+
|
433 |
+
def collect_medical_history():
|
434 |
+
st.header("Medical History")
|
435 |
+
with st.form("medical_history"):
|
436 |
+
st.session_state.patient_info['medical_history'] = st.text_area("Relevant Medical History")
|
437 |
+
st.session_state.patient_info['medications'] = st.text_area("Current Medications")
|
438 |
+
st.session_state.patient_info['allergies'] = st.text_input("Known Allergies")
|
439 |
+
st.session_state.patient_info['last_meal'] = st.text_input("Last Meal Time")
|
440 |
+
st.session_state.patient_info['recent_travel'] = st.text_input("Recent Travel History")
|
441 |
+
|
442 |
+
if st.form_submit_button("Submit"):
|
443 |
+
st.session_state.current_step = 4
|
444 |
+
st.rerun()
|
445 |
+
|
446 |
+
def generate_risk_assessment():
|
447 |
+
st.header("Risk Assessment")
|
448 |
+
|
449 |
+
try:
|
450 |
+
symptom_log = "\n".join(
|
451 |
+
[f"Q: {q['question']}\nA: {q['answer']}"
|
452 |
+
for q in st.session_state.symptom_details]
|
453 |
+
)
|
454 |
+
|
455 |
+
patient_profile = f"""
|
456 |
+
*Patient Profile*
|
457 |
+
Name: {st.session_state.patient_info['name']}
|
458 |
+
Age: {st.session_state.patient_info['age']}
|
459 |
+
Gender: {st.session_state.patient_info['gender']}
|
460 |
+
|
461 |
+
*Primary Complaint*
|
462 |
+
{st.session_state.patient_info['main_symptom']}
|
463 |
+
|
464 |
+
*Symptom Interrogation*
|
465 |
+
{symptom_log}
|
466 |
+
|
467 |
+
*Medical History*
|
468 |
+
{st.session_state.patient_info.get('medical_history', 'None reported')}
|
469 |
+
|
470 |
+
*Current Medications*
|
471 |
+
{st.session_state.patient_info.get('medications', 'None')}
|
472 |
+
|
473 |
+
*Allergies*
|
474 |
+
{st.session_state.patient_info.get('allergies', 'None reported')}
|
475 |
+
|
476 |
+
*Recent Context*
|
477 |
+
Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')}
|
478 |
+
Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')}
|
479 |
+
"""
|
480 |
+
|
481 |
+
analysis_prompt = f"""STRICTLY follow these instructions:
|
482 |
+
1. Analyze this case: {patient_profile}
|
483 |
+
2. *Include ONLY symptoms the patient is actively experiencing*. Exclude all negated symptoms (e.g., "no fever," "denies breathlessness").
|
484 |
+
3. Output *EXCLUSIVELY* in this format with NO additional text or explanations:
|
485 |
+
[Age]-year-old [gender] with [specific, present symptoms].
|
486 |
+
|
487 |
+
Example Output:
|
488 |
+
"45-year-old man with severe chest pain radiating to the jaw"
|
489 |
+
|
490 |
+
Your Output:"""
|
491 |
+
|
492 |
+
response = st.session_state.client.chat.completions.create(
|
493 |
+
messages=[
|
494 |
+
{"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."},
|
495 |
+
{"role": "user", "content": analysis_prompt}
|
496 |
+
],
|
497 |
+
model="mixtral-8x7b-32768",
|
498 |
+
temperature=0.3,
|
499 |
+
max_tokens=100
|
500 |
+
)
|
501 |
+
|
502 |
+
risk_prompt = response.choices[0].message.content.strip('"')
|
503 |
+
|
504 |
+
st.subheader("Clinical Summary")
|
505 |
+
st.markdown(f"\n{risk_prompt}\n")
|
506 |
+
|
507 |
+
# Add risk scoring
|
508 |
+
risk_score, breakdown = calculate_risk_score(risk_prompt)
|
509 |
+
|
510 |
+
# Display risk score with color coding
|
511 |
+
score_color = "#2ecc71" if risk_score < 40 else "#f1c40f" if risk_score < 70 else "#e74c3c"
|
512 |
+
st.markdown(f"""
|
513 |
+
<div style="background: {score_color};
|
514 |
+
color: white;
|
515 |
+
padding: 1.5rem;
|
516 |
+
border-radius: 10px;
|
517 |
+
text-align: center;
|
518 |
+
margin: 2rem 0;">
|
519 |
+
<h2 style="color: white; margin: 0;">Risk Assessment Score</h2>
|
520 |
+
<h1 style="font-size: 3.5rem; margin: 0.5rem 0;">{risk_score}/100</h1>
|
521 |
+
<p>{'Low Risk' if risk_score < 40 else 'Moderate Risk' if risk_score < 70 else 'High Risk'}</p>
|
522 |
+
</div>
|
523 |
+
""", unsafe_allow_html=True)
|
524 |
+
|
525 |
+
# Display breakdown
|
526 |
+
with st.expander("View Detailed Risk Breakdown", expanded=True):
|
527 |
+
for item in breakdown:
|
528 |
+
st.markdown(f"""
|
529 |
+
<div style="background: rgba(236, 240, 241, 0.5);
|
530 |
+
padding: 1rem;
|
531 |
+
border-radius: 8px;
|
532 |
+
margin: 0.5rem 0;">
|
533 |
+
<strong>{item['symptom']}</strong> ({item['condition']})
|
534 |
+
<div style="float: right;">+{item['risk']} pts</div>
|
535 |
+
</div>
|
536 |
+
""", unsafe_allow_html=True)
|
537 |
+
|
538 |
+
# Create download button
|
539 |
+
timestamp = datetime.now().strftime('%Y%m%d%H%M')
|
540 |
+
filename = f"{st.session_state.patient_info['name'].replace(' ', '')}_assessment{timestamp}.txt"
|
541 |
+
st.download_button(
|
542 |
+
label="Download Assessment",
|
543 |
+
data=risk_prompt,
|
544 |
+
file_name=filename,
|
545 |
+
mime="text/plain"
|
546 |
+
)
|
547 |
+
|
548 |
+
except Exception as e:
|
549 |
+
st.error(f"Error generating risk assessment: {str(e)}")
|
550 |
+
|
551 |
+
def main():
|
552 |
+
st.title("🏥 AI Medical Consultancy")
|
553 |
+
|
554 |
+
# Progress indicator
|
555 |
+
steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"]
|
556 |
+
progress_html = """
|
557 |
+
<div class="progress-bar">
|
558 |
+
<div class="step {}">{}</div>
|
559 |
+
<div class="step {}">{}</div>
|
560 |
+
<div class="step {}">{}</div>
|
561 |
+
<div class="step {}">{}</div>
|
562 |
+
</div>
|
563 |
+
""".format(
|
564 |
+
'active' if st.session_state.current_step >= 0 else '',
|
565 |
+
'1. Patient Info',
|
566 |
+
'active' if st.session_state.current_step >= 1 else '',
|
567 |
+
'2. Symptoms',
|
568 |
+
'active' if st.session_state.current_step >= 3 else '',
|
569 |
+
'3. History',
|
570 |
+
'active' if st.session_state.current_step >= 4 else '',
|
571 |
+
'4. Report'
|
572 |
+
)
|
573 |
+
st.markdown(progress_html, unsafe_allow_html=True)
|
574 |
+
|
575 |
+
if not initialize_groq_client():
|
576 |
+
return
|
577 |
+
|
578 |
+
steps = {
|
579 |
+
0: collect_basic_info,
|
580 |
+
1: handle_symptom_interrogation,
|
581 |
+
2: handle_symptom_interrogation,
|
582 |
+
3: collect_medical_history,
|
583 |
+
4: generate_risk_assessment
|
584 |
+
}
|
585 |
+
|
586 |
+
current_step = st.session_state.get('current_step', 0)
|
587 |
+
if current_step in steps:
|
588 |
+
steps[current_step]()
|
589 |
+
|
590 |
+
if _name_ == "_main_":
|
591 |
main()
|