Spaces:
Sleeping
Sleeping
File size: 6,232 Bytes
e573d3e ed8c0cb e573d3e a9d2dec ee32f79 e573d3e ee32f79 e573d3e a9d2dec e573d3e a9d2dec 421c21c a9d2dec 421c21c a9d2dec 421c21c a9d2dec 421c21c ee32f79 a9d2dec e573d3e ee32f79 e573d3e ee32f79 e573d3e 421c21c a9d2dec e573d3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os
# Load data and FAISS index
def load_data_and_index():
docs_df = pd.read_pickle("data.pkl") # Adjust path for HF Spaces
embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return docs_df, index
docs_df, index = load_data_and_index()
# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')
# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')
# Preprocess text function
def preprocess_text(text):
text = text.lower()
text = text.replace('\n', ' ').replace('\t', ' ')
text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
text = ' '.join(text.split()).strip()
return text
# Retrieve documents
def retrieve_docs(query, k=5):
query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
distances, indices = index.search(np.array([query_embedding]), k)
retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
retrieved_docs['distance'] = distances[0]
return retrieved_docs
# Parse response into structured sections
def parse_response(response_text):
sections = {
"Symptoms": [],
"Signs": [],
"Risk Factors": [],
"Diagnostic Criteria": [],
"Other": []
}
# Simple regex-based parsing (adjust based on your Gemini output format)
lines = response_text.split('\n')
current_section = "Other"
for line in lines:
line = line.strip()
if line.lower().startswith("symptoms:"):
current_section = "Symptoms"
elif line.lower().startswith("signs:"):
current_section = "Signs"
elif line.lower().startswith("risk factors") or line.lower().startswith("past medical history:"):
current_section = "Risk Factors"
elif line.lower().startswith("diagnostic criteria:"):
current_section = "Diagnostic Criteria"
elif line and not line.startswith((' ', '\t')) and ':' in line:
current_section = "Other"
if line and not line.endswith(':'):
sections[current_section].append(line)
return sections
# Respond function with generic HTML formatting
def respond(message, system_message, max_tokens, temperature, top_p):
# Preprocess the user message
preprocessed_query = preprocess_text(message)
# Retrieve relevant documents
retrieved_docs = retrieve_docs(preprocessed_query, k=5)
context = "\n".join(retrieved_docs['text'].tolist())
# Construct the prompt with system message and RAG context
prompt = f"{system_message}\n\n"
prompt += (
f"Query: {message}\n"
f"Relevant Context: {context}\n"
f"Generate a concise response to the query based only on the provided context. "
f"Structure the response with clear sections like 'Symptoms:', 'Signs:', 'Risk Factors:', and 'Diagnostic Criteria:' where applicable."
)
# Generate response with Gemini
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=max_tokens,
temperature=temperature
)
)
answer = response.text.strip()
# Parse the response into sections
sections = parse_response(answer)
# Format the response into HTML with CSS styling
html_response = """
<style>
.diagnosis-container { font-family: Arial, sans-serif; line-height: 1.6; padding: 15px; max-width: 800px; margin: auto; }
h2 { color: #2c3e50; font-size: 24px; margin-bottom: 15px; border-bottom: 2px solid #2980b9; padding-bottom: 5px; }
h3 { color: #2980b9; font-size: 18px; margin-top: 15px; margin-bottom: 8px; }
ul { margin: 0; padding-left: 25px; }
li { margin-bottom: 6px; color: #34495e; }
p { margin: 5px 0; color: #34495e; }
</style>
<div class="diagnosis-container">
<h2>AI Response</h2>
<p>Based on the provided context, here is the information relevant to your query:</p>
"""
# Add sections dynamically
for section, items in sections.items():
if items: # Only include sections that have content
html_response += f"<h3>{section}</h3>"
html_response += "<ul>"
for item in items:
# Remove section prefix if present (e.g., "Symptoms:" from the first line)
cleaned_item = re.sub(rf"^{section}:", "", item, flags=re.IGNORECASE).strip()
html_response += f"<li>{cleaned_item}</li>"
html_response += "</ul>"
html_response += "</div>"
return html_response
# Simple Gradio Interface with HTML output
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Your Query", placeholder="Enter your medical question here (e.g., diabetes, heart failure)..."),
gr.Textbox(
value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
label="System Message"
),
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)", # Included but not used by Gemini
),
],
outputs=gr.HTML(label="Diagnosis"),
title="🏥 Medical Query Assistant",
description="A medical assistant that diagnoses patient queries using AI and past records, with styled output for any condition."
)
if __name__ == "__main__":
demo.launch() |