|
import streamlit as st |
|
import os |
|
from groq import Groq |
|
from datetime import datetime |
|
|
|
|
|
st.set_page_config(page_title="AI Medical Consultancy", layout="wide") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
/* Color Variables */ |
|
:root { |
|
--primary: #3498db; /* Blue */ |
|
--secondary: #2c3e50; /* Dark accent */ |
|
--accent: #f1c40f; /* Yellow */ |
|
--success: #2ecc71; /* Positive actions */ |
|
--light: #ffffff; /* White backgrounds */ |
|
--dark: #000000; /* Black text/elements */ |
|
} |
|
/* Main container styling */ |
|
.stApp { |
|
background: linear-gradient(135deg, #3498db 0%, #e0e0e0 100%); |
|
font-family: 'Arial', sans-serif; |
|
} |
|
/* Headers styling */ |
|
h1, h2, h3 { |
|
color: var(--dark) !important; |
|
border-bottom: 3px solid var(--primary); |
|
padding-bottom: 0.3em; |
|
} |
|
/* Form containers */ |
|
.stForm { |
|
background: #000000; |
|
border: 1px solid rgba(44, 62, 80, 0.2); |
|
border-radius: 15px; |
|
padding: 2rem; |
|
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.12); |
|
margin: 1rem 0; |
|
} |
|
/* Input fields */ |
|
.stTextInput input, .stNumberInput input, |
|
.stSelectbox select, .stTextArea textarea { |
|
border: 2px solid #00FFFF !important; |
|
border-radius: 10px !important; |
|
padding: 1rem !important; |
|
background: #00FFFF !important; |
|
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1); |
|
color: var(--dark) !important; |
|
} |
|
.stTextInput input:focus, .stNumberInput input:focus, |
|
.stSelectbox select:focus, .stTextArea textarea:focus { |
|
border-color: var(--primary) !important; |
|
box-shadow: 0 0 12px rgba(52, 152, 219, 0.2) !important; |
|
background: white !important; |
|
color: var(--dark) !important; |
|
} |
|
/* Buttons styling */ |
|
.stButton>button { |
|
background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important; |
|
color: var(--dark) !important; |
|
border: none !important; |
|
border-radius: 10px !important; |
|
padding: 1rem 2rem !important; |
|
font-size: 1rem !important; |
|
transition: all 0.3s ease; |
|
position: relative; |
|
overflow: hidden; |
|
} |
|
.stButton>button:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 8px 15px rgba(52, 152, 219, 0.3); |
|
opacity: 0.95; |
|
} |
|
.stButton>button:active { |
|
transform: translateY(0); |
|
opacity: 1; |
|
} |
|
/* Progress indicator */ |
|
.progress-bar { |
|
display: flex; |
|
justify-content: space-between; |
|
margin: 2rem 0; |
|
padding: 1rem; |
|
background: rgba(255, 255, 255, 0.9); |
|
border-radius: 10px; |
|
color: var(--dark) !important; |
|
} |
|
.step { |
|
flex: 1; |
|
text-align: center; |
|
padding: 1rem; |
|
font-weight: 600; |
|
color: #95a5a6; |
|
position: relative; |
|
} |
|
.step.active { |
|
color: var(--primary); |
|
} |
|
.step.active:after { |
|
content: ''; |
|
position: absolute; |
|
bottom: -1px; |
|
left: 50%; |
|
transform: translateX(-50%); |
|
width: 40%; |
|
height: 3px; |
|
background: var(--primary); |
|
} |
|
/* Chat bubbles */ |
|
.dr-message { |
|
background: linear-gradient(135deg, var(--primary) 0%, #2980b9 100%); |
|
color: white; |
|
border-radius: 20px 20px 20px 4px; |
|
padding: 1.2rem 1.5rem; |
|
margin: 1rem 0; |
|
max-width: 80%; |
|
width: fit-content; |
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05); |
|
} |
|
.user-message { |
|
background: linear-gradient(135deg, #f1c40f 0%, #e1b800 100%); |
|
margin-left: auto; |
|
border-radius: 20px 20px 4px 20px; |
|
color: var(--dark) !important; |
|
} |
|
/* Emergency alert */ |
|
.emergency-alert { |
|
background: linear-gradient(135deg, var(--accent) 0%, #c0392b 100%); |
|
color: white; |
|
padding: 2rem; |
|
border-radius: 15px; |
|
animation: pulse 1.5s infinite; |
|
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2); |
|
} |
|
@keyframes pulse { |
|
0% { transform: scale(1); } |
|
50% { transform: scale(1.02); } |
|
100% { transform: scale(1); } |
|
} |
|
/* Download button */ |
|
.download-btn { |
|
background: linear-gradient(135deg, var(--success) 0%, #27ae60 100%) !important; |
|
} |
|
/* Enhanced Data Visualization Contrast */ |
|
.stDataFrame { |
|
border: 1px solid rgba(0, 0, 0, 0.1); |
|
border-radius: 12px; |
|
overflow: hidden; |
|
background: #f0f0f0; |
|
color: var(--dark) !important; |
|
} |
|
/* Tabbed Interface Styling */ |
|
.stTabs [role="tablist"] { |
|
gap: 10px; |
|
padding: 8px; |
|
background: rgba(240, 240, 240, 0.9); |
|
border-radius: 12px; |
|
color: var(--dark) !important; |
|
} |
|
.stTabs [role="tab"] { |
|
background: #ffffff !important; |
|
border-radius: 8px !important; |
|
transition: all 0.3s ease; |
|
color: var(--dark) !important; |
|
} |
|
.stTabs [role="tab"][aria-selected="true"] { |
|
background: var(--primary) !important; |
|
color: white !important; |
|
transform: scale(1.05); |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
if 'current_step' not in st.session_state: |
|
st.session_state.current_step = 0 |
|
if 'symptom_details' not in st.session_state: |
|
st.session_state.symptom_details = [] |
|
if 'patient_info' not in st.session_state: |
|
st.session_state.patient_info = {} |
|
|
|
def initialize_groq_client(): |
|
try: |
|
|
|
api_key = None |
|
try: |
|
api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY")) |
|
except FileNotFoundError: |
|
st.warning("No `secrets.toml` file found. Please create one in the `.streamlit` folder.") |
|
|
|
|
|
if not api_key: |
|
api_key = st.text_input("Enter your Groq API Key:", type="password") |
|
if not api_key: |
|
st.warning("Please provide a valid Groq API key to proceed.") |
|
return False |
|
|
|
|
|
client = Groq(api_key=api_key) |
|
st.session_state.client = client |
|
return True |
|
except Exception as e: |
|
st.error(f"Error initializing Groq client: {str(e)}") |
|
return False |
|
|
|
def symptom_interrogation_step(): |
|
client = st.session_state.client |
|
main_symptom = st.session_state.patient_info['main_symptom'] |
|
step = len(st.session_state.symptom_details) |
|
|
|
if step == 0: |
|
|
|
medical_focus = { |
|
'pain': "location/radiation/provoking factors", |
|
'fever': "pattern/associated symptoms/response to meds", |
|
'gi': "bowel changes/ingestion timing/associated symptoms", |
|
'respiratory': "exertion relationship/sputum/triggers" |
|
} |
|
focus = medical_focus.get(main_symptom.lower(), |
|
"temporal pattern/severity progression/associated symptoms") |
|
|
|
prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom} |
|
focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient. |
|
Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}. |
|
Ask only one question at a time to avoid overwhelming the patient. |
|
Keep your language clear, professional, and easy to understand. |
|
|
|
Dont display possibe symptoms or why you are asking questions.""" |
|
|
|
messages = [ |
|
{"role": "system", "content": "Ask focused clinical questions. One at a time."}, |
|
{"role": "user", "content": prompt} |
|
] |
|
else: |
|
|
|
last_qa = st.session_state.symptom_details[-1] |
|
prompt = f"""Last Q&A: {last_qa['question']} → {last_qa['answer']} |
|
Based on this, ask the NEXT most critical question to differentiate between |
|
possible causes of {main_symptom}. Consider red flags and likelihood.""" |
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
try: |
|
response = client.chat.completions.create( |
|
messages=messages, |
|
model="mixtral-8x7b-32768", |
|
temperature=0.3 |
|
) |
|
question = response.choices[0].message.content.strip() |
|
if not question.endswith('?'): |
|
question += '?' |
|
st.session_state.current_question = question |
|
except Exception as e: |
|
st.error(f"Error generating question: {str(e)}") |
|
st.stop() |
|
|
|
def handle_symptom_interrogation(): |
|
st.header("Symptom Analysis") |
|
|
|
if st.session_state.current_step == 1: |
|
symptom_interrogation_step() |
|
st.session_state.current_step = 2 |
|
|
|
if 'current_question' in st.session_state: |
|
with st.form("symptom_qna"): |
|
st.markdown(f'<div class="dr-message">👨⚕️ {st.session_state.current_question}</div>', unsafe_allow_html=True) |
|
answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}") |
|
|
|
if st.form_submit_button("Next"): |
|
if answer: |
|
st.session_state.symptom_details.append({ |
|
"question": st.session_state.current_question, |
|
"answer": answer |
|
}) |
|
del st.session_state.current_question |
|
|
|
|
|
if len(st.session_state.symptom_details) >= 3: |
|
last_answer = st.session_state.symptom_details[-1]['answer'] |
|
try: |
|
urgency_check = st.session_state.client.chat.completions.create( |
|
messages=[{"role": "user", "content": |
|
f"Does '{last_answer}' indicate immediate emergency? Yes/No"}], |
|
model="mixtral-8x7b-32768", |
|
temperature=0 |
|
).choices[0].message.content |
|
|
|
if 'YES' in urgency_check.upper(): |
|
st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True) |
|
st.session_state.current_step = 4 |
|
return |
|
except Exception as e: |
|
st.error(f"Error checking urgency: {str(e)}") |
|
|
|
if len(st.session_state.symptom_details) < 7: |
|
st.session_state.current_step = 1 |
|
st.rerun() |
|
else: |
|
st.session_state.current_step = 3 |
|
st.rerun() |
|
else: |
|
st.warning("Please provide an answer") |
|
|
|
def collect_basic_info(): |
|
st.header("Patient Information") |
|
with st.form("basic_info"): |
|
st.session_state.patient_info['name'] = st.text_input("Full Name") |
|
st.session_state.patient_info['age'] = st.number_input("Age", min_value=0, max_value=120) |
|
st.session_state.patient_info['gender'] = st.selectbox("Gender", ["Male", "Female", "Other"]) |
|
st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom") |
|
|
|
if st.form_submit_button("Next"): |
|
if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]): |
|
st.session_state.current_step = 1 |
|
st.rerun() |
|
else: |
|
st.warning("Please fill all required fields") |
|
|
|
def collect_medical_history(): |
|
st.header("Medical History") |
|
with st.form("medical_history"): |
|
st.session_state.patient_info['medical_history'] = st.text_area("Relevant Medical History") |
|
st.session_state.patient_info['medications'] = st.text_area("Current Medications") |
|
st.session_state.patient_info['allergies'] = st.text_input("Known Allergies") |
|
st.session_state.patient_info['last_meal'] = st.text_input("Last Meal Time") |
|
st.session_state.patient_info['recent_travel'] = st.text_input("Recent Travel History") |
|
|
|
if st.form_submit_button("Submit"): |
|
st.session_state.current_step = 4 |
|
st.rerun() |
|
|
|
def generate_risk_assessment(): |
|
st.header("Risk Assessment") |
|
|
|
try: |
|
symptom_log = "\n".join( |
|
[f"Q: {q['question']}\nA: {q['answer']}" |
|
for q in st.session_state.symptom_details] |
|
) |
|
|
|
patient_profile = f""" |
|
**Patient Profile** |
|
Name: {st.session_state.patient_info['name']} |
|
Age: {st.session_state.patient_info['age']} |
|
Gender: {st.session_state.patient_info['gender']} |
|
|
|
**Primary Complaint** |
|
{st.session_state.patient_info['main_symptom']} |
|
|
|
**Symptom Interrogation** |
|
{symptom_log} |
|
|
|
**Medical History** |
|
{st.session_state.patient_info.get('medical_history', 'None reported')} |
|
|
|
**Current Medications** |
|
{st.session_state.patient_info.get('medications', 'None')} |
|
|
|
**Allergies** |
|
{st.session_state.patient_info.get('allergies', 'None reported')} |
|
|
|
**Recent Context** |
|
Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')} |
|
Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')} |
|
""" |
|
|
|
analysis_prompt = f"""STRICTLY follow these instructions: |
|
1. Analyze this case: {patient_profile} |
|
2. *Include ONLY symptoms the patient is actively experiencing*. Exclude all negated symptoms (e.g., "no fever," "denies breathlessness"). |
|
3. Output *EXCLUSIVELY* in this format with NO additional text or explanations: |
|
[Age]-year-old [gender] with [specific, present symptoms]. |
|
Example Output: |
|
"45-year-old man with severe chest pain radiating to the jaw" |
|
Your Output:""" |
|
|
|
|
|
response = st.session_state.client.chat.completions.create( |
|
messages=[ |
|
{"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."}, |
|
{"role": "user", "content": analysis_prompt} |
|
], |
|
model="mixtral-8x7b-32768", |
|
temperature=0.3, |
|
max_tokens=100 |
|
) |
|
|
|
risk_prompt = response.choices[0].message.content.strip('"') |
|
|
|
st.subheader("Clinical Summary") |
|
st.markdown(f"```\n{risk_prompt}\n```") |
|
|
|
|
|
timestamp = datetime.now().strftime('%Y%m%d%H%M') |
|
filename = f"{st.session_state.patient_info['name'].replace(' ', '_')}_assessment_{timestamp}.txt" |
|
st.download_button( |
|
label="Download Assessment", |
|
data=risk_prompt, |
|
file_name=filename, |
|
mime="text/plain" |
|
) |
|
|
|
except Exception as e: |
|
st.error(f"Error generating risk assessment: {str(e)}") |
|
|
|
def main(): |
|
st.title("🏥 AI Medical Consultancy") |
|
|
|
|
|
steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"] |
|
progress_html = """ |
|
<div class="progress-bar"> |
|
<div class="step {}">{}</div> |
|
<div class="step {}">{}</div> |
|
<div class="step {}">{}</div> |
|
<div class="step {}">{}</div> |
|
</div> |
|
""".format( |
|
'active' if st.session_state.current_step >= 0 else '', |
|
'1. Patient Info', |
|
'active' if st.session_state.current_step >= 1 else '', |
|
'2. Symptoms', |
|
'active' if st.session_state.current_step >= 3 else '', |
|
'3. History', |
|
'active' if st.session_state.current_step >= 4 else '', |
|
'4. Report' |
|
) |
|
st.markdown(progress_html, unsafe_allow_html=True) |
|
|
|
if not initialize_groq_client(): |
|
return |
|
|
|
steps = { |
|
0: collect_basic_info, |
|
1: handle_symptom_interrogation, |
|
2: handle_symptom_interrogation, |
|
3: collect_medical_history, |
|
4: generate_risk_assessment |
|
} |
|
|
|
current_step = st.session_state.get('current_step', 0) |
|
if current_step in steps: |
|
steps[current_step]() |
|
|
|
if __name__ == "__main__": |
|
main() |