Clinical_RAG / app.py
burhan112's picture
Update app.py
c477934 verified
raw
history blame
3.16 kB
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os
# Load documents and FAISS index
def load_index_and_data():
df = pd.read_pickle("data.pkl")
vecs = np.array(df['embeddings'].tolist(), dtype=np.float32)
idx = faiss.IndexFlatL2(vecs.shape[1])
idx.add(vecs)
return df, idx
docs_df, index = load_index_and_data()
# Embedding model and Gemini setup
encoder = SentenceTransformer("all-MiniLM-L6-v2")
API_KEY = os.getenv("GEMINI_API_KEY")
if not API_KEY:
raise EnvironmentError("Missing Gemini API key.")
genai.configure(api_key=API_KEY)
llm = genai.GenerativeModel("gemini-2.0-flash")
# Clean text input
def clean_text(text):
text = text.lower()
text = re.sub(r"[^\w\s.,]", " ", text)
return " ".join(text.split())
# Retrieve relevant document context
def get_context(query, k=5):
q_vec = encoder.encode([query])[0].astype(np.float32)
_, indices = index.search(np.array([q_vec]), k)
return "\n".join(docs_df.iloc[indices[0]]["text"].tolist())
# RAG-based Gemini response generation
def generate_answer(user_input, system_note, max_tokens, temp):
query = clean_text(user_input)
context = get_context(query)
prompt = (
f"Role Description:\n{system_note}\n\n"
f"User Question:\n{user_input}\n\n"
f"Knowledge Extracted From Records:\n{context}\n\n"
f"Instructions:\n"
f"- Analyze the user's query using ONLY the above context.\n"
f"- Do NOT add external or made-up information.\n"
f"- Begin with a brief summary of the identified condition or concern.\n"
f"- Provide detailed reasoning and explanation in bullet points:\n"
f" • Include possible causes, symptoms, and diagnostic considerations.\n"
f" • Mention relevant terms or observations from context.\n"
f" • Explain how the context supports the conclusions.\n"
f"- End with a short, clear recommendation (if context permits).\n"
f"- Avoid medical advice unless the context contains it."
)
result = llm.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=max_tokens,
temperature=temp
)
)
return result.text.strip()
# Gradio interface
demo = gr.Interface(
fn=generate_answer,
inputs=[
gr.Textbox(label="Ask Something", placeholder="Describe your symptom or condition..."),
gr.Textbox(
value="You are a virtual medical assistant using past medical records to respond intelligently.",
label="System Role"
),
gr.Slider(50, 500, value=300, step=10, label="Max Tokens"),
gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Creativity (Temperature)")
],
outputs=gr.Textbox(label="AI Diagnosis"),
title="🩺 Smart Medical Query Assistant",
description="Submit a health-related question. The assistant analyzes similar past records to respond accurately and clearly."
)
if __name__ == "__main__":
demo.launch()