ayush200399391001's picture
Update app.py
ae6054f verified
raw
history blame
16.5 kB
import streamlit as st
import os
from groq import Groq
from datetime import datetime
# Set page config FIRST
st.set_page_config(page_title="AI Medical Consultancy", layout="wide")
# Custom CSS for styling
st.markdown("""
<style>
/* Color Variables */
:root {
--primary: #3498db; /* Blue */
--secondary: #2c3e50; /* Dark accent */
--accent: #f1c40f; /* Yellow */
--success: #2ecc71; /* Positive actions */
--light: #ffffff; /* White backgrounds */
--dark: #000000; /* Black text/elements */
}
/* Main container styling */
.stApp {
background: linear-gradient(135deg, #3498db 0%, #e0e0e0 100%);
font-family: 'Arial', sans-serif;
}
/* Headers styling */
h1, h2, h3 {
color: var(--dark) !important;
border-bottom: 3px solid var(--primary);
padding-bottom: 0.3em;
}
/* Form containers */
.stForm {
background: #000000;
border: 1px solid rgba(44, 62, 80, 0.2);
border-radius: 15px;
padding: 2rem;
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.12);
margin: 1rem 0;
}
/* Input fields */
.stTextInput input, .stNumberInput input,
.stSelectbox select, .stTextArea textarea {
border: 2px solid #00FFFF !important;
border-radius: 10px !important;
padding: 1rem !important;
background: #00FFFF !important;
transition: all 0.3s cubic-bezier(0.4, 0, 0.2, 1);
color: var(--dark) !important;
}
.stTextInput input:focus, .stNumberInput input:focus,
.stSelectbox select:focus, .stTextArea textarea:focus {
border-color: var(--primary) !important;
box-shadow: 0 0 12px rgba(52, 152, 219, 0.2) !important;
background: white !important;
color: var(--dark) !important;
}
/* Buttons styling */
.stButton>button {
background: linear-gradient(135deg, var(--primary) 0%, var(--accent) 100%) !important;
color: var(--dark) !important;
border: none !important;
border-radius: 10px !important;
padding: 1rem 2rem !important;
font-size: 1rem !important;
transition: all 0.3s ease;
position: relative;
overflow: hidden;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 8px 15px rgba(52, 152, 219, 0.3);
opacity: 0.95;
}
.stButton>button:active {
transform: translateY(0);
opacity: 1;
}
/* Progress indicator */
.progress-bar {
display: flex;
justify-content: space-between;
margin: 2rem 0;
padding: 1rem;
background: rgba(255, 255, 255, 0.9);
border-radius: 10px;
color: var(--dark) !important;
}
.step {
flex: 1;
text-align: center;
padding: 1rem;
font-weight: 600;
color: #95a5a6;
position: relative;
}
.step.active {
color: var(--primary);
}
.step.active:after {
content: '';
position: absolute;
bottom: -1px;
left: 50%;
transform: translateX(-50%);
width: 40%;
height: 3px;
background: var(--primary);
}
/* Chat bubbles */
.dr-message {
background: linear-gradient(135deg, var(--primary) 0%, #2980b9 100%);
color: white;
border-radius: 20px 20px 20px 4px;
padding: 1.2rem 1.5rem;
margin: 1rem 0;
max-width: 80%;
width: fit-content;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
}
.user-message {
background: linear-gradient(135deg, #f1c40f 0%, #e1b800 100%);
margin-left: auto;
border-radius: 20px 20px 4px 20px;
color: var(--dark) !important;
}
/* Emergency alert */
.emergency-alert {
background: linear-gradient(135deg, var(--accent) 0%, #c0392b 100%);
color: white;
padding: 2rem;
border-radius: 15px;
animation: pulse 1.5s infinite;
text-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.02); }
100% { transform: scale(1); }
}
/* Download button */
.download-btn {
background: linear-gradient(135deg, var(--success) 0%, #27ae60 100%) !important;
}
/* Enhanced Data Visualization Contrast */
.stDataFrame {
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: 12px;
overflow: hidden;
background: #f0f0f0;
color: var(--dark) !important;
}
/* Tabbed Interface Styling */
.stTabs [role="tablist"] {
gap: 10px;
padding: 8px;
background: rgba(240, 240, 240, 0.9);
border-radius: 12px;
color: var(--dark) !important;
}
.stTabs [role="tab"] {
background: #ffffff !important;
border-radius: 8px !important;
transition: all 0.3s ease;
color: var(--dark) !important;
}
.stTabs [role="tab"][aria-selected="true"] {
background: var(--primary) !important;
color: white !important;
transform: scale(1.05);
}
</style>
""", unsafe_allow_html=True)
# Initialize session state variables
if 'current_step' not in st.session_state:
st.session_state.current_step = 0
if 'symptom_details' not in st.session_state:
st.session_state.symptom_details = [] # Initialize as an empty list
if 'patient_info' not in st.session_state:
st.session_state.patient_info = {}
def initialize_groq_client():
try:
# Try to get the API key from Streamlit secrets
api_key = None
try:
api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY"))
except FileNotFoundError:
st.warning("No `secrets.toml` file found. Please create one in the `.streamlit` folder.")
# If not found, prompt the user to enter it
if not api_key:
api_key = st.text_input("Enter your Groq API Key:", type="password")
if not api_key:
st.warning("Please provide a valid Groq API key to proceed.")
return False
# Initialize the Groq client
client = Groq(api_key=api_key)
st.session_state.client = client
return True
except Exception as e:
st.error(f"Error initializing Groq client: {str(e)}")
return False
def symptom_interrogation_step():
client = st.session_state.client
main_symptom = st.session_state.patient_info['main_symptom']
step = len(st.session_state.symptom_details) # Use the number of collected details as the step
if step == 0:
# First question: ask about the main symptom
medical_focus = {
'pain': "location/radiation/provoking factors",
'fever': "pattern/associated symptoms/response to meds",
'gi': "bowel changes/ingestion timing/associated symptoms",
'respiratory': "exertion relationship/sputum/triggers"
}
focus = medical_focus.get(main_symptom.lower(),
"temporal pattern/severity progression/associated symptoms")
prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom}
focusing on {focus} to differentiate serious causes. Your task is to have a polite and simple conversation with a patient.
Start by asking ONE specific follow-up question about their initial symptom: {main_symptom}.
Ask only one question at a time to avoid overwhelming the patient.
Keep your language clear, professional, and easy to understand.
Dont display possibe symptoms or why you are asking questions."""
messages = [
{"role": "system", "content": "Ask focused clinical questions. One at a time."},
{"role": "user", "content": prompt}
]
else:
# Subsequent questions: use the last Q&A to generate the next question
last_qa = st.session_state.symptom_details[-1]
prompt = f"""Last Q&A: {last_qa['question']}{last_qa['answer']}
Based on this, ask the NEXT most critical question to differentiate between
possible causes of {main_symptom}. Consider red flags and likelihood."""
messages = [{"role": "user", "content": prompt}]
try:
response = client.chat.completions.create(
messages=messages,
model="mixtral-8x7b-32768",
temperature=0.3
)
question = response.choices[0].message.content.strip()
if not question.endswith('?'):
question += '?'
st.session_state.current_question = question
except Exception as e:
st.error(f"Error generating question: {str(e)}")
st.stop()
def handle_symptom_interrogation():
st.header("Symptom Analysis")
if st.session_state.current_step == 1:
symptom_interrogation_step()
st.session_state.current_step = 2
if 'current_question' in st.session_state:
with st.form("symptom_qna"):
st.markdown(f'<div class="dr-message">👨‍⚕️ {st.session_state.current_question}</div>', unsafe_allow_html=True)
answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}")
if st.form_submit_button("Next"):
if answer:
st.session_state.symptom_details.append({
"question": st.session_state.current_question,
"answer": answer
})
del st.session_state.current_question
# Check for emergency after 3 questions
if len(st.session_state.symptom_details) >= 3:
last_answer = st.session_state.symptom_details[-1]['answer']
try:
urgency_check = st.session_state.client.chat.completions.create(
messages=[{"role": "user", "content":
f"Does '{last_answer}' indicate immediate emergency? Yes/No"}],
model="mixtral-8x7b-32768",
temperature=0
).choices[0].message.content
if 'YES' in urgency_check.upper():
st.markdown('<div class="emergency-alert">🚨 Emergency detected! Please seek immediate medical attention.</div>', unsafe_allow_html=True)
st.session_state.current_step = 4
return
except Exception as e:
st.error(f"Error checking urgency: {str(e)}")
if len(st.session_state.symptom_details) < 7:
st.session_state.current_step = 1
st.rerun()
else:
st.session_state.current_step = 3
st.rerun()
else:
st.warning("Please provide an answer")
def collect_basic_info():
st.header("Patient Information")
with st.form("basic_info"):
st.session_state.patient_info['name'] = st.text_input("Full Name")
st.session_state.patient_info['age'] = st.number_input("Age", min_value=0, max_value=120)
st.session_state.patient_info['gender'] = st.selectbox("Gender", ["Male", "Female", "Other"])
st.session_state.patient_info['main_symptom'] = st.text_input("Main Symptom")
if st.form_submit_button("Next"):
if all([st.session_state.patient_info.get(k) for k in ['name', 'age', 'gender', 'main_symptom']]):
st.session_state.current_step = 1
st.rerun()
else:
st.warning("Please fill all required fields")
def collect_medical_history():
st.header("Medical History")
with st.form("medical_history"):
st.session_state.patient_info['medical_history'] = st.text_area("Relevant Medical History")
st.session_state.patient_info['medications'] = st.text_area("Current Medications")
st.session_state.patient_info['allergies'] = st.text_input("Known Allergies")
st.session_state.patient_info['last_meal'] = st.text_input("Last Meal Time")
st.session_state.patient_info['recent_travel'] = st.text_input("Recent Travel History")
if st.form_submit_button("Submit"):
st.session_state.current_step = 4
st.rerun()
def generate_risk_assessment():
st.header("Risk Assessment")
try:
symptom_log = "\n".join(
[f"Q: {q['question']}\nA: {q['answer']}"
for q in st.session_state.symptom_details]
)
patient_profile = f"""
**Patient Profile**
Name: {st.session_state.patient_info['name']}
Age: {st.session_state.patient_info['age']}
Gender: {st.session_state.patient_info['gender']}
**Primary Complaint**
{st.session_state.patient_info['main_symptom']}
**Symptom Interrogation**
{symptom_log}
**Medical History**
{st.session_state.patient_info.get('medical_history', 'None reported')}
**Current Medications**
{st.session_state.patient_info.get('medications', 'None')}
**Allergies**
{st.session_state.patient_info.get('allergies', 'None reported')}
**Recent Context**
Last Meal: {st.session_state.patient_info.get('last_meal', 'Unknown')}
Recent Travel: {st.session_state.patient_info.get('recent_travel', 'None')}
"""
analysis_prompt = f"""STRICTLY follow these instructions:
1. Analyze this case: {patient_profile}
2. *Include ONLY symptoms the patient is actively experiencing*. Exclude all negated symptoms (e.g., "no fever," "denies breathlessness").
3. Output *EXCLUSIVELY* in this format with NO additional text or explanations:
[Age]-year-old [gender] with [specific, present symptoms].
Example Output:
"45-year-old man with severe chest pain radiating to the jaw"
Your Output:"""
response = st.session_state.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a medical AI that outputs ONLY patient descriptions."},
{"role": "user", "content": analysis_prompt}
],
model="mixtral-8x7b-32768",
temperature=0.3,
max_tokens=100
)
risk_prompt = response.choices[0].message.content.strip('"')
st.subheader("Clinical Summary")
st.markdown(f"```\n{risk_prompt}\n```")
# Create download button
timestamp = datetime.now().strftime('%Y%m%d%H%M')
filename = f"{st.session_state.patient_info['name'].replace(' ', '_')}_assessment_{timestamp}.txt"
st.download_button(
label="Download Assessment",
data=risk_prompt,
file_name=filename,
mime="text/plain"
)
except Exception as e:
st.error(f"Error generating risk assessment: {str(e)}")
def main():
st.title("🏥 AI Medical Consultancy")
# Progress indicator
steps_titles = ["Patient Info", "Symptoms", "Medical History", "Assessment"]
progress_html = """
<div class="progress-bar">
<div class="step {}">{}</div>
<div class="step {}">{}</div>
<div class="step {}">{}</div>
<div class="step {}">{}</div>
</div>
""".format(
'active' if st.session_state.current_step >= 0 else '',
'1. Patient Info',
'active' if st.session_state.current_step >= 1 else '',
'2. Symptoms',
'active' if st.session_state.current_step >= 3 else '',
'3. History',
'active' if st.session_state.current_step >= 4 else '',
'4. Report'
)
st.markdown(progress_html, unsafe_allow_html=True)
if not initialize_groq_client():
return
steps = {
0: collect_basic_info,
1: handle_symptom_interrogation,
2: handle_symptom_interrogation,
3: collect_medical_history,
4: generate_risk_assessment
}
current_step = st.session_state.get('current_step', 0)
if current_step in steps:
steps[current_step]()
if __name__ == "__main__":
main()