thechaiexperiment commited on
Commit
88b8fb2
·
1 Parent(s): c77cccd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -50
app.py CHANGED
@@ -170,54 +170,10 @@ def translate_text(text, source_to_target='ar_to_en'):
170
  print(f"Translation error: {e}")
171
  return text
172
 
173
- def extract_entities(text):
174
- """Extract medical entities from text using NER"""
175
- try:
176
- results = models['ner_pipeline'](text)
177
- return list({result['word'] for result in results if result['entity'].startswith("B-")})
178
- except Exception as e:
179
- print(f"Error extracting entities: {e}")
180
- return []
181
-
182
- def generate_answer(query, context, max_length=860, temperature=0.2):
183
- """Generate answer using LLM"""
184
- try:
185
- prompt = f"""
186
- As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
187
-
188
- Context: {context}
189
-
190
- Question: {query}
191
-
192
- Answer: Let me help you with accurate information from reliable medical sources."""
193
-
194
- inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
195
-
196
- with torch.no_grad():
197
- outputs = models['llm_model'].generate(
198
- inputs.input_ids,
199
- max_length=max_length,
200
- num_return_sequences=1,
201
- temperature=temperature,
202
- do_sample=True,
203
- top_p=0.9,
204
- pad_token_id=models['llm_tokenizer'].eos_token_id
205
- )
206
-
207
- response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
208
-
209
- if "Answer:" in response:
210
- response = response.split("Answer:")[-1].strip()
211
-
212
- sentences = nltk.sent_tokenize(response)
213
- if sentences:
214
- return " ".join(sentences)
215
- return response
216
-
217
- except Exception as e:
218
- print(f"Error generating answer: {e}")
219
- return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
220
-
221
  def query_embeddings(query_embedding, n_results=5):
222
  """Find relevant documents using embedding similarity"""
223
  if not data['embeddings']:
@@ -258,6 +214,160 @@ def rerank_documents(query, doc_texts):
258
  print(f"Error reranking documents: {e}")
259
  return np.zeros(len(doc_texts))
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  @app.get("/")
262
  async def root():
263
  return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
@@ -287,8 +397,8 @@ async def chat_endpoint(chat_query: ChatQuery):
287
  rerank_scores = rerank_documents(query_text, doc_texts)
288
  ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
289
 
290
- context = " ".join(ranked_texts[:3])
291
- answer = generate_answer(query_text, context)
292
 
293
  return {
294
  "response": answer,
 
170
  print(f"Translation error: {e}")
171
  return text
172
 
173
+ def embed_query_text(query_text):
174
+ query_embedding = embedding_model.encode([query_text])
175
+ return query_embedding
176
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def query_embeddings(query_embedding, n_results=5):
178
  """Find relevant documents using embedding similarity"""
179
  if not data['embeddings']:
 
214
  print(f"Error reranking documents: {e}")
215
  return np.zeros(len(doc_texts))
216
 
217
+ def extract_entities(text):
218
+ """Extract medical entities from text using NER"""
219
+ try:
220
+ results = models['ner_pipeline'](text)
221
+ return list({result['word'] for result in results if result['entity'].startswith("B-")})
222
+ except Exception as e:
223
+ print(f"Error extracting entities: {e}")
224
+ return []
225
+ def match_entities(query_entities, sentence_entities):
226
+ query_set, sentence_set = set(query_entities), set(sentence_entities)
227
+ matches = query_set.intersection(sentence_set)
228
+ return len(matches)
229
+
230
+ def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
231
+ relevant_portions = {}
232
+
233
+ # Extract entities from the query
234
+ query_entities = extract_entities(query, ner_biobert)
235
+ print(f"Extracted Query Entities: {query_entities}")
236
+ for doc_id, doc_text in enumerate(document_texts):
237
+ sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
238
+ doc_relevant_portions = []
239
+
240
+ # Extract entities from the entire document
241
+ doc_entities = extract_entities(doc_text, ner_biobert)
242
+ print(f"Document {doc_id} Entities: {doc_entities}")
243
+
244
+ for i, sentence in enumerate(sentences):
245
+ # Extract entities from the sentence
246
+ sentence_entities = extract_entities(sentence, ner_biobert)
247
+
248
+ # Compute relevance score
249
+ relevance_score = match_entities(query_entities, sentence_entities)
250
+
251
+ # Select sentences with at least `min_query_words` matching entities
252
+ if relevance_score >= min_query_words:
253
+ start_idx = max(0, i - portion_size // 2)
254
+ end_idx = min(len(sentences), i + portion_size // 2 + 1)
255
+ portion = " ".join(sentences[start_idx:end_idx])
256
+ doc_relevant_portions.append(portion)
257
+ if len(doc_relevant_portions) >= max_portions:
258
+ break
259
+
260
+ # Add fallback to include the most entity-dense sentences if no results
261
+ if not doc_relevant_portions and len(doc_entities) > 0:
262
+ print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
263
+ sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
264
+ for fallback_sentence in sorted_sentences[:max_portions]:
265
+ doc_relevant_portions.append(fallback_sentence)
266
+
267
+ relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
268
+
269
+ return relevant_portions
270
+ def remove_duplicates(selected_parts):
271
+ unique_sentences = set()
272
+ unique_selected_parts = []
273
+
274
+ for sentence in selected_parts:
275
+ if sentence not in unique_sentences:
276
+ unique_selected_parts.append(sentence)
277
+ unique_sentences.add(sentence)
278
+
279
+ return unique_selected_parts
280
+
281
+ # Flatten the dictionary of relevant portions (from earlier code)
282
+ flattened_relevant_portions = []
283
+ for doc_id, portions in relevant_portions.items():
284
+ flattened_relevant_portions.extend(portions)
285
+
286
+ # Remove duplicate portions
287
+ unique_selected_parts = remove_duplicates(flattened_relevant_portions)
288
+
289
+ # Combine the unique parts into a single string of context
290
+ combined_parts = " ".join(unique_selected_parts)
291
+
292
+ # Construct context as a list: first the query, then the unique selected portions
293
+ context = [query_text] + unique_selected_parts
294
+
295
+ def extract_entities(text):
296
+ inputs = biobert_tokenizer(text, return_tensors="pt")
297
+ outputs = biobert_model(**inputs)
298
+ predictions = torch.argmax(outputs.logits, dim=2)
299
+ tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
300
+ entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
301
+ return entities
302
+
303
+ def enhance_passage_with_entities(passage, entities):
304
+ # Example: Add entities to the passage for better context
305
+ return f"{passage}\n\nEntities: {', '.join(entities)}"
306
+
307
+ def create_prompt(question, passage):
308
+ prompt = ("""
309
+ As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
310
+
311
+ Passage: {passage}
312
+
313
+ Question: {question}
314
+
315
+ Answer:
316
+ """)
317
+ return prompt.format(passage=passage, question=question)
318
+
319
+ def generate_answer(prompt, max_length=860, temperature=0.2):
320
+ inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
321
+
322
+ # Start timing
323
+ start_time = time.time()
324
+
325
+ output_ids = model_f.generate(
326
+ inputs.input_ids,
327
+ max_length=max_length,
328
+ num_return_sequences=1,
329
+ temperature=temperature,
330
+ pad_token_id=tokenizer_f.eos_token_id
331
+ )
332
+
333
+ # End timing
334
+ end_time = time.time()
335
+
336
+ # Calculate the duration
337
+ duration = end_time - start_time
338
+
339
+ # Decode the answer
340
+ answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
341
+
342
+ passage_keywords = set(passage.lower().split())
343
+ answer_keywords = set(answer.lower().split())
344
+
345
+ if passage_keywords.intersection(answer_keywords):
346
+ return answer, duration
347
+ else:
348
+ return "Sorry, I can't help with that.", duration
349
+
350
+ def remove_answer_prefix(text):
351
+ prefix = "Answer:"
352
+ if prefix in text:
353
+ return text.split(prefix)[-1].strip()
354
+ return text
355
+
356
+ def remove_incomplete_sentence(text):
357
+ # Check if the text ends with a period
358
+ if not text.endswith('.'):
359
+ # Find the last period or the end of the string
360
+ last_period_index = text.rfind('.')
361
+ if last_period_index != -1:
362
+ # Remove everything after the last period
363
+ return text[:last_period_index + 1].strip()
364
+ return text
365
+
366
+ answer_part = answer.split("Answer:")[-1].strip()
367
+ cleaned_answer = remove_answer_prefix(answer_part)
368
+ final_answer = remove_incomplete_sentence(cleaned_answer)
369
+
370
+
371
  @app.get("/")
372
  async def root():
373
  return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
 
397
  rerank_scores = rerank_documents(query_text, doc_texts)
398
  ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
399
 
400
+ context = [query_text] + unique_selected_parts
401
+ answer = remove_incomplete_sentence(query_text, context)
402
 
403
  return {
404
  "response": answer,