Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -346,52 +346,45 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
|
346 |
print(f"Error reranking documents: {e}")
|
347 |
return []
|
348 |
|
349 |
-
def extract_entities(text, ner_pipeline=None):
|
350 |
-
try:
|
351 |
-
if ner_pipeline is None:
|
352 |
-
ner_pipeline = models['ner_pipeline']
|
353 |
-
ner_results = ner_pipeline(text)
|
354 |
-
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
355 |
-
return list(entities)
|
356 |
-
except Exception as e:
|
357 |
-
print(f"Error extracting entities: {e}")
|
358 |
-
return []
|
359 |
|
360 |
-
|
|
|
|
|
|
|
|
|
361 |
try:
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
except Exception as e:
|
366 |
-
print(f"Error
|
367 |
-
return
|
368 |
-
|
369 |
-
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
|
370 |
-
relevant_portions = {}
|
371 |
-
query_entities = extract_entities(query)
|
372 |
-
print(f"Extracted Query Entities: {query_entities}")
|
373 |
-
for doc_id, doc_text in enumerate(document_texts):
|
374 |
-
sentences = nltk.sent_tokenize(doc_text)
|
375 |
-
doc_relevant_portions = []
|
376 |
-
doc_entities = extract_entities(doc_text)
|
377 |
-
print(f"Document {doc_id} Entities: {doc_entities}")
|
378 |
-
for i, sentence in enumerate(sentences):
|
379 |
-
sentence_entities = extract_entities(sentence)
|
380 |
-
relevance_score = match_entities(query_entities, sentence_entities)
|
381 |
-
if relevance_score >= min_query_words:
|
382 |
-
start_idx = max(0, i - portion_size // 2)
|
383 |
-
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
384 |
-
portion = " ".join(sentences[start_idx:end_idx])
|
385 |
-
doc_relevant_portions.append(portion)
|
386 |
-
if len(doc_relevant_portions) >= max_portions:
|
387 |
-
break
|
388 |
-
if not doc_relevant_portions and len(doc_entities) > 0:
|
389 |
-
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
390 |
-
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
391 |
-
for fallback_sentence in sorted_sentences[:max_portions]:
|
392 |
-
doc_relevant_portions.append(fallback_sentence)
|
393 |
-
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
394 |
-
return relevant_portions
|
395 |
|
396 |
def remove_duplicates(selected_parts):
|
397 |
unique_sentences = set()
|
@@ -426,11 +419,8 @@ def enhance_passage_with_entities(passage, entities):
|
|
426 |
def create_prompt(question, passage):
|
427 |
prompt = ("""
|
428 |
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
429 |
-
|
430 |
Passage: {passage}
|
431 |
-
|
432 |
Question: {question}
|
433 |
-
|
434 |
Answer:
|
435 |
""")
|
436 |
return prompt.format(passage=passage, question=question)
|
@@ -520,46 +510,69 @@ async def health_check():
|
|
520 |
async def chat_endpoint(chat_query: ChatQuery):
|
521 |
try:
|
522 |
query_text = chat_query.query
|
523 |
-
language_code = chat_query.language_code
|
|
|
|
|
524 |
if language_code == 0:
|
525 |
query_text = translate_ar_to_en(query_text)
|
|
|
|
|
526 |
query_embedding = embed_query_text(query_text)
|
527 |
n_results = 5
|
528 |
-
|
|
|
|
|
529 |
folder_path = 'downloaded_articles/downloaded_articles'
|
530 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
531 |
-
|
|
|
|
|
532 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
|
|
|
|
533 |
cross_encoder = models['cross_encoder']
|
534 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
|
|
|
|
535 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
536 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
542 |
combined_parts = " ".join(unique_selected_parts)
|
|
|
|
|
543 |
context = [query_text] + unique_selected_parts
|
544 |
entities = extract_entities(query_text)
|
545 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
|
|
546 |
prompt = create_prompt(query_text, passage)
|
547 |
answer = generate_answer(prompt)
|
548 |
answer_part = answer.split("Answer:")[-1].strip()
|
|
|
|
|
549 |
cleaned_answer = remove_answer_prefix(answer_part)
|
550 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
|
|
|
|
551 |
if language_code == 0:
|
552 |
final_answer = translate_en_to_ar(final_answer)
|
|
|
|
|
553 |
if final_answer:
|
554 |
print("Answer:")
|
555 |
print(final_answer)
|
556 |
else:
|
557 |
print("Sorry, I can't help with that.")
|
|
|
558 |
return {
|
559 |
"response": f"I hope this answers your question: {final_answer}",
|
560 |
-
# "conversation_id": chat_query.conversation_id,
|
561 |
"success": True
|
562 |
}
|
|
|
563 |
except Exception as e:
|
564 |
raise HTTPException(status_code=500, detail=str(e))
|
565 |
|
|
|
346 |
print(f"Error reranking documents: {e}")
|
347 |
return []
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
+
from sentence_transformers import SentenceTransformer
|
351 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
352 |
+
import nltk
|
353 |
+
|
354 |
+
def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
|
355 |
try:
|
356 |
+
relevant_portions = {}
|
357 |
+
|
358 |
+
for _, doc_id, doc_text in top_documents:
|
359 |
+
if doc_id not in embeddings_data:
|
360 |
+
print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
|
361 |
+
continue
|
362 |
+
|
363 |
+
# Retrieve the precomputed embedding for this document
|
364 |
+
doc_embedding = np.array(embeddings_data[doc_id])
|
365 |
+
|
366 |
+
# Compute similarity between the query embedding and the document embedding
|
367 |
+
similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
|
368 |
+
|
369 |
+
# Split the document into sentences
|
370 |
+
sentences = nltk.sent_tokenize(doc_text)
|
371 |
+
|
372 |
+
# Rank sentences based on their length (proxy for importance) or other heuristic
|
373 |
+
# Since we're using document-level embeddings, we assume all sentences are equally relevant.
|
374 |
+
sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
|
375 |
+
|
376 |
+
relevant_portions[doc_id] = sorted_sentences
|
377 |
+
|
378 |
+
print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
|
379 |
+
for i, sentence in enumerate(sorted_sentences, start=1):
|
380 |
+
print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
|
381 |
+
|
382 |
+
return relevant_portions
|
383 |
+
|
384 |
except Exception as e:
|
385 |
+
print(f"Error in extract_relevant_portions: {e}")
|
386 |
+
return {}
|
387 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
def remove_duplicates(selected_parts):
|
390 |
unique_sentences = set()
|
|
|
419 |
def create_prompt(question, passage):
|
420 |
prompt = ("""
|
421 |
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
|
|
422 |
Passage: {passage}
|
|
|
423 |
Question: {question}
|
|
|
424 |
Answer:
|
425 |
""")
|
426 |
return prompt.format(passage=passage, question=question)
|
|
|
510 |
async def chat_endpoint(chat_query: ChatQuery):
|
511 |
try:
|
512 |
query_text = chat_query.query
|
513 |
+
language_code = chat_query.language_code
|
514 |
+
|
515 |
+
# Translate Arabic to English if language_code is 0
|
516 |
if language_code == 0:
|
517 |
query_text = translate_ar_to_en(query_text)
|
518 |
+
|
519 |
+
# Generate query embedding
|
520 |
query_embedding = embed_query_text(query_text)
|
521 |
n_results = 5
|
522 |
+
|
523 |
+
# Load embeddings and retrieve initial results
|
524 |
+
embeddings_data = load_embeddings()
|
525 |
folder_path = 'downloaded_articles/downloaded_articles'
|
526 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
527 |
+
|
528 |
+
# Extract document IDs and texts
|
529 |
+
document_ids = [doc_id for doc_id, *_ in initial_results]
|
530 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
531 |
+
|
532 |
+
# Use cross-encoder to score documents
|
533 |
cross_encoder = models['cross_encoder']
|
534 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
535 |
+
|
536 |
+
# Score and sort documents
|
537 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
538 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
539 |
+
|
540 |
+
# Extract relevant portions
|
541 |
+
relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
|
542 |
+
unique_selected_parts = remove_duplicates(relevant_portions)
|
|
|
543 |
combined_parts = " ".join(unique_selected_parts)
|
544 |
+
|
545 |
+
# Build context and enhance passage with entities
|
546 |
context = [query_text] + unique_selected_parts
|
547 |
entities = extract_entities(query_text)
|
548 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
549 |
+
|
550 |
+
# Create prompt and generate answer
|
551 |
prompt = create_prompt(query_text, passage)
|
552 |
answer = generate_answer(prompt)
|
553 |
answer_part = answer.split("Answer:")[-1].strip()
|
554 |
+
|
555 |
+
# Clean and finalize the answer
|
556 |
cleaned_answer = remove_answer_prefix(answer_part)
|
557 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
558 |
+
|
559 |
+
# Translate English back to Arabic if needed
|
560 |
if language_code == 0:
|
561 |
final_answer = translate_en_to_ar(final_answer)
|
562 |
+
|
563 |
+
# Print and return the answer
|
564 |
if final_answer:
|
565 |
print("Answer:")
|
566 |
print(final_answer)
|
567 |
else:
|
568 |
print("Sorry, I can't help with that.")
|
569 |
+
|
570 |
return {
|
571 |
"response": f"I hope this answers your question: {final_answer}",
|
572 |
+
# "conversation_id": chat_query.conversation_id, # Uncomment if needed
|
573 |
"success": True
|
574 |
}
|
575 |
+
|
576 |
except Exception as e:
|
577 |
raise HTTPException(status_code=500, detail=str(e))
|
578 |
|