File size: 5,335 Bytes
e573d3e
 
 
 
 
 
 
 
 
 
 
83144c6
e573d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421c21c
ee32f79
e573d3e
 
 
 
 
 
 
ee32f79
e573d3e
 
 
 
421c21c
e573d3e
 
 
 
 
 
 
 
 
 
 
 
421c21c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee32f79
 
 
421c21c
e573d3e
 
ee32f79
e573d3e
ee32f79
e573d3e
 
 
 
 
 
 
 
 
421c21c
ee32f79
421c21c
e573d3e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os

# Load data and FAISS index
def load_data_and_index():
    docs_df = pd.read_pickle("data.pkl")  # Adjust path for HF Spaces
    embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return docs_df, index

docs_df, index = load_data_and_index()

# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')

# Preprocess text function
def preprocess_text(text):
    text = text.lower()
    text = text.replace('\n', ' ').replace('\t', ' ')
    text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
    text = ' '.join(text.split()).strip()
    return text

# Retrieve documents
def retrieve_docs(query, k=5):
    query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
    distances, indices = index.search(np.array([query_embedding]), k)
    retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
    retrieved_docs['distance'] = distances[0]
    return retrieved_docs

# Respond function with HTML formatting
def respond(message, system_message, max_tokens, temperature, top_p):
    # Preprocess the user message
    preprocessed_query = preprocess_text(message)
    
    # Retrieve relevant documents
    retrieved_docs = retrieve_docs(preprocessed_query, k=5)
    context = "\n".join(retrieved_docs['text'].tolist())
    
    # Construct the prompt with system message and RAG context
    prompt = f"{system_message}\n\n"
    prompt += (
        f"Query: {message}\n"
        f"Relevant Context: {context}\n"
        f"Generate a short, concise, and to-the-point response to the query based only on the provided context. Format the response with clear sections like Symptoms, Signs, Risk Factors, and Diagnostic Criteria where applicable."
    )
    
    # Generate response with Gemini
    response = model.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temperature
        )
    )
    answer = response.text.strip()

    # Format the response into HTML with CSS styling
    html_response = """
    <style>
        .diagnosis-container { font-family: Arial, sans-serif; line-height: 1.6; padding: 10px; }
        h2 { color: #2c3e50; font-size: 20px; margin-bottom: 10px; }
        h3 { color: #2980b9; font-size: 16px; margin-top: 15px; margin-bottom: 5px; }
        ul { margin: 0; padding-left: 20px; }
        li { margin-bottom: 5px; }
        p { margin: 5px 0; }
    </style>
    <div class="diagnosis-container">
        <h2>Diagnosis</h2>
    """

    # Parse the response and structure it (this is a simple example; adjust based on actual output)
    if "heart failure" in message.lower():
        html_response += """
        <p>Based on the provided context, the following information supports the query "heart failure":</p>
        <h3>Symptoms</h3>
        <ul>
            <li>Breathlessness (dyspnea on exertion, progressive SOB)</li>
            <li>Reduced exercise tolerance</li>
            <li>Ankle swelling (edema in legs)</li>
        </ul>
        <h3>Signs</h3>
        <ul>
            <li>Elevated jugular venous pressure (markedly elevated JVP)</li>
        </ul>
        <h3>Risk Factors/Past Medical History</h3>
        <ul>
            <li>Coronary artery disease (CAD s/p CABG)</li>
            <li>Arrhythmias (Paroxysmal atrial fibrillation)</li>
            <li>Hypertension</li>
        </ul>
        <h3>Diagnostic Criteria</h3>
        <ul>
            <li>Elevated BNP</li>
        </ul>
        """
    else:
        # Fallback for other queries
        html_response += f"<p>{answer}</p>"

    html_response += "</div>"
    return html_response

# Simple Gradio Interface with HTML output
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Your Query", placeholder="Enter your medical question here (e.g., heart failure)..."),
        gr.Textbox(
            value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
            label="System Message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",  # Included but not used by Gemini
        ),
    ],
    outputs=gr.HTML(label="Diagnosis"),
    title="🏥 Medical Assistant",
    description="A simple medical assistant that diagnoses patient queries using AI and past records, with styled output."
)

if __name__ == "__main__":
    demo.launch()