import gradio as gr import faiss import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer import google.generativeai as genai import re import os # Load data and FAISS index def load_data_and_index(): docs_df = pd.read_pickle("data.pkl") # Adjust path for HF Spaces embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32) dimension = embeddings.shape[1] index = faiss.IndexFlatL2(dimension) index.add(embeddings) return docs_df, index docs_df, index = load_data_and_index() # Load SentenceTransformer minilm = SentenceTransformer('all-MiniLM-L6-v2') # Configure Gemini API using Hugging Face Secrets GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") if not GEMINI_API_KEY: raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.") genai.configure(api_key=GEMINI_API_KEY) model = genai.GenerativeModel('gemini-2.0-flash') # Preprocess text function def preprocess_text(text): text = text.lower() text = text.replace('\n', ' ').replace('\t', ' ') text = re.sub(r'[^\w\s.,;:>-]', ' ', text) text = ' '.join(text.split()).strip() return text # Retrieve documents def retrieve_docs(query, k=5): query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32) distances, indices = index.search(np.array([query_embedding]), k) retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']] retrieved_docs['distance'] = distances[0] return retrieved_docs # Parse response into structured sections def parse_response(response_text): sections = { "Symptoms": [], "Signs": [], "Risk Factors": [], "Diagnostic Criteria": [], "Other": [] } # Simple regex-based parsing (adjust based on your Gemini output format) lines = response_text.split('\n') current_section = "Other" for line in lines: line = line.strip() if line.lower().startswith("symptoms:"): current_section = "Symptoms" elif line.lower().startswith("signs:"): current_section = "Signs" elif line.lower().startswith("risk factors") or line.lower().startswith("past medical history:"): current_section = "Risk Factors" elif line.lower().startswith("diagnostic criteria:"): current_section = "Diagnostic Criteria" elif line and not line.startswith((' ', '\t')) and ':' in line: current_section = "Other" if line and not line.endswith(':'): sections[current_section].append(line) return sections # Respond function with generic HTML formatting def respond(message, system_message, max_tokens, temperature, top_p): # Preprocess the user message preprocessed_query = preprocess_text(message) # Retrieve relevant documents retrieved_docs = retrieve_docs(preprocessed_query, k=5) context = "\n".join(retrieved_docs['text'].tolist()) # Construct the prompt with system message and RAG context prompt = f"{system_message}\n\n" prompt += ( f"Query: {message}\n" f"Relevant Context: {context}\n" f"Generate a concise response to the query based only on the provided context. " f"Structure the response with clear sections like 'Symptoms:', 'Signs:', 'Risk Factors:', and 'Diagnostic Criteria:' where applicable." ) # Generate response with Gemini response = model.generate_content( prompt, generation_config=genai.types.GenerationConfig( max_output_tokens=max_tokens, temperature=temperature ) ) answer = response.text.strip() # Parse the response into sections sections = parse_response(answer) # Format the response into HTML with CSS styling html_response = """

AI Response

Based on the provided context, here is the information relevant to your query:

""" # Add sections dynamically for section, items in sections.items(): if items: # Only include sections that have content html_response += f"

{section}

" html_response += "" html_response += "
" return html_response # Simple Gradio Interface with HTML output demo = gr.Interface( fn=respond, inputs=[ gr.Textbox(label="Your Query", placeholder="Enter your medical question here (e.g., diabetes, heart failure)..."), gr.Textbox( value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.", label="System Message" ), gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", # Included but not used by Gemini ), ], outputs=gr.HTML(label="Diagnosis"), title="🏥 Medical Query Assistant", description="A medical assistant that diagnoses patient queries using AI and past records, with styled output for any condition." ) if __name__ == "__main__": demo.launch()