File size: 6,232 Bytes
e573d3e
 
 
 
 
 
 
 
 
 
 
ed8c0cb
e573d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9d2dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee32f79
e573d3e
 
 
 
 
 
 
ee32f79
e573d3e
 
 
 
a9d2dec
 
e573d3e
 
 
 
 
 
 
 
 
 
 
 
a9d2dec
 
 
421c21c
 
 
a9d2dec
 
 
 
 
 
421c21c
 
a9d2dec
 
421c21c
 
a9d2dec
 
 
 
 
 
 
 
 
 
421c21c
 
 
 
 
ee32f79
 
 
a9d2dec
e573d3e
 
ee32f79
e573d3e
ee32f79
e573d3e
 
 
 
 
 
 
 
 
421c21c
a9d2dec
 
e573d3e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os

# Load data and FAISS index
def load_data_and_index():
    docs_df = pd.read_pickle("data.pkl")  # Adjust path for HF Spaces
    embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return docs_df, index

docs_df, index = load_data_and_index()

# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')

# Preprocess text function
def preprocess_text(text):
    text = text.lower()
    text = text.replace('\n', ' ').replace('\t', ' ')
    text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
    text = ' '.join(text.split()).strip()
    return text

# Retrieve documents
def retrieve_docs(query, k=5):
    query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
    distances, indices = index.search(np.array([query_embedding]), k)
    retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
    retrieved_docs['distance'] = distances[0]
    return retrieved_docs

# Parse response into structured sections
def parse_response(response_text):
    sections = {
        "Symptoms": [],
        "Signs": [],
        "Risk Factors": [],
        "Diagnostic Criteria": [],
        "Other": []
    }
    
    # Simple regex-based parsing (adjust based on your Gemini output format)
    lines = response_text.split('\n')
    current_section = "Other"
    
    for line in lines:
        line = line.strip()
        if line.lower().startswith("symptoms:"):
            current_section = "Symptoms"
        elif line.lower().startswith("signs:"):
            current_section = "Signs"
        elif line.lower().startswith("risk factors") or line.lower().startswith("past medical history:"):
            current_section = "Risk Factors"
        elif line.lower().startswith("diagnostic criteria:"):
            current_section = "Diagnostic Criteria"
        elif line and not line.startswith((' ', '\t')) and ':' in line:
            current_section = "Other"
        if line and not line.endswith(':'):
            sections[current_section].append(line)
    
    return sections

# Respond function with generic HTML formatting
def respond(message, system_message, max_tokens, temperature, top_p):
    # Preprocess the user message
    preprocessed_query = preprocess_text(message)
    
    # Retrieve relevant documents
    retrieved_docs = retrieve_docs(preprocessed_query, k=5)
    context = "\n".join(retrieved_docs['text'].tolist())
    
    # Construct the prompt with system message and RAG context
    prompt = f"{system_message}\n\n"
    prompt += (
        f"Query: {message}\n"
        f"Relevant Context: {context}\n"
        f"Generate a concise response to the query based only on the provided context. "
        f"Structure the response with clear sections like 'Symptoms:', 'Signs:', 'Risk Factors:', and 'Diagnostic Criteria:' where applicable."
    )
    
    # Generate response with Gemini
    response = model.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temperature
        )
    )
    answer = response.text.strip()

    # Parse the response into sections
    sections = parse_response(answer)

    # Format the response into HTML with CSS styling
    html_response = """
    <style>
        .diagnosis-container { font-family: Arial, sans-serif; line-height: 1.6; padding: 15px; max-width: 800px; margin: auto; }
        h2 { color: #2c3e50; font-size: 24px; margin-bottom: 15px; border-bottom: 2px solid #2980b9; padding-bottom: 5px; }
        h3 { color: #2980b9; font-size: 18px; margin-top: 15px; margin-bottom: 8px; }
        ul { margin: 0; padding-left: 25px; }
        li { margin-bottom: 6px; color: #34495e; }
        p { margin: 5px 0; color: #34495e; }
    </style>
    <div class="diagnosis-container">
        <h2>AI Response</h2>
        <p>Based on the provided context, here is the information relevant to your query:</p>
    """

    # Add sections dynamically
    for section, items in sections.items():
        if items:  # Only include sections that have content
            html_response += f"<h3>{section}</h3>"
            html_response += "<ul>"
            for item in items:
                # Remove section prefix if present (e.g., "Symptoms:" from the first line)
                cleaned_item = re.sub(rf"^{section}:", "", item, flags=re.IGNORECASE).strip()
                html_response += f"<li>{cleaned_item}</li>"
            html_response += "</ul>"

    html_response += "</div>"
    return html_response

# Simple Gradio Interface with HTML output
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Your Query", placeholder="Enter your medical question here (e.g., diabetes, heart failure)..."),
        gr.Textbox(
            value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
            label="System Message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",  # Included but not used by Gemini
        ),
    ],
    outputs=gr.HTML(label="Diagnosis"),
    title="🏥 Medical Query Assistant",
    description="A medical assistant that diagnoses patient queries using AI and past records, with styled output for any condition."
)

if __name__ == "__main__":
    demo.launch()