thechaiexperiment commited on
Commit
9593a8e
·
verified ·
1 Parent(s): bca5800

Update medical_rag.py

Browse files
Files changed (1) hide show
  1. medical_rag.py +171 -79
medical_rag.py CHANGED
@@ -32,9 +32,10 @@ def load_medical_models():
32
  print(f"Error loading medical models: {e}")
33
  return False
34
 
35
- def extract_entities(text):
36
  try:
37
- ner_pipeline = models['ner_pipeline']
 
38
  ner_results = ner_pipeline(text)
39
  entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
40
  return list(entities)
@@ -55,13 +56,11 @@ def extract_relevant_portions(document_texts, query, max_portions=3, portion_siz
55
  relevant_portions = {}
56
  query_entities = extract_entities(query)
57
  print(f"Extracted Query Entities: {query_entities}")
58
-
59
  for doc_id, doc_text in enumerate(document_texts):
60
  sentences = nltk.sent_tokenize(doc_text)
61
  doc_relevant_portions = []
62
  doc_entities = extract_entities(doc_text)
63
  print(f"Document {doc_id} Entities: {doc_entities}")
64
-
65
  for i, sentence in enumerate(sentences):
66
  sentence_entities = extract_entities(sentence)
67
  relevance_score = match_entities(query_entities, sentence_entities)
@@ -72,26 +71,47 @@ def extract_relevant_portions(document_texts, query, max_portions=3, portion_siz
72
  doc_relevant_portions.append(portion)
73
  if len(doc_relevant_portions) >= max_portions:
74
  break
75
-
76
  if not doc_relevant_portions and len(doc_entities) > 0:
77
  print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
78
- sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s)), reverse=True)
79
  for fallback_sentence in sorted_sentences[:max_portions]:
80
  doc_relevant_portions.append(fallback_sentence)
81
-
82
  relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
83
  return relevant_portions
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def enhance_passage_with_entities(passage, entities):
86
  return f"{passage}\n\nEntities: {', '.join(entities)}"
87
 
88
- def create_medical_prompt(question, passage):
89
  prompt = ("""
90
- As a medical expert, you are required to answer the following question based only on the provided passage.
91
- Do not include any information not present in the passage.
92
- Your response should directly reflect the content of the passage.
93
- Maintain accuracy and relevance to the provided information.
94
- Provide a medically reliable answer in no more than 250 words.
95
 
96
  Passage: {passage}
97
 
@@ -101,112 +121,184 @@ def create_medical_prompt(question, passage):
101
  """)
102
  return prompt.format(passage=passage, question=question)
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  @app.post("/api/chat")
105
  async def chat_endpoint(chat_query: ChatQuery):
106
  try:
107
- print("\n=== STARTING CHAT REQUEST PROCESSING ===")
108
- print(f"Initial query: {chat_query.query} (language_code: {chat_query.language_code})")
109
-
110
- # Step 1: Handle translation if needed
111
  query_text = chat_query.query
112
- language_code = chat_query.language_code
113
  if language_code == 0:
114
- print("Translating from Arabic to English...")
115
- query_text = translate_text(query_text, 'ar_to_en')
116
- print(f"Translated query: {query_text}")
117
 
118
- # Step 2: Generate embeddings
119
- print("\nGenerating query embeddings...")
120
  query_embedding = embed_query_text(query_text)
121
- print(f"Embedding generated. Shape: {query_embedding.shape}")
122
-
123
- # Step 3: Load embeddings and query them
124
- print("\nLoading document embeddings...")
125
- embeddings_data = load_embeddings()
126
- if not embeddings_data:
127
- raise HTTPException(status_code=500, detail="Failed to load embeddings data")
128
- print(f"Loaded embeddings for {len(embeddings_data)} documents")
129
-
130
- print("\nQuerying embeddings...")
131
  n_results = 5
 
 
132
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
133
- print(f"Initial results: {initial_results}")
134
-
135
  document_ids = [doc_id for doc_id, _ in initial_results]
136
- print(f"Document IDs to retrieve: {document_ids}")
137
-
138
- # Step 4: Retrieve document texts
139
- print("\nRetrieving document texts...")
140
- folder_path = 'downloaded_articles/downloaded_articles'
141
  document_texts = retrieve_document_texts(document_ids, folder_path)
142
- print(f"Retrieved {len(document_texts)} documents")
143
-
144
- # Step 5: Rerank documents
145
- print("\nReranking documents...")
146
- if 'cross_encoder' not in models:
147
- raise HTTPException(status_code=500, detail="Cross-encoder model not loaded")
148
 
 
149
  cross_encoder = models['cross_encoder']
150
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
151
  scored_documents = list(zip(scores, document_ids, document_texts))
152
  scored_documents.sort(key=lambda x: x[0], reverse=True)
153
- print("Top 3 reranked documents:")
154
- for i, (score, doc_id, _) in enumerate(scored_documents[:3]):
155
- print(f"{i+1}. Doc {doc_id} (score: {score:.4f})")
156
-
157
- # Step 6: Extract relevant portions
158
- print("\nExtracting relevant portions...")
159
- relevant_portions = extract_relevant_portions(document_texts, query_text)
160
- print(f"Found relevant portions in {len(relevant_portions)} documents")
161
 
 
 
162
  flattened_relevant_portions = []
163
  for doc_id, portions in relevant_portions.items():
164
  flattened_relevant_portions.extend(portions)
 
 
165
 
166
- combined_parts = " ".join(flattened_relevant_portions)
167
- print(f"Combined relevant text length: {len(combined_parts)} characters")
168
-
169
- # Step 7: Extract and enhance with entities
170
- print("\nExtracting entities...")
171
  entities = extract_entities(query_text)
172
- print(f"Found entities: {entities}")
173
-
174
  passage = enhance_passage_with_entities(combined_parts, entities)
175
- print(f"Enhanced passage length: {len(passage)} characters")
176
 
177
- # Step 8: Generate response
178
- print("\nCreating prompt...")
179
- prompt = create_medical_prompt(query_text, passage)
180
- print(f"Prompt length: {len(prompt)} characters")
181
 
182
- print("\nGetting completion from DeepSeek...")
183
- answer = get_completion(prompt)
184
- print(f"Raw answer received: {answer[:200]}...") # Print first 200 chars
185
 
186
- # Step 9: Final processing
 
 
 
187
  final_answer = answer.strip()
188
  if language_code == 0:
189
- print("\nTranslating answer to Arabic...")
190
- final_answer = translate_text(final_answer, 'en_to_ar')
191
- print(f"Translated answer: {final_answer[:200]}...")
192
 
193
  if not final_answer:
194
  final_answer = "Sorry, I can't help with that."
195
- print("Warning: Empty answer received")
196
-
197
- print("\n=== REQUEST PROCESSING COMPLETE ===")
198
  return {
199
  "response": f"I hope this answers your question: {final_answer}",
200
  "success": True
201
  }
202
 
203
  except HTTPException as e:
204
- print(f"\n!!! HTTPException: {e.detail}")
205
  raise e
206
  except Exception as e:
207
- print(f"\n!!! Unexpected error: {str(e)}")
208
- print(f"Error type: {type(e).__name__}")
209
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  # Initialize medical models when this module is imported
212
  load_medical_models()
 
32
  print(f"Error loading medical models: {e}")
33
  return False
34
 
35
+ def extract_entities(text, ner_pipeline=None):
36
  try:
37
+ if ner_pipeline is None:
38
+ ner_pipeline = models['ner_pipeline']
39
  ner_results = ner_pipeline(text)
40
  entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
41
  return list(entities)
 
56
  relevant_portions = {}
57
  query_entities = extract_entities(query)
58
  print(f"Extracted Query Entities: {query_entities}")
 
59
  for doc_id, doc_text in enumerate(document_texts):
60
  sentences = nltk.sent_tokenize(doc_text)
61
  doc_relevant_portions = []
62
  doc_entities = extract_entities(doc_text)
63
  print(f"Document {doc_id} Entities: {doc_entities}")
 
64
  for i, sentence in enumerate(sentences):
65
  sentence_entities = extract_entities(sentence)
66
  relevance_score = match_entities(query_entities, sentence_entities)
 
71
  doc_relevant_portions.append(portion)
72
  if len(doc_relevant_portions) >= max_portions:
73
  break
 
74
  if not doc_relevant_portions and len(doc_entities) > 0:
75
  print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
76
+ sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
77
  for fallback_sentence in sorted_sentences[:max_portions]:
78
  doc_relevant_portions.append(fallback_sentence)
 
79
  relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
80
  return relevant_portions
81
 
82
+ def remove_duplicates(selected_parts):
83
+ unique_sentences = set()
84
+ unique_selected_parts = []
85
+ for sentence in selected_parts:
86
+ if sentence not in unique_sentences:
87
+ unique_selected_parts.append(sentence)
88
+ unique_sentences.add(sentence)
89
+ return unique_selected_parts
90
+
91
+ def extract_entities(text):
92
+ try:
93
+ biobert_tokenizer = models['bio_tokenizer']
94
+ biobert_model = models['bio_model']
95
+ inputs = biobert_tokenizer(text, return_tensors="pt")
96
+ outputs = biobert_model(**inputs)
97
+ predictions = torch.argmax(outputs.logits, dim=2)
98
+ tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
99
+ entities = [
100
+ tokens[i]
101
+ for i in range(len(tokens))
102
+ if predictions[0][i].item() != 0 # Assuming 0 is the label for non-entity
103
+ ]
104
+ return entities
105
+ except Exception as e:
106
+ print(f"Error extracting entities: {e}")
107
+ return []
108
+
109
  def enhance_passage_with_entities(passage, entities):
110
  return f"{passage}\n\nEntities: {', '.join(entities)}"
111
 
112
+ def create_prompt(question, passage):
113
  prompt = ("""
114
+ As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
 
 
 
 
115
 
116
  Passage: {passage}
117
 
 
121
  """)
122
  return prompt.format(passage=passage, question=question)
123
 
124
+ def generate_answer(prompt, max_length=860, temperature=0.2):
125
+ tokenizer_f = models['llm_tokenizer']
126
+ model_f = models['llm_model']
127
+ inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
128
+ output_ids = model_f.generate(
129
+ inputs.input_ids,
130
+ max_length=max_length,
131
+ num_return_sequences=1,
132
+ temperature=temperature,
133
+ pad_token_id=tokenizer_f.eos_token_id
134
+ )
135
+ answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
136
+ passage_keywords = set(prompt.lower().split())
137
+ answer_keywords = set(answer.lower().split())
138
+ if passage_keywords.intersection(answer_keywords):
139
+ return answer
140
+ else:
141
+ return "Sorry, I can't help with that."
142
+
143
+ def remove_answer_prefix(text):
144
+ prefix = "Answer:"
145
+ if prefix in text:
146
+ return text.split(prefix, 1)[-1].strip()
147
+ return text
148
+
149
+ def remove_incomplete_sentence(text):
150
+ if not text.endswith('.'):
151
+ last_period_index = text.rfind('.')
152
+ if last_period_index != -1:
153
+ return text[:last_period_index + 1].strip()
154
+ return text
155
+
156
  @app.post("/api/chat")
157
  async def chat_endpoint(chat_query: ChatQuery):
158
  try:
 
 
 
 
159
  query_text = chat_query.query
160
+ language_code = chat_query.language_code
161
  if language_code == 0:
162
+ query_text = translate_ar_to_en(query_text)
 
 
163
 
164
+ # Generate embeddings and retrieve relevant documents (original RAG logic)
 
165
  query_embedding = embed_query_text(query_text)
 
 
 
 
 
 
 
 
 
 
166
  n_results = 5
167
+ embeddings_data = load_embeddings()
168
+ folder_path = 'downloaded_articles/downloaded_articles'
169
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
 
 
170
  document_ids = [doc_id for doc_id, _ in initial_results]
 
 
 
 
 
171
  document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
 
 
 
 
172
 
173
+ # Rerank documents with cross-encoder
174
  cross_encoder = models['cross_encoder']
175
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
176
  scored_documents = list(zip(scores, document_ids, document_texts))
177
  scored_documents.sort(key=lambda x: x[0], reverse=True)
 
 
 
 
 
 
 
 
178
 
179
+ # Extract relevant portions from documents
180
+ relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
181
  flattened_relevant_portions = []
182
  for doc_id, portions in relevant_portions.items():
183
  flattened_relevant_portions.extend(portions)
184
+ unique_selected_parts = remove_duplicates(flattened_relevant_portions)
185
+ combined_parts = " ".join(unique_selected_parts)
186
 
187
+ # Enhance context with entities
 
 
 
 
188
  entities = extract_entities(query_text)
 
 
189
  passage = enhance_passage_with_entities(combined_parts, entities)
 
190
 
191
+ # Create prompt and generate answer using OpenRouter
192
+ prompt = create_prompt(query_text, passage)
 
 
193
 
194
+ # Add constraints similar to /api/ask endpoint
195
+ constraints = "Provide a medically reliable answer in no more than 250 words."
196
+ full_prompt = f"{prompt} {constraints}"
197
 
198
+ # Use the same OpenRouter model as /api/ask
199
+ answer = get_completion(full_prompt)
200
+
201
+ # Process the answer
202
  final_answer = answer.strip()
203
  if language_code == 0:
204
+ final_answer = translate_en_to_ar(final_answer)
 
 
205
 
206
  if not final_answer:
207
  final_answer = "Sorry, I can't help with that."
208
+
 
 
209
  return {
210
  "response": f"I hope this answers your question: {final_answer}",
211
  "success": True
212
  }
213
 
214
  except HTTPException as e:
 
215
  raise e
216
  except Exception as e:
 
 
217
  raise HTTPException(status_code=500, detail=str(e))
218
+
219
+
220
+ @app.post("/api/resources")
221
+ async def resources_endpoint(profile: MedicalProfile):
222
+ try:
223
+ query_text = profile.conditions + " " + profile.daily_symptoms
224
+ n_results = profile.count
225
+ print(f"Generated query text: {query_text}")
226
+ query_embedding = embed_query_text(query_text)
227
+ if query_embedding is None:
228
+ raise ValueError("Failed to generate query embedding.")
229
+ embeddings_data = load_embeddings()
230
+ folder_path = 'downloaded_articles/downloaded_articles'
231
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
232
+ if not initial_results:
233
+ raise ValueError("No relevant documents found.")
234
+ document_ids = [doc_id for doc_id, _ in initial_results]
235
+ file_path = 'finalcleaned_excel_file.xlsx'
236
+ df = pd.read_excel(file_path)
237
+ file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
238
+ resources = []
239
+ for file_name in document_ids:
240
+ original_url = file_name_to_url.get(file_name, None)
241
+ if original_url:
242
+ title = get_page_title(original_url) or "Unknown Title"
243
+ resources.append({"file_name": file_name, "title": title, "url": original_url})
244
+ else:
245
+ resources.append({"file_name": file_name, "title": "Unknown", "url": None})
246
+ document_texts = retrieve_document_texts(document_ids, folder_path)
247
+ if not document_texts:
248
+ raise ValueError("Failed to retrieve document texts.")
249
+ cross_encoder = models['cross_encoder']
250
+ scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
251
+ scores = [float(score) for score in scores]
252
+ for i, resource in enumerate(resources):
253
+ resource["score"] = scores[i] if i < len(scores) else 0.0
254
+ resources.sort(key=lambda x: x["score"], reverse=True)
255
+ output = [{"title": resource["title"], "url": resource["url"]} for resource in resources]
256
+ return output
257
+ except ValueError as ve:
258
+ raise HTTPException(status_code=400, detail=str(ve))
259
+ except Exception as e:
260
+ print(f"Unexpected error: {e}")
261
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
262
+
263
+ @app.post("/api/recipes")
264
+ async def recipes_endpoint(profile: MedicalProfile):
265
+ try:
266
+ recipe_query = (
267
+ f"Recipes and foods for: "
268
+ f"{profile.conditions} and experiencing {profile.daily_symptoms}"
269
+ )
270
+ query_text = recipe_query
271
+ print(f"Generated query text: {query_text}")
272
+ n_results = profile.count
273
+ query_embedding = embed_query_text(query_text)
274
+ if query_embedding is None:
275
+ raise ValueError("Failed to generate query embedding.")
276
+ embeddings_data = load_recipes_embeddings()
277
+ folder_path = 'downloaded_articles/downloaded_articles'
278
+ initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results)
279
+ if not initial_results:
280
+ raise ValueError("No relevant recipes found.")
281
+ print("Initial results (document indices and similarities):")
282
+ print(initial_results)
283
+ document_indices = [doc_id for doc_id, _ in initial_results]
284
+ print("Document indices:", document_indices)
285
+ metadata_path = 'recipes_metadata.xlsx'
286
+ metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
287
+ print(f"Retrieved Metadata: {metadata}")
288
+ recipes = []
289
+ for item in metadata.values():
290
+ recipes.append({
291
+ "title": item["original_file_name"] if "original_file_name" in item else "Unknown Title",
292
+ "url": item["url"] if "url" in item else ""
293
+ })
294
+ print(recipes)
295
+ return recipes
296
+ except ValueError as ve:
297
+ raise HTTPException(status_code=400, detail=str(ve))
298
+ except Exception as e:
299
+ print(f"Unexpected error: {e}")
300
+ raise HTTPException(status_code=500, detail="An unexpected error occurred.")
301
+
302
 
303
  # Initialize medical models when this module is imported
304
  load_medical_models()