import streamlit as st import os import pandas as pd import re from groq import Groq from datetime import datetime # Set page config st.set_page_config(page_title="AI Medical Consultancy", layout="wide") # Load custom CSS def load_css(): try: with open("style.css") as f: st.markdown(f"", unsafe_allow_html=True) except FileNotFoundError: st.warning("CSS file not found. Please ensure 'style.css' is in the directory.") load_css() # Initialize session state variables if 'current_step' not in st.session_state: st.session_state.current_step = 0 if 'symptom_details' not in st.session_state: st.session_state.symptom_details = [] if 'patient_info' not in st.session_state: st.session_state.patient_info = {} if 'appointment_details' not in st.session_state: st.session_state.appointment_details = None if 'appointment_summary' not in st.session_state: st.session_state.appointment_summary = None if 'analysis_results' not in st.session_state: st.session_state.analysis_results = None class MedicalAnalysisSystem: def _init_(self, dataset_path): try: self.data = pd.read_csv(dataset_path) print("Dataset Columns:", self.data.columns.tolist()) # Debug print("Sample Data:\n", self.data.head()) # Debug # Clean data - remove placeholder rows self.data = self.data[~self.data['Symptom'].str.contains('Symptom|Condition', case=False)] self.data['Risk Score'] = pd.to_numeric(self.data['Risk Score'], errors='coerce') # Handle missing values self.data['Risk Score'] = self.data['Risk Score'].fillna(0) # Calculate MAX_RISK_SCORE dynamically symptom_max_risk = self.data.groupby('Symptom')['Risk Score'].max().sum() max_age = 120 max_age_risk = (max_age - 40) * 0.05 if max_age > 40 else 0 self.MAX_RISK_SCORE = symptom_max_risk + max_age_risk self.local_messages = [] self.severity_mapping = { 'Mild': ['mild', 'slight', 'minor', 'low grade'], 'Moderate': ['moderate', 'medium', 'average'], 'Severe': ['severe', 'high', 'extreme', 'critical', 'intense', 'very bad', 'acute'] } self.negation_words = {'no', 'not', 'denies', 'without', 'negative', 'none', 'denied'} except Exception as e: st.error(f"Dataset Error: {str(e)}") raise def add_patient_data(self, patient_message): try: if not patient_message: raise ValueError("Patient message cannot be empty") self.local_messages.append({ 'message': patient_message, 'timestamp': datetime.now().timestamp() }) except Exception as e: st.error(f"Error adding patient data: {str(e)}") def extract_info_from_bot_response(self, bot_response_data): try: if not bot_response_data: return 0, [], {} bot_response_text = str(bot_response_data) bot_response_lower = bot_response_text.lower() # Age extraction age = 0 age_pattern = r'(\d{1,3})\s*(?:years?-?old|yo|years|-years-old?)' age_match = re.search(age_pattern, bot_response_text, re.IGNORECASE) if age_match: age = int(age_match.group(1)) if not (0 <= age <= 120): age = 20 # Symptom extraction symptoms = [] for symptom in self.data['Symptom'].unique(): symptom_lower = symptom.lower() pattern = re.compile(r'\b' + re.escape(symptom_lower) + r'\b', re.IGNORECASE) matches = pattern.finditer(bot_response_lower) for match in matches: start_pos = match.start() preceding_text = bot_response_lower[:start_pos].split() preceding_words = preceding_text[-3:] if not any(neg in preceding_words for neg in self.negation_words): symptoms.append(symptom) break # Severity analysis symptom_severity = {} for symptom in symptoms: symptom_lower = symptom.lower() highest_severity_score = 0 pattern = re.compile(r'\b' + re.escape(symptom_lower) + r'\b', re.IGNORECASE) matches = pattern.finditer(bot_response_lower) for match in matches: start, end = match.start(), match.end() words = bot_response_lower.split() match_index = len(bot_response_lower[:start].split()) context_start = max(0, match_index - 5) context_end = min(len(words), match_index + 6) context = ' '.join(words[context_start:context_end]) for severity, keywords in self.severity_mapping.items(): for keyword in keywords: if re.search(r'\b' + re.escape(keyword) + r'\b', context): condition_data = self.data[(self.data['Symptom'] == symptom) & (self.data['Condition'] == severity)] if not condition_data.empty: risk_score = condition_data['Risk Score'].values[0] if risk_score > highest_severity_score: highest_severity_score = risk_score if highest_severity_score == 0: highest_severity_score = self.data[self.data['Symptom'] == symptom]['Risk Score'].max() symptom_severity[symptom] = highest_severity_score return age, symptoms, symptom_severity except Exception as e: st.error(f"Extraction Error: {str(e)}") return 0, [], {} def calculate_risk_score(self, age, symptoms, symptom_severity): try: # Validate symptoms valid_symptoms = [s for s in symptoms if s in self.data['Symptom'].values] if not valid_symptoms: return "Unknown", 0, 0 # Calculate scores with validation symptom_risk = sum(float(symptom_severity.get(s, 0)) for s in valid_symptoms) age_risk = max((age - 40) * 0.05, 0) if age >= 40 else 0 final_score = symptom_risk + age_risk # Ensure we don't divide by zero max_score = self.MAX_RISK_SCORE if self.MAX_RISK_SCORE > 0 else 1 risk_pct = min(100, max(0, (final_score / max_score) * 100)) if risk_pct <= 30: label = "Low" elif risk_pct <= 70: label = "Medium" else: label = "High" return label, final_score, round(risk_pct, 1) except Exception as e: st.error(f"Risk Calculation Error: {str(e)}") return "Low", 0, 0 def analyze_patient_data(self, patient_message): """Full analysis workflow""" try: # Clean input message patient_message = patient_message.replace("Symptom", "").replace("Condition", "") self.add_patient_data(patient_message) age, symptoms, severity = self.extract_info_from_bot_response(patient_message) # Filter invalid symptoms valid_symptoms = [s for s in symptoms if s in self.data['Symptom'].values] if not valid_symptoms: return {"error": "No valid symptoms detected"} # Get unique conditions from valid symptoms conditions = self.data[self.data['Symptom'].isin(valid_symptoms)]['Condition'].unique() valid_conditions = [c for c in conditions if c not in ['Normal', 'Moderate', 'Severe', 'Condition']] risk_label, risk_score, risk_pct = self.calculate_risk_score(age, valid_symptoms, severity) return { 'age': age, 'symptoms': valid_symptoms, 'symptom_severity': severity, 'risk_label': risk_label, 'risk_score': round(risk_score, 2), 'risk_percentage': risk_pct, 'possible_conditions': valid_conditions, 'analysis_timestamp': datetime.now().isoformat() } except Exception as e: return {"error": f"Analysis Error: {str(e)}"} def process_user_data(self): try: if not self.local_messages: return {"error": "No messages available"} latest = max(self.local_messages, key=lambda x: x['timestamp']) age, symptoms, severity = self.extract_info_from_bot_response(latest['message']) if not symptoms: return {"error": "No symptoms detected"} risk_label, risk_score, risk_pct = self.calculate_risk_score(age, symptoms, severity) return { 'age': age, 'symptoms': symptoms, 'symptom_severity': severity, 'risk_label': risk_label, 'risk_score': round(risk_score, 2), 'risk_percentage': risk_pct, 'possible_conditions': self.data[self.data['Symptom'].isin(symptoms)]['Condition'].unique().tolist(), 'analysis_timestamp': datetime.now().isoformat() } except Exception as e: return {"error": f"Processing Error: {str(e)}"} def initialize_groq_client(): try: api_key = st.secrets.get("GROQ_API_KEY", os.getenv("GROQ_API_KEY")) if not api_key: api_key = st.text_input("Enter Groq API Key:", type="password") if not api_key: return False st.session_state.client = Groq(api_key=api_key) return True except Exception as e: st.error(f"Groq Error: {str(e)}") return False def symptom_interrogation_step(): client = st.session_state.client main_symptom = st.session_state.patient_info['main_symptom'] step = len(st.session_state.symptom_details) if step == 0: medical_focus = { 'pain': "location/radiation/provoking factors", 'fever': "pattern/associated symptoms/response to meds", 'gi': "bowel changes/ingestion timing/associated symptoms", 'respiratory': "exertion relationship/sputum/triggers" } focus = medical_focus.get(main_symptom.lower(), "temporal pattern/severity progression/associated symptoms") prompt = f"""As an ER physician, ask ONE high-yield question about {main_symptom} focusing on {focus}. Use simple, patient-friendly language. Ask only ONE question.""" else: last_qa = st.session_state.symptom_details[-1] prompt = f"""Based on previous Q: {last_qa['question']} → A: {last_qa['answer']} Ask the NEXT critical question about {main_symptom} considering red flags.""" try: response = client.chat.completions.create( messages=[{"role": "user", "content": prompt}], model="mixtral-8x7b-32768", temperature=0.3 ) question = response.choices[0].message.content.strip() if not question.endswith('?'): question += '?' st.session_state.current_question = question except Exception as e: st.error(f"Question Generation Error: {str(e)}") st.stop() def handle_symptom_interrogation(): st.header("Symptom Analysis") if st.session_state.current_step == 1: symptom_interrogation_step() st.session_state.current_step = 2 if 'current_question' in st.session_state: with st.form("symptom_qna"): st.markdown(f'
', unsafe_allow_html=True) answer = st.text_input("Your answer:", key=f"answer_{len(st.session_state.symptom_details)}") if st.form_submit_button("Next"): if answer: st.session_state.symptom_details.append({ "question": st.session_state.current_question, "answer": answer }) del st.session_state.current_question # Emergency check if len(st.session_state.symptom_details) >= 3: last_answer = st.session_state.symptom_details[-1]['answer'] try: urgency_check = st.session_state.client.chat.completions.create( messages=[{"role": "user", "content": f"Does this indicate emergency? '{last_answer}' Yes/No"}], model="mixtral-8x7b-32768", temperature=0 ).choices[0].message.content if 'YES' in urgency_check.upper(): st.markdown('