File size: 3,705 Bytes
e573d3e
 
 
 
 
 
 
 
 
 
 
45daa66
e573d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee32f79
 
e573d3e
 
 
 
 
 
 
ee32f79
e573d3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee32f79
e573d3e
ee32f79
 
 
 
 
e573d3e
 
ee32f79
e573d3e
ee32f79
e573d3e
 
 
 
 
 
 
 
 
ee32f79
 
 
e573d3e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os

# Load data and FAISS index
def load_data_and_index():
    docs_df = pd.read_pickle("data.pkl")  # Adjust path for HF Spaces
    embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return docs_df, index

docs_df, index = load_data_and_index()

# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')

# Preprocess text function
def preprocess_text(text):
    text = text.lower()
    text = text.replace('\n', ' ').replace('\t', ' ')
    text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
    text = ' '.join(text.split()).strip()
    return text

# Retrieve documents
def retrieve_docs(query, k=5):
    query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
    distances, indices = index.search(np.array([query_embedding]), k)
    retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
    retrieved_docs['distance'] = distances[0]
    return retrieved_docs

# Simplified respond function (no history)
def respond(message, system_message, max_tokens, temperature, top_p):
    # Preprocess the user message
    preprocessed_query = preprocess_text(message)
    
    # Retrieve relevant documents
    retrieved_docs = retrieve_docs(preprocessed_query, k=5)
    context = "\n".join(retrieved_docs['text'].tolist())
    
    # Construct the prompt with system message and RAG context
    prompt = f"{system_message}\n\n"
    prompt += (
        f"Query: {message}\n"
        f"Relevant Context: {context}\n"
        f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
    )
    
    # Generate response with Gemini
    response = model.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temperature
        )
    )
    answer = response.text.strip()
    if not answer.endswith('.'):
        last_period = answer.rfind('.')
        if last_period != -1:
            answer = answer[:last_period + 1]
        else:
            answer += "."
    
    return answer

# Simple Gradio Interface
demo = gr.Interface(
    fn=respond,
    inputs=[
        gr.Textbox(label="Your Query", placeholder="Enter your medical question here..."),
        gr.Textbox(
            value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
            label="System Message"
        ),
        gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",  # Included but not used by Gemini
        ),
    ],
    outputs=gr.Textbox(label="Diagnosis"),
    title="🏥 Medical Assistant",
    description="A simple medical assistant that diagnoses patient queries using AI and past records."
)

if __name__ == "__main__":
    demo.launch()