thechaiexperiment commited on
Commit
6e8f9d4
·
verified ·
1 Parent(s): 6e0f0b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -55
app.py CHANGED
@@ -346,52 +346,45 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
346
  print(f"Error reranking documents: {e}")
347
  return []
348
 
349
- def extract_entities(text, ner_pipeline=None):
350
- try:
351
- if ner_pipeline is None:
352
- ner_pipeline = models['ner_pipeline']
353
- ner_results = ner_pipeline(text)
354
- entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
355
- return list(entities)
356
- except Exception as e:
357
- print(f"Error extracting entities: {e}")
358
- return []
359
 
360
- def match_entities(query_entities, sentence_entities):
 
 
 
 
361
  try:
362
- query_set, sentence_set = set(query_entities), set(sentence_entities)
363
- matches = query_set.intersection(sentence_set)
364
- return len(matches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  except Exception as e:
366
- print(f"Error matching entities: {e}")
367
- return 0
368
-
369
- def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
370
- relevant_portions = {}
371
- query_entities = extract_entities(query)
372
- print(f"Extracted Query Entities: {query_entities}")
373
- for doc_id, doc_text in enumerate(document_texts):
374
- sentences = nltk.sent_tokenize(doc_text)
375
- doc_relevant_portions = []
376
- doc_entities = extract_entities(doc_text)
377
- print(f"Document {doc_id} Entities: {doc_entities}")
378
- for i, sentence in enumerate(sentences):
379
- sentence_entities = extract_entities(sentence)
380
- relevance_score = match_entities(query_entities, sentence_entities)
381
- if relevance_score >= min_query_words:
382
- start_idx = max(0, i - portion_size // 2)
383
- end_idx = min(len(sentences), i + portion_size // 2 + 1)
384
- portion = " ".join(sentences[start_idx:end_idx])
385
- doc_relevant_portions.append(portion)
386
- if len(doc_relevant_portions) >= max_portions:
387
- break
388
- if not doc_relevant_portions and len(doc_entities) > 0:
389
- print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
390
- sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
391
- for fallback_sentence in sorted_sentences[:max_portions]:
392
- doc_relevant_portions.append(fallback_sentence)
393
- relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
394
- return relevant_portions
395
 
396
  def remove_duplicates(selected_parts):
397
  unique_sentences = set()
@@ -426,11 +419,8 @@ def enhance_passage_with_entities(passage, entities):
426
  def create_prompt(question, passage):
427
  prompt = ("""
428
  As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
429
-
430
  Passage: {passage}
431
-
432
  Question: {question}
433
-
434
  Answer:
435
  """)
436
  return prompt.format(passage=passage, question=question)
@@ -520,46 +510,69 @@ async def health_check():
520
  async def chat_endpoint(chat_query: ChatQuery):
521
  try:
522
  query_text = chat_query.query
523
- language_code = chat_query.language_code
 
 
524
  if language_code == 0:
525
  query_text = translate_ar_to_en(query_text)
 
 
526
  query_embedding = embed_query_text(query_text)
527
  n_results = 5
528
- embeddings_data = load_embeddings ()
 
 
529
  folder_path = 'downloaded_articles/downloaded_articles'
530
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
531
- document_ids = [doc_id for doc_id, _ in initial_results]
 
 
532
  document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
533
  cross_encoder = models['cross_encoder']
534
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
 
 
535
  scored_documents = list(zip(scores, document_ids, document_texts))
536
  scored_documents.sort(key=lambda x: x[0], reverse=True)
537
- relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
538
- flattened_relevant_portions = []
539
- for doc_id, portions in relevant_portions.items():
540
- flattened_relevant_portions.extend(portions)
541
- unique_selected_parts = remove_duplicates(flattened_relevant_portions)
542
  combined_parts = " ".join(unique_selected_parts)
 
 
543
  context = [query_text] + unique_selected_parts
544
  entities = extract_entities(query_text)
545
  passage = enhance_passage_with_entities(combined_parts, entities)
 
 
546
  prompt = create_prompt(query_text, passage)
547
  answer = generate_answer(prompt)
548
  answer_part = answer.split("Answer:")[-1].strip()
 
 
549
  cleaned_answer = remove_answer_prefix(answer_part)
550
  final_answer = remove_incomplete_sentence(cleaned_answer)
 
 
551
  if language_code == 0:
552
  final_answer = translate_en_to_ar(final_answer)
 
 
553
  if final_answer:
554
  print("Answer:")
555
  print(final_answer)
556
  else:
557
  print("Sorry, I can't help with that.")
 
558
  return {
559
  "response": f"I hope this answers your question: {final_answer}",
560
- # "conversation_id": chat_query.conversation_id,
561
  "success": True
562
  }
 
563
  except Exception as e:
564
  raise HTTPException(status_code=500, detail=str(e))
565
 
 
346
  print(f"Error reranking documents: {e}")
347
  return []
348
 
 
 
 
 
 
 
 
 
 
 
349
 
350
+ from sentence_transformers import SentenceTransformer
351
+ from sklearn.metrics.pairwise import cosine_similarity
352
+ import nltk
353
+
354
+ def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
355
  try:
356
+ relevant_portions = {}
357
+
358
+ for _, doc_id, doc_text in top_documents:
359
+ if doc_id not in embeddings_data:
360
+ print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
361
+ continue
362
+
363
+ # Retrieve the precomputed embedding for this document
364
+ doc_embedding = np.array(embeddings_data[doc_id])
365
+
366
+ # Compute similarity between the query embedding and the document embedding
367
+ similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
368
+
369
+ # Split the document into sentences
370
+ sentences = nltk.sent_tokenize(doc_text)
371
+
372
+ # Rank sentences based on their length (proxy for importance) or other heuristic
373
+ # Since we're using document-level embeddings, we assume all sentences are equally relevant.
374
+ sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
375
+
376
+ relevant_portions[doc_id] = sorted_sentences
377
+
378
+ print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
379
+ for i, sentence in enumerate(sorted_sentences, start=1):
380
+ print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
381
+
382
+ return relevant_portions
383
+
384
  except Exception as e:
385
+ print(f"Error in extract_relevant_portions: {e}")
386
+ return {}
387
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  def remove_duplicates(selected_parts):
390
  unique_sentences = set()
 
419
  def create_prompt(question, passage):
420
  prompt = ("""
421
  As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
 
422
  Passage: {passage}
 
423
  Question: {question}
 
424
  Answer:
425
  """)
426
  return prompt.format(passage=passage, question=question)
 
510
  async def chat_endpoint(chat_query: ChatQuery):
511
  try:
512
  query_text = chat_query.query
513
+ language_code = chat_query.language_code
514
+
515
+ # Translate Arabic to English if language_code is 0
516
  if language_code == 0:
517
  query_text = translate_ar_to_en(query_text)
518
+
519
+ # Generate query embedding
520
  query_embedding = embed_query_text(query_text)
521
  n_results = 5
522
+
523
+ # Load embeddings and retrieve initial results
524
+ embeddings_data = load_embeddings()
525
  folder_path = 'downloaded_articles/downloaded_articles'
526
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
527
+
528
+ # Extract document IDs and texts
529
+ document_ids = [doc_id for doc_id, *_ in initial_results]
530
  document_texts = retrieve_document_texts(document_ids, folder_path)
531
+
532
+ # Use cross-encoder to score documents
533
  cross_encoder = models['cross_encoder']
534
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
535
+
536
+ # Score and sort documents
537
  scored_documents = list(zip(scores, document_ids, document_texts))
538
  scored_documents.sort(key=lambda x: x[0], reverse=True)
539
+
540
+ # Extract relevant portions
541
+ relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
542
+ unique_selected_parts = remove_duplicates(relevant_portions)
 
543
  combined_parts = " ".join(unique_selected_parts)
544
+
545
+ # Build context and enhance passage with entities
546
  context = [query_text] + unique_selected_parts
547
  entities = extract_entities(query_text)
548
  passage = enhance_passage_with_entities(combined_parts, entities)
549
+
550
+ # Create prompt and generate answer
551
  prompt = create_prompt(query_text, passage)
552
  answer = generate_answer(prompt)
553
  answer_part = answer.split("Answer:")[-1].strip()
554
+
555
+ # Clean and finalize the answer
556
  cleaned_answer = remove_answer_prefix(answer_part)
557
  final_answer = remove_incomplete_sentence(cleaned_answer)
558
+
559
+ # Translate English back to Arabic if needed
560
  if language_code == 0:
561
  final_answer = translate_en_to_ar(final_answer)
562
+
563
+ # Print and return the answer
564
  if final_answer:
565
  print("Answer:")
566
  print(final_answer)
567
  else:
568
  print("Sorry, I can't help with that.")
569
+
570
  return {
571
  "response": f"I hope this answers your question: {final_answer}",
572
+ # "conversation_id": chat_query.conversation_id, # Uncomment if needed
573
  "success": True
574
  }
575
+
576
  except Exception as e:
577
  raise HTTPException(status_code=500, detail=str(e))
578