thechaiexperiment commited on
Commit
6e0f0b6
·
verified ·
1 Parent(s): 484270b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -76
app.py CHANGED
@@ -28,12 +28,9 @@ from sklearn.metrics.pairwise import cosine_similarity
28
  from bs4 import BeautifulSoup
29
  from huggingface_hub import hf_hub_download
30
  from safetensors.torch import load_file
31
- from typing import List, Dict,Any,Tuple, Optional
32
  from safetensors.numpy import load_file
33
  from safetensors.torch import safe_open
34
- from concurrent.futures import ThreadPoolExecutor
35
- import asyncio
36
- from functools import partial
37
  nltk.download('punkt_tab')
38
 
39
  app = FastAPI()
@@ -66,11 +63,6 @@ class ChatMessage(BaseModel):
66
  content: str
67
  timestamp: str
68
 
69
- async def run_in_threadpool(func, *args, **kwargs):
70
- return await asyncio.get_event_loop().run_in_executor(
71
- None, partial(func, *args, **kwargs)
72
- )
73
-
74
  def init_nltk():
75
  try:
76
  nltk.download('punkt', quiet=True)
@@ -354,44 +346,52 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
354
  print(f"Error reranking documents: {e}")
355
  return []
356
 
357
- from sentence_transformers import SentenceTransformer
358
- from sklearn.metrics.pairwise import cosine_similarity
359
- import nltk
360
-
361
- def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
362
  try:
363
- relevant_portions = {}
364
-
365
- for _, doc_id, doc_text in top_documents:
366
- if doc_id not in embeddings_data:
367
- print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
368
- continue
369
-
370
- # Retrieve the precomputed embedding for this document
371
- doc_embedding = np.array(embeddings_data[doc_id])
372
-
373
- # Compute similarity between the query embedding and the document embedding
374
- similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
375
-
376
- # Split the document into sentences
377
- sentences = nltk.sent_tokenize(doc_text)
378
-
379
- # Rank sentences based on their length (proxy for importance) or other heuristic
380
- # Since we're using document-level embeddings, we assume all sentences are equally relevant.
381
- sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
382
-
383
- relevant_portions[doc_id] = sorted_sentences
384
-
385
- print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
386
- for i, sentence in enumerate(sorted_sentences, start=1):
387
- print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
388
-
389
- return relevant_portions
390
-
391
  except Exception as e:
392
- print(f"Error in extract_relevant_portions: {e}")
393
- return {}
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  def remove_duplicates(selected_parts):
397
  unique_sentences = set()
@@ -426,8 +426,11 @@ def enhance_passage_with_entities(passage, entities):
426
  def create_prompt(question, passage):
427
  prompt = ("""
428
  As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
 
429
  Passage: {passage}
 
430
  Question: {question}
 
431
  Answer:
432
  """)
433
  return prompt.format(passage=passage, question=question)
@@ -517,69 +520,46 @@ async def health_check():
517
  async def chat_endpoint(chat_query: ChatQuery):
518
  try:
519
  query_text = chat_query.query
520
- language_code = chat_query.language_code
521
-
522
- # Translate Arabic to English if language_code is 0
523
  if language_code == 0:
524
  query_text = translate_ar_to_en(query_text)
525
-
526
- # Generate query embedding
527
  query_embedding = embed_query_text(query_text)
528
  n_results = 5
529
-
530
- # Load embeddings and retrieve initial results
531
- embeddings_data = load_embeddings()
532
  folder_path = 'downloaded_articles/downloaded_articles'
533
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
534
-
535
- # Extract document IDs and texts
536
- document_ids = [doc_id for doc_id, *_ in initial_results]
537
  document_texts = retrieve_document_texts(document_ids, folder_path)
538
-
539
- # Use cross-encoder to score documents
540
  cross_encoder = models['cross_encoder']
541
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
542
-
543
- # Score and sort documents
544
  scored_documents = list(zip(scores, document_ids, document_texts))
545
  scored_documents.sort(key=lambda x: x[0], reverse=True)
546
-
547
- # Extract relevant portions
548
- relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
549
- unique_selected_parts = remove_duplicates(relevant_portions)
 
550
  combined_parts = " ".join(unique_selected_parts)
551
-
552
- # Build context and enhance passage with entities
553
  context = [query_text] + unique_selected_parts
554
  entities = extract_entities(query_text)
555
  passage = enhance_passage_with_entities(combined_parts, entities)
556
-
557
- # Create prompt and generate answer
558
  prompt = create_prompt(query_text, passage)
559
  answer = generate_answer(prompt)
560
  answer_part = answer.split("Answer:")[-1].strip()
561
-
562
- # Clean and finalize the answer
563
  cleaned_answer = remove_answer_prefix(answer_part)
564
  final_answer = remove_incomplete_sentence(cleaned_answer)
565
-
566
- # Translate English back to Arabic if needed
567
  if language_code == 0:
568
  final_answer = translate_en_to_ar(final_answer)
569
-
570
- # Print and return the answer
571
  if final_answer:
572
  print("Answer:")
573
  print(final_answer)
574
  else:
575
  print("Sorry, I can't help with that.")
576
-
577
  return {
578
  "response": f"I hope this answers your question: {final_answer}",
579
- # "conversation_id": chat_query.conversation_id, # Uncomment if needed
580
  "success": True
581
  }
582
-
583
  except Exception as e:
584
  raise HTTPException(status_code=500, detail=str(e))
585
 
 
28
  from bs4 import BeautifulSoup
29
  from huggingface_hub import hf_hub_download
30
  from safetensors.torch import load_file
31
+ from typing import List, Dict, Optional
32
  from safetensors.numpy import load_file
33
  from safetensors.torch import safe_open
 
 
 
34
  nltk.download('punkt_tab')
35
 
36
  app = FastAPI()
 
63
  content: str
64
  timestamp: str
65
 
 
 
 
 
 
66
  def init_nltk():
67
  try:
68
  nltk.download('punkt', quiet=True)
 
346
  print(f"Error reranking documents: {e}")
347
  return []
348
 
349
+ def extract_entities(text, ner_pipeline=None):
 
 
 
 
350
  try:
351
+ if ner_pipeline is None:
352
+ ner_pipeline = models['ner_pipeline']
353
+ ner_results = ner_pipeline(text)
354
+ entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
355
+ return list(entities)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  except Exception as e:
357
+ print(f"Error extracting entities: {e}")
358
+ return []
359
 
360
+ def match_entities(query_entities, sentence_entities):
361
+ try:
362
+ query_set, sentence_set = set(query_entities), set(sentence_entities)
363
+ matches = query_set.intersection(sentence_set)
364
+ return len(matches)
365
+ except Exception as e:
366
+ print(f"Error matching entities: {e}")
367
+ return 0
368
+
369
+ def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
370
+ relevant_portions = {}
371
+ query_entities = extract_entities(query)
372
+ print(f"Extracted Query Entities: {query_entities}")
373
+ for doc_id, doc_text in enumerate(document_texts):
374
+ sentences = nltk.sent_tokenize(doc_text)
375
+ doc_relevant_portions = []
376
+ doc_entities = extract_entities(doc_text)
377
+ print(f"Document {doc_id} Entities: {doc_entities}")
378
+ for i, sentence in enumerate(sentences):
379
+ sentence_entities = extract_entities(sentence)
380
+ relevance_score = match_entities(query_entities, sentence_entities)
381
+ if relevance_score >= min_query_words:
382
+ start_idx = max(0, i - portion_size // 2)
383
+ end_idx = min(len(sentences), i + portion_size // 2 + 1)
384
+ portion = " ".join(sentences[start_idx:end_idx])
385
+ doc_relevant_portions.append(portion)
386
+ if len(doc_relevant_portions) >= max_portions:
387
+ break
388
+ if not doc_relevant_portions and len(doc_entities) > 0:
389
+ print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
390
+ sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
391
+ for fallback_sentence in sorted_sentences[:max_portions]:
392
+ doc_relevant_portions.append(fallback_sentence)
393
+ relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
394
+ return relevant_portions
395
 
396
  def remove_duplicates(selected_parts):
397
  unique_sentences = set()
 
426
  def create_prompt(question, passage):
427
  prompt = ("""
428
  As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
429
+
430
  Passage: {passage}
431
+
432
  Question: {question}
433
+
434
  Answer:
435
  """)
436
  return prompt.format(passage=passage, question=question)
 
520
  async def chat_endpoint(chat_query: ChatQuery):
521
  try:
522
  query_text = chat_query.query
523
+ language_code = chat_query.language_code
 
 
524
  if language_code == 0:
525
  query_text = translate_ar_to_en(query_text)
 
 
526
  query_embedding = embed_query_text(query_text)
527
  n_results = 5
528
+ embeddings_data = load_embeddings ()
 
 
529
  folder_path = 'downloaded_articles/downloaded_articles'
530
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
531
+ document_ids = [doc_id for doc_id, _ in initial_results]
 
 
532
  document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
533
  cross_encoder = models['cross_encoder']
534
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
 
 
535
  scored_documents = list(zip(scores, document_ids, document_texts))
536
  scored_documents.sort(key=lambda x: x[0], reverse=True)
537
+ relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
538
+ flattened_relevant_portions = []
539
+ for doc_id, portions in relevant_portions.items():
540
+ flattened_relevant_portions.extend(portions)
541
+ unique_selected_parts = remove_duplicates(flattened_relevant_portions)
542
  combined_parts = " ".join(unique_selected_parts)
 
 
543
  context = [query_text] + unique_selected_parts
544
  entities = extract_entities(query_text)
545
  passage = enhance_passage_with_entities(combined_parts, entities)
 
 
546
  prompt = create_prompt(query_text, passage)
547
  answer = generate_answer(prompt)
548
  answer_part = answer.split("Answer:")[-1].strip()
 
 
549
  cleaned_answer = remove_answer_prefix(answer_part)
550
  final_answer = remove_incomplete_sentence(cleaned_answer)
 
 
551
  if language_code == 0:
552
  final_answer = translate_en_to_ar(final_answer)
 
 
553
  if final_answer:
554
  print("Answer:")
555
  print(final_answer)
556
  else:
557
  print("Sorry, I can't help with that.")
 
558
  return {
559
  "response": f"I hope this answers your question: {final_answer}",
560
+ # "conversation_id": chat_query.conversation_id,
561
  "success": True
562
  }
 
563
  except Exception as e:
564
  raise HTTPException(status_code=500, detail=str(e))
565