Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -28,12 +28,9 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
28 |
from bs4 import BeautifulSoup
|
29 |
from huggingface_hub import hf_hub_download
|
30 |
from safetensors.torch import load_file
|
31 |
-
from typing import List, Dict,
|
32 |
from safetensors.numpy import load_file
|
33 |
from safetensors.torch import safe_open
|
34 |
-
from concurrent.futures import ThreadPoolExecutor
|
35 |
-
import asyncio
|
36 |
-
from functools import partial
|
37 |
nltk.download('punkt_tab')
|
38 |
|
39 |
app = FastAPI()
|
@@ -66,11 +63,6 @@ class ChatMessage(BaseModel):
|
|
66 |
content: str
|
67 |
timestamp: str
|
68 |
|
69 |
-
async def run_in_threadpool(func, *args, **kwargs):
|
70 |
-
return await asyncio.get_event_loop().run_in_executor(
|
71 |
-
None, partial(func, *args, **kwargs)
|
72 |
-
)
|
73 |
-
|
74 |
def init_nltk():
|
75 |
try:
|
76 |
nltk.download('punkt', quiet=True)
|
@@ -354,44 +346,52 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
|
354 |
print(f"Error reranking documents: {e}")
|
355 |
return []
|
356 |
|
357 |
-
|
358 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
359 |
-
import nltk
|
360 |
-
|
361 |
-
def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
|
362 |
try:
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
continue
|
369 |
-
|
370 |
-
# Retrieve the precomputed embedding for this document
|
371 |
-
doc_embedding = np.array(embeddings_data[doc_id])
|
372 |
-
|
373 |
-
# Compute similarity between the query embedding and the document embedding
|
374 |
-
similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
|
375 |
-
|
376 |
-
# Split the document into sentences
|
377 |
-
sentences = nltk.sent_tokenize(doc_text)
|
378 |
-
|
379 |
-
# Rank sentences based on their length (proxy for importance) or other heuristic
|
380 |
-
# Since we're using document-level embeddings, we assume all sentences are equally relevant.
|
381 |
-
sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
|
382 |
-
|
383 |
-
relevant_portions[doc_id] = sorted_sentences
|
384 |
-
|
385 |
-
print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
|
386 |
-
for i, sentence in enumerate(sorted_sentences, start=1):
|
387 |
-
print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
|
388 |
-
|
389 |
-
return relevant_portions
|
390 |
-
|
391 |
except Exception as e:
|
392 |
-
print(f"Error
|
393 |
-
return
|
394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
def remove_duplicates(selected_parts):
|
397 |
unique_sentences = set()
|
@@ -426,8 +426,11 @@ def enhance_passage_with_entities(passage, entities):
|
|
426 |
def create_prompt(question, passage):
|
427 |
prompt = ("""
|
428 |
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
|
|
429 |
Passage: {passage}
|
|
|
430 |
Question: {question}
|
|
|
431 |
Answer:
|
432 |
""")
|
433 |
return prompt.format(passage=passage, question=question)
|
@@ -517,69 +520,46 @@ async def health_check():
|
|
517 |
async def chat_endpoint(chat_query: ChatQuery):
|
518 |
try:
|
519 |
query_text = chat_query.query
|
520 |
-
language_code = chat_query.language_code
|
521 |
-
|
522 |
-
# Translate Arabic to English if language_code is 0
|
523 |
if language_code == 0:
|
524 |
query_text = translate_ar_to_en(query_text)
|
525 |
-
|
526 |
-
# Generate query embedding
|
527 |
query_embedding = embed_query_text(query_text)
|
528 |
n_results = 5
|
529 |
-
|
530 |
-
# Load embeddings and retrieve initial results
|
531 |
-
embeddings_data = load_embeddings()
|
532 |
folder_path = 'downloaded_articles/downloaded_articles'
|
533 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
534 |
-
|
535 |
-
# Extract document IDs and texts
|
536 |
-
document_ids = [doc_id for doc_id, *_ in initial_results]
|
537 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
538 |
-
|
539 |
-
# Use cross-encoder to score documents
|
540 |
cross_encoder = models['cross_encoder']
|
541 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
542 |
-
|
543 |
-
# Score and sort documents
|
544 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
545 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
550 |
combined_parts = " ".join(unique_selected_parts)
|
551 |
-
|
552 |
-
# Build context and enhance passage with entities
|
553 |
context = [query_text] + unique_selected_parts
|
554 |
entities = extract_entities(query_text)
|
555 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
556 |
-
|
557 |
-
# Create prompt and generate answer
|
558 |
prompt = create_prompt(query_text, passage)
|
559 |
answer = generate_answer(prompt)
|
560 |
answer_part = answer.split("Answer:")[-1].strip()
|
561 |
-
|
562 |
-
# Clean and finalize the answer
|
563 |
cleaned_answer = remove_answer_prefix(answer_part)
|
564 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
565 |
-
|
566 |
-
# Translate English back to Arabic if needed
|
567 |
if language_code == 0:
|
568 |
final_answer = translate_en_to_ar(final_answer)
|
569 |
-
|
570 |
-
# Print and return the answer
|
571 |
if final_answer:
|
572 |
print("Answer:")
|
573 |
print(final_answer)
|
574 |
else:
|
575 |
print("Sorry, I can't help with that.")
|
576 |
-
|
577 |
return {
|
578 |
"response": f"I hope this answers your question: {final_answer}",
|
579 |
-
# "conversation_id": chat_query.conversation_id,
|
580 |
"success": True
|
581 |
}
|
582 |
-
|
583 |
except Exception as e:
|
584 |
raise HTTPException(status_code=500, detail=str(e))
|
585 |
|
|
|
28 |
from bs4 import BeautifulSoup
|
29 |
from huggingface_hub import hf_hub_download
|
30 |
from safetensors.torch import load_file
|
31 |
+
from typing import List, Dict, Optional
|
32 |
from safetensors.numpy import load_file
|
33 |
from safetensors.torch import safe_open
|
|
|
|
|
|
|
34 |
nltk.download('punkt_tab')
|
35 |
|
36 |
app = FastAPI()
|
|
|
63 |
content: str
|
64 |
timestamp: str
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
def init_nltk():
|
67 |
try:
|
68 |
nltk.download('punkt', quiet=True)
|
|
|
346 |
print(f"Error reranking documents: {e}")
|
347 |
return []
|
348 |
|
349 |
+
def extract_entities(text, ner_pipeline=None):
|
|
|
|
|
|
|
|
|
350 |
try:
|
351 |
+
if ner_pipeline is None:
|
352 |
+
ner_pipeline = models['ner_pipeline']
|
353 |
+
ner_results = ner_pipeline(text)
|
354 |
+
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
355 |
+
return list(entities)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
except Exception as e:
|
357 |
+
print(f"Error extracting entities: {e}")
|
358 |
+
return []
|
359 |
|
360 |
+
def match_entities(query_entities, sentence_entities):
|
361 |
+
try:
|
362 |
+
query_set, sentence_set = set(query_entities), set(sentence_entities)
|
363 |
+
matches = query_set.intersection(sentence_set)
|
364 |
+
return len(matches)
|
365 |
+
except Exception as e:
|
366 |
+
print(f"Error matching entities: {e}")
|
367 |
+
return 0
|
368 |
+
|
369 |
+
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
|
370 |
+
relevant_portions = {}
|
371 |
+
query_entities = extract_entities(query)
|
372 |
+
print(f"Extracted Query Entities: {query_entities}")
|
373 |
+
for doc_id, doc_text in enumerate(document_texts):
|
374 |
+
sentences = nltk.sent_tokenize(doc_text)
|
375 |
+
doc_relevant_portions = []
|
376 |
+
doc_entities = extract_entities(doc_text)
|
377 |
+
print(f"Document {doc_id} Entities: {doc_entities}")
|
378 |
+
for i, sentence in enumerate(sentences):
|
379 |
+
sentence_entities = extract_entities(sentence)
|
380 |
+
relevance_score = match_entities(query_entities, sentence_entities)
|
381 |
+
if relevance_score >= min_query_words:
|
382 |
+
start_idx = max(0, i - portion_size // 2)
|
383 |
+
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
384 |
+
portion = " ".join(sentences[start_idx:end_idx])
|
385 |
+
doc_relevant_portions.append(portion)
|
386 |
+
if len(doc_relevant_portions) >= max_portions:
|
387 |
+
break
|
388 |
+
if not doc_relevant_portions and len(doc_entities) > 0:
|
389 |
+
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
390 |
+
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
391 |
+
for fallback_sentence in sorted_sentences[:max_portions]:
|
392 |
+
doc_relevant_portions.append(fallback_sentence)
|
393 |
+
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
394 |
+
return relevant_portions
|
395 |
|
396 |
def remove_duplicates(selected_parts):
|
397 |
unique_sentences = set()
|
|
|
426 |
def create_prompt(question, passage):
|
427 |
prompt = ("""
|
428 |
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
429 |
+
|
430 |
Passage: {passage}
|
431 |
+
|
432 |
Question: {question}
|
433 |
+
|
434 |
Answer:
|
435 |
""")
|
436 |
return prompt.format(passage=passage, question=question)
|
|
|
520 |
async def chat_endpoint(chat_query: ChatQuery):
|
521 |
try:
|
522 |
query_text = chat_query.query
|
523 |
+
language_code = chat_query.language_code
|
|
|
|
|
524 |
if language_code == 0:
|
525 |
query_text = translate_ar_to_en(query_text)
|
|
|
|
|
526 |
query_embedding = embed_query_text(query_text)
|
527 |
n_results = 5
|
528 |
+
embeddings_data = load_embeddings ()
|
|
|
|
|
529 |
folder_path = 'downloaded_articles/downloaded_articles'
|
530 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
531 |
+
document_ids = [doc_id for doc_id, _ in initial_results]
|
|
|
|
|
532 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
|
|
|
|
533 |
cross_encoder = models['cross_encoder']
|
534 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
|
|
|
|
535 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
536 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
537 |
+
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
|
538 |
+
flattened_relevant_portions = []
|
539 |
+
for doc_id, portions in relevant_portions.items():
|
540 |
+
flattened_relevant_portions.extend(portions)
|
541 |
+
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
542 |
combined_parts = " ".join(unique_selected_parts)
|
|
|
|
|
543 |
context = [query_text] + unique_selected_parts
|
544 |
entities = extract_entities(query_text)
|
545 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
|
|
546 |
prompt = create_prompt(query_text, passage)
|
547 |
answer = generate_answer(prompt)
|
548 |
answer_part = answer.split("Answer:")[-1].strip()
|
|
|
|
|
549 |
cleaned_answer = remove_answer_prefix(answer_part)
|
550 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
|
|
|
|
551 |
if language_code == 0:
|
552 |
final_answer = translate_en_to_ar(final_answer)
|
|
|
|
|
553 |
if final_answer:
|
554 |
print("Answer:")
|
555 |
print(final_answer)
|
556 |
else:
|
557 |
print("Sorry, I can't help with that.")
|
|
|
558 |
return {
|
559 |
"response": f"I hope this answers your question: {final_answer}",
|
560 |
+
# "conversation_id": chat_query.conversation_id,
|
561 |
"success": True
|
562 |
}
|
|
|
563 |
except Exception as e:
|
564 |
raise HTTPException(status_code=500, detail=str(e))
|
565 |
|