burhan112 commited on
Commit
848cd9e
·
verified ·
1 Parent(s): 914eefe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -42
app.py CHANGED
@@ -9,7 +9,7 @@ import os
9
 
10
  # Load data and FAISS index
11
  def load_data_and_index():
12
- docs_df = pd.read_pickle("data.pkl") # Adjust path for HF Spaces
13
  embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
14
  dimension = embeddings.shape[1]
15
  index = faiss.IndexFlatL2(dimension)
@@ -28,7 +28,7 @@ if not GEMINI_API_KEY:
28
  genai.configure(api_key=GEMINI_API_KEY)
29
  model = genai.GenerativeModel('gemini-2.0-flash')
30
 
31
- # Preprocess text
32
  def preprocess_text(text):
33
  text = text.lower()
34
  text = text.replace('\n', ' ').replace('\t', ' ')
@@ -36,7 +36,7 @@ def preprocess_text(text):
36
  text = ' '.join(text.split()).strip()
37
  return text
38
 
39
- # Retrieve top-k documents
40
  def retrieve_docs(query, k=5):
41
  query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
42
  distances, indices = index.search(np.array([query_embedding]), k)
@@ -44,24 +44,36 @@ def retrieve_docs(query, k=5):
44
  retrieved_docs['distance'] = distances[0]
45
  return retrieved_docs
46
 
47
- # Generate structured response
48
- def respond(message, system_message, max_tokens, temperature, top_p):
49
- # Preprocess and retrieve
 
 
 
 
 
 
 
50
  preprocessed_query = preprocess_text(message)
 
 
51
  retrieved_docs = retrieve_docs(preprocessed_query, k=5)
52
-
53
- # Combine retrieved texts
54
- context = "\n".join([f"- *{row['label']}* ({row['source']}): {row['text']}" for _, row in retrieved_docs.iterrows()])
55
-
56
- # Build prompt
57
  prompt = f"{system_message}\n\n"
 
 
 
 
 
58
  prompt += (
59
  f"Query: {message}\n"
60
  f"Relevant Context: {context}\n"
61
  f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
62
  )
63
-
64
- # Get Gemini response
65
  response = model.generate_content(
66
  prompt,
67
  generation_config=genai.types.GenerationConfig(
@@ -77,41 +89,37 @@ def respond(message, system_message, max_tokens, temperature, top_p):
77
  else:
78
  answer += "."
79
 
80
- # Format output with Markdown
81
  formatted_answer = f"""
82
- **🩺 Patient Query:**
83
- {message}
84
-
85
- ---
86
-
87
- **📚 Retrieved Context:**
88
- {context}
89
-
90
- ---
91
-
92
- **🧠 Diagnosis / Suggestion:**
93
- {answer}
94
- """
95
-
96
- return formatted_answer.strip()
97
-
98
- # Gradio app
99
- demo = gr.Interface(
100
- fn=respond,
101
- inputs=[
102
- gr.Textbox(label="Your Query", placeholder="Enter your medical question here..."),
103
  gr.Textbox(
104
  value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
105
- label="System Message"
106
  ),
107
- gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max New Tokens"),
108
  gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
109
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
110
  ],
111
- outputs=gr.Markdown(label="Diagnosis"),
112
- title="🏥 Medical Assistant",
113
- description="A simple medical assistant that diagnoses patient queries using AI and past records."
114
  )
115
 
116
  if __name__ == "__main__":
117
- demo.launch()
 
9
 
10
  # Load data and FAISS index
11
  def load_data_and_index():
12
+ docs_df = pd.read_pickle("docs_with_embeddings (1).pkl") # Adjust path for HF Spaces
13
  embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
14
  dimension = embeddings.shape[1]
15
  index = faiss.IndexFlatL2(dimension)
 
28
  genai.configure(api_key=GEMINI_API_KEY)
29
  model = genai.GenerativeModel('gemini-2.0-flash')
30
 
31
+ # Preprocess text function
32
  def preprocess_text(text):
33
  text = text.lower()
34
  text = text.replace('\n', ' ').replace('\t', ' ')
 
36
  text = ' '.join(text.split()).strip()
37
  return text
38
 
39
+ # Retrieve documents
40
  def retrieve_docs(query, k=5):
41
  query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
42
  distances, indices = index.search(np.array([query_embedding]), k)
 
44
  retrieved_docs['distance'] = distances[0]
45
  return retrieved_docs
46
 
47
+ # RAG pipeline integrated into respond function
48
+ def respond(
49
+ message,
50
+ history: list[tuple[str, str]],
51
+ system_message,
52
+ max_tokens,
53
+ temperature,
54
+ top_p, # Keeping top_p as an input, though Gemini doesn’t use it directly
55
+ ):
56
+ # Preprocess the user message
57
  preprocessed_query = preprocess_text(message)
58
+
59
+ # Retrieve relevant documents
60
  retrieved_docs = retrieve_docs(preprocessed_query, k=5)
61
+ context = "\n".join(retrieved_docs['text'].tolist())
62
+
63
+ # Construct the prompt with system message, history, and RAG context
 
 
64
  prompt = f"{system_message}\n\n"
65
+ for user_msg, assistant_msg in history:
66
+ if user_msg:
67
+ prompt += f"User: {user_msg}\n"
68
+ if assistant_msg:
69
+ prompt += f"Assistant: {assistant_msg}\n"
70
  prompt += (
71
  f"Query: {message}\n"
72
  f"Relevant Context: {context}\n"
73
  f"Generate a short, concise, and to-the-point response to the query based only on the provided context."
74
  )
75
+
76
+ # Generate response with Gemini
77
  response = model.generate_content(
78
  prompt,
79
  generation_config=genai.types.GenerationConfig(
 
89
  else:
90
  answer += "."
91
 
92
+ # Format the output with Gradio markdown for better readability
93
  formatted_answer = f"""
94
+ <div style='background-color:#f0f0f0; padding: 10px; border-radius: 5px;'>
95
+ <h3 style='color:#333; font-weight:bold;'>Assistant's Response:</h3>
96
+ <p style='color:#555;'>{answer}</p>
97
+ </div>
98
+ """
99
+ # Yield the formatted response
100
+ yield formatted_answer
101
+
102
+ # Gradio Chat Interface
103
+ demo = gr.ChatInterface(
104
+ respond,
105
+ additional_inputs=[
 
 
 
 
 
 
 
 
 
106
  gr.Textbox(
107
  value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
108
+ label="System message"
109
  ),
110
+ gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
111
  gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
112
+ gr.Slider(
113
+ minimum=0.1,
114
+ maximum=1.0,
115
+ value=0.95,
116
+ step=0.05,
117
+ label="Top-p (nucleus sampling)", # Included but not used by Gemini
118
+ ),
119
  ],
120
+ title="🏥 Medical Chat Assistant",
121
+ description="A chat-based medical assistant that diagnoses patient queries using AI and past records."
 
122
  )
123
 
124
  if __name__ == "__main__":
125
+ demo.launch()