Spaces:
Running
Running
File size: 3,777 Bytes
e573d3e ed8c0cb e573d3e 914eefe e573d3e 914eefe e573d3e 914eefe ee32f79 914eefe e573d3e 914eefe e573d3e 5574a92 e573d3e 914eefe e573d3e 5574a92 e573d3e 914eefe ee32f79 5574a92 e573d3e ee32f79 e573d3e ee32f79 e573d3e 914eefe e573d3e 914eefe 5574a92 e573d3e 914eefe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os
# Load data and FAISS index
def load_data_and_index():
docs_df = pd.read_pickle("data.pkl") # Adjust path for HF Spaces
embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return docs_df, index
docs_df, index = load_data_and_index()
# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')
# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')
# Preprocess text
def preprocess_text(text):
text = text.lower()
text = text.replace('\n', ' ').replace('\t', ' ')
text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
text = ' '.join(text.split()).strip()
return text
# Retrieve top-k documents
def retrieve_docs(query, k=5):
query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
distances, indices = index.search(np.array([query_embedding]), k)
retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
retrieved_docs['distance'] = distances[0]
return retrieved_docs
# Generate structured response
def respond(message, system_message, max_tokens, temperature, top_p):
# Preprocess and retrieve
preprocessed_query = preprocess_text(message)
retrieved_docs = retrieve_docs(preprocessed_query, k=5)
# Combine retrieved texts
context = "\n".join([f"- *{row['label']}* ({row['source']}): {row['text']}" for _, row in retrieved_docs.iterrows()])
# Build prompt
prompt = f"{system_message}\n\n"
prompt += (
f"Query: {message}\n"
f"Relevant Context: {context}\n"
f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
)
# Get Gemini response
response = model.generate_content(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=max_tokens,
temperature=temperature
)
)
answer = response.text.strip()
if not answer.endswith('.'):
last_period = answer.rfind('.')
if last_period != -1:
answer = answer[:last_period + 1]
else:
answer += "."
# Format output with Markdown
formatted_answer = f"""
**π©Ί Patient Query:**
{message}
---
**π Retrieved Context:**
{context}
---
**π§ Diagnosis / Suggestion:**
{answer}
"""
return formatted_answer.strip()
# Gradio app
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Your Query", placeholder="Enter your medical question here..."),
gr.Textbox(
value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
label="System Message"
),
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
outputs=gr.Markdown(label="Diagnosis"),
title="π₯ Medical Assistant",
description="A simple medical assistant that diagnoses patient queries using AI and past records."
)
if __name__ == "__main__":
demo.launch()
|