burhan112 commited on
Commit
e573d3e
·
verified ·
1 Parent(s): cf1f39a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import faiss
3
+ import numpy as np
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+ import google.generativeai as genai
7
+ import re
8
+ import os
9
+
10
+ # Load data and FAISS index
11
+ def load_data_and_index():
12
+ docs_df = pd.read_pickle("docs_with_embeddings (1).pkl") # Adjust path for HF Spaces
13
+ embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
14
+ dimension = embeddings.shape[1]
15
+ index = faiss.IndexFlatL2(dimension)
16
+ index.add(embeddings)
17
+ return docs_df, index
18
+
19
+ docs_df, index = load_data_and_index()
20
+
21
+ # Load SentenceTransformer
22
+ minilm = SentenceTransformer('all-MiniLM-L6-v2')
23
+
24
+ # Configure Gemini API using Hugging Face Secrets
25
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
26
+ if not GEMINI_API_KEY:
27
+ raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
28
+ genai.configure(api_key=GEMINI_API_KEY)
29
+ model = genai.GenerativeModel('gemini-2.0-flash')
30
+
31
+ # Preprocess text function
32
+ def preprocess_text(text):
33
+ text = text.lower()
34
+ text = text.replace('\n', ' ').replace('\t', ' ')
35
+ text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
36
+ text = ' '.join(text.split()).strip()
37
+ return text
38
+
39
+ # Retrieve documents
40
+ def retrieve_docs(query, k=5):
41
+ query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
42
+ distances, indices = index.search(np.array([query_embedding]), k)
43
+ retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
44
+ retrieved_docs['distance'] = distances[0]
45
+ return retrieved_docs
46
+
47
+ # RAG pipeline integrated into respond function
48
+ def respond(
49
+ message,
50
+ history: list[tuple[str, str]],
51
+ system_message,
52
+ max_tokens,
53
+ temperature,
54
+ top_p, # Keeping top_p as an input, though Gemini doesn’t use it directly
55
+ ):
56
+ # Preprocess the user message
57
+ preprocessed_query = preprocess_text(message)
58
+
59
+ # Retrieve relevant documents
60
+ retrieved_docs = retrieve_docs(preprocessed_query, k=5)
61
+ context = "\n".join(retrieved_docs['text'].tolist())
62
+
63
+ # Construct the prompt with system message, history, and RAG context
64
+ prompt = f"{system_message}\n\n"
65
+ for user_msg, assistant_msg in history:
66
+ if user_msg:
67
+ prompt += f"User: {user_msg}\n"
68
+ if assistant_msg:
69
+ prompt += f"Assistant: {assistant_msg}\n"
70
+ prompt += (
71
+ f"Query: {message}\n"
72
+ f"Relevant Context: {context}\n"
73
+ f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
74
+ )
75
+
76
+ # Generate response with Gemini
77
+ response = model.generate_content(
78
+ prompt,
79
+ generation_config=genai.types.GenerationConfig(
80
+ max_output_tokens=max_tokens,
81
+ temperature=temperature
82
+ )
83
+ )
84
+ answer = response.text.strip()
85
+ if not answer.endswith('.'):
86
+ last_period = answer.rfind('.')
87
+ if last_period != -1:
88
+ answer = answer[:last_period + 1]
89
+ else:
90
+ answer += "."
91
+
92
+ # Yield the full response (no streaming, as Gemini API doesn’t support it here)
93
+ yield answer
94
+
95
+ # Gradio Chat Interface
96
+ demo = gr.ChatInterface(
97
+ respond,
98
+ additional_inputs=[
99
+ gr.Textbox(
100
+ value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
101
+ label="System message"
102
+ ),
103
+ gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
104
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
105
+ gr.Slider(
106
+ minimum=0.1,
107
+ maximum=1.0,
108
+ value=0.95,
109
+ step=0.05,
110
+ label="Top-p (nucleus sampling)", # Included but not used by Gemini
111
+ ),
112
+ ],
113
+ title="🏥 Medical Chat Assistant",
114
+ description="A chat-based medical assistant that diagnoses patient queries using AI and past records."
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.launch()