File size: 3,423 Bytes
e573d3e
 
 
 
 
 
 
 
 
fa2a7ef
 
 
 
 
 
 
 
e573d3e
fa2a7ef
e573d3e
fa2a7ef
 
e573d3e
fa2a7ef
 
 
 
 
 
e573d3e
fa2a7ef
 
e573d3e
fa2a7ef
 
 
 
e573d3e
fa2a7ef
 
 
 
 
 
 
e573d3e
fa2a7ef
 
 
 
848cd9e
fa2a7ef
 
 
 
 
 
 
 
 
 
 
e573d3e
fa2a7ef
 
 
e573d3e
 
 
fa2a7ef
e573d3e
 
fa2a7ef
 
 
 
 
 
 
 
 
 
 
 
 
848cd9e
ca2f154
fa2a7ef
ca2f154
fa2a7ef
e573d3e
fa2a7ef
 
 
e573d3e
 
 
fa2a7ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import faiss
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import re
import os

# Load data and FAISS index
def load_data_and_index():
    docs_df = pd.read_pickle("data.pkl")  # Adjust path for HF Spaces
    embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return docs_df, index

docs_df, index = load_data_and_index()

# Load SentenceTransformer
minilm = SentenceTransformer('all-MiniLM-L6-v2')

# Configure Gemini API using Hugging Face Secrets
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
if not GEMINI_API_KEY:
    raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
genai.configure(api_key=GEMINI_API_KEY)
model = genai.GenerativeModel('gemini-2.0-flash')

# Preprocess text function
def preprocess_text(text):
    text = text.lower()
    text = text.replace('\n', ' ').replace('\t', ' ')
    text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
    text = ' '.join(text.split()).strip()
    return text

# Retrieve documents
def retrieve_docs(query, k=5):
    query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
    distances, indices = index.search(np.array([query_embedding]), k)
    retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
    retrieved_docs['distance'] = distances[0]
    return retrieved_docs

# RAG pipeline integrated into respond function
def respond(message, system_message, max_tokens, temperature):
    # Preprocess the user message
    preprocessed_query = preprocess_text(message)
    
    # Retrieve relevant documents
    retrieved_docs = retrieve_docs(preprocessed_query, k=5)
    context = "\n".join(retrieved_docs['text'].tolist())
    
    # Construct the prompt with system message and RAG context, asking for structured response
    prompt = f"{system_message}\n\n"
    prompt += (
        f"Query: {message}\n"
        f"Relevant Context: {context}\n"
        f"Generate a short, concise response to the query based only on the provided context. "
        f"Format the response as a structured with headings and information write in the form of points not paragraph"
    )
    
    # Generate response with Gemini
    response = model.generate_content(
        prompt,
        generation_config=genai.types.GenerationConfig(
            max_output_tokens=max_tokens,
            temperature=temperature
        )
    )
    answer = response.text.strip()
    if not answer.endswith('.'):
        last_period = answer.rfind('.')
        if last_period != -1:
            answer = answer[:last_period + 1]
        else:
            answer += "."
    
    return answer

# Simple Gradio Interface
def chatbot_interface(message, system_message, max_tokens, temperature):
    return respond(message, system_message, max_tokens, temperature)

demo = gr.Interface(
    fn=chatbot_interface,
    inputs=[
        gr.Textbox(label="Your Query", placeholder="Enter your medical question here..."),
    ],
    outputs=gr.Textbox(label="Response"),
    title="🏥 Medical Chat Assistant",
    description="A simple medical assistant that diagnoses patient queries using AI and past records, providing structured responses."
)

if __name__ == "__main__":
    demo.launch()