File size: 3,163 Bytes
e573d3e
 
 
 
 
 
 
 
 
c477934
 
 
 
 
 
 
e573d3e
c477934
e573d3e
c477934
 
e573d3e
c477934
 
 
 
 
e573d3e
c477934
 
e573d3e
c477934
 
e573d3e
c477934
 
 
 
 
e573d3e
c477934
 
 
 
848cd9e
c477934
 
 
 
 
 
 
 
 
 
 
 
 
 
e573d3e
c477934
 
e573d3e
 
 
c477934
e573d3e
 
c477934
848cd9e
c477934
ca2f154
c477934
ca2f154
c477934
e573d3e
c477934
 
e573d3e
c477934
 
e573d3e
c477934
 
 
e573d3e
 
 
c477934
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os

# Load documents and FAISS index
def load_index_and_data():
    df = pd.read_pickle("data.pkl")
    vecs = np.array(df['embeddings'].tolist(), dtype=np.float32)
    idx = faiss.IndexFlatL2(vecs.shape[1])
    idx.add(vecs)
    return df, idx

docs_df, index = load_index_and_data()

# Embedding model and Gemini setup
encoder = SentenceTransformer("all-MiniLM-L6-v2")

API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
    raise EnvironmentError("Missing Gemini API key.")
genai.configure(api_key=API_KEY)
llm = genai.GenerativeModel("gemini-2.0-flash")

# Clean text input
def clean_text(text):
    text = text.lower()
    text = re.sub(r"[^\w\s.,]", " ", text)
    return " ".join(text.split())

# Retrieve relevant document context
def get_context(query, k=5):
    q_vec = encoder.encode([query])[0].astype(np.float32)
    _, indices = index.search(np.array([q_vec]), k)
    return "\n".join(docs_df.iloc[indices[0]]["text"].tolist())

# RAG-based Gemini response generation
def generate_answer(user_input, system_note, max_tokens, temp):
    query = clean_text(user_input)
    context = get_context(query)
    
    prompt = (
        f"Role Description:\n{system_note}\n\n"
        f"User Question:\n{user_input}\n\n"
        f"Knowledge Extracted From Records:\n{context}\n\n"
        f"Instructions:\n"
        f"- Analyze the user's query using ONLY the above context.\n"
        f"- Do NOT add external or made-up information.\n"
        f"- Begin with a brief summary of the identified condition or concern.\n"
        f"- Provide detailed reasoning and explanation in bullet points:\n"
        f"   • Include possible causes, symptoms, and diagnostic considerations.\n"
        f"   • Mention relevant terms or observations from context.\n"
        f"   • Explain how the context supports the conclusions.\n"
        f"- End with a short, clear recommendation (if context permits).\n"
        f"- Avoid medical advice unless the context contains it."
    )

    result = llm.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temp
        )
    )
    return result.text.strip()

# Gradio interface
demo = gr.Interface(
    fn=generate_answer,
    inputs=[
        gr.Textbox(label="Ask Something", placeholder="Describe your symptom or condition..."),
        gr.Textbox(
            value="You are a virtual medical assistant using past medical records to respond intelligently.",
            label="System Role"
        ),
        gr.Slider(50, 500, value=300, step=10, label="Max Tokens"),
        gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Creativity (Temperature)")
    ],
    outputs=gr.Textbox(label="AI Diagnosis"),
    title="🩺 Smart Medical Query Assistant",
    description="Submit a health-related question. The assistant analyzes similar past records to respond accurately and clearly."
)

if __name__ == "__main__":
    demo.launch()