Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -340,159 +340,117 @@ def retrieve_metadata(document_indices: List[int], metadata_path: str = 'recipes
|
|
340 |
print(f"Error retrieving metadata: {e}")
|
341 |
return {}
|
342 |
|
343 |
-
def rerank_documents(query
|
344 |
try:
|
345 |
-
# Batch process all documents at once
|
346 |
pairs = [(query, doc) for doc in document_texts]
|
347 |
-
scores = cross_encoder_model.predict(pairs
|
348 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
349 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
|
|
|
350 |
return scored_documents
|
351 |
except Exception as e:
|
352 |
print(f"Error reranking documents: {e}")
|
353 |
return []
|
354 |
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
for i in range(0, len(texts), batch_size):
|
359 |
-
batch_texts = texts[i:i + batch_size]
|
360 |
-
# Process multiple texts in parallel
|
361 |
-
inputs = biobert_tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
|
362 |
-
with torch.no_grad(): # Disable gradient calculation
|
363 |
-
outputs = biobert_model(**inputs)
|
364 |
-
|
365 |
-
predictions = torch.argmax(outputs.logits, dim=2)
|
366 |
-
|
367 |
-
for j, (input_ids, preds) in enumerate(zip(inputs.input_ids, predictions)):
|
368 |
-
tokens = biobert_tokenizer.convert_ids_to_tokens(input_ids)
|
369 |
-
entities = [tokens[k] for k in range(len(tokens)) if preds[k].item() != 0]
|
370 |
-
all_entities.append(entities)
|
371 |
-
|
372 |
-
return all_entities
|
373 |
-
except Exception as e:
|
374 |
-
print(f"Error in batch entity extraction: {e}")
|
375 |
-
return [[] for _ in texts]
|
376 |
|
377 |
-
def extract_relevant_portions(
|
378 |
-
max_portions: int = 3, portion_size: int = 1) -> Dict[str, List[str]]:
|
379 |
try:
|
380 |
-
# Process query and all documents in one batch
|
381 |
-
all_texts = [query] + document_texts
|
382 |
-
all_entities = extract_entities_batch(all_texts, biobert_tokenizer, biobert_model)
|
383 |
-
|
384 |
-
query_entities = set(all_entities[0])
|
385 |
relevant_portions = {}
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
sentences = nltk.sent_tokenize(doc_text)
|
392 |
-
doc_relevant_portions = []
|
393 |
|
394 |
-
#
|
395 |
-
|
396 |
-
for i, sentence in enumerate(sentences):
|
397 |
-
entity_overlap = len(query_entities.intersection(doc_entities))
|
398 |
-
sentence_scores.append((entity_overlap, i))
|
399 |
|
400 |
-
#
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
|
|
|
|
407 |
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
results = list(executor.map(lambda x: process_document(x), range(len(document_texts))))
|
413 |
-
|
414 |
-
relevant_portions = dict(results)
|
415 |
return relevant_portions
|
416 |
-
|
417 |
except Exception as e:
|
418 |
-
print(f"Error
|
419 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
|
421 |
-
def
|
422 |
try:
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
)
|
436 |
-
|
437 |
-
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
438 |
-
|
439 |
-
# Quick relevance check
|
440 |
-
if any(word in answer.lower() for word in prompt.lower().split()):
|
441 |
-
return answer
|
442 |
-
return "I apologize, but I cannot provide a relevant answer based on the given information."
|
443 |
-
|
444 |
except Exception as e:
|
445 |
-
print(f"Error
|
446 |
-
return
|
447 |
|
448 |
def enhance_passage_with_entities(passage, entities):
|
449 |
return f"{passage}\n\nEntities: {', '.join(entities)}"
|
450 |
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
Passage: {passage}
|
455 |
-
Question: {question}
|
456 |
-
Answer:
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
models
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
# Remove duplicates while preserving order
|
480 |
-
unique_portions = list(dict.fromkeys(all_portions))
|
481 |
-
|
482 |
-
# Create context from unique portions
|
483 |
-
context = " ".join(unique_portions[:max_portions])
|
484 |
-
|
485 |
-
# Generate and return answer
|
486 |
-
prompt = create_prompt(query, context)
|
487 |
-
return generate_answer(
|
488 |
-
prompt,
|
489 |
-
models['llm_tokenizer'],
|
490 |
-
models['llm_model']
|
491 |
-
)
|
492 |
-
|
493 |
-
except Exception as e:
|
494 |
-
print(f"Error in query processing pipeline: {e}")
|
495 |
-
return "I apologize, but I encountered an error while processing your question."
|
496 |
def remove_answer_prefix(text):
|
497 |
prefix = "Answer:"
|
498 |
if prefix in text:
|
@@ -558,132 +516,71 @@ async def health_check():
|
|
558 |
@app.post("/api/chat")
|
559 |
async def chat_endpoint(chat_query: ChatQuery):
|
560 |
try:
|
561 |
-
# Initialize response timing
|
562 |
-
start_time = asyncio.get_event_loop().time()
|
563 |
-
|
564 |
-
# Extract query and handle translation
|
565 |
query_text = chat_query.query
|
566 |
language_code = chat_query.language_code
|
567 |
-
|
|
|
568 |
if language_code == 0:
|
569 |
-
query_text =
|
570 |
-
|
571 |
-
#
|
572 |
-
|
573 |
-
embeddings_data_task = run_in_threadpool(load_embeddings)
|
574 |
-
|
575 |
-
# Wait for both tasks to complete
|
576 |
-
query_embedding, embeddings_data = await asyncio.gather(
|
577 |
-
query_embedding_task,
|
578 |
-
embeddings_data_task
|
579 |
-
)
|
580 |
-
|
581 |
-
# Initial document retrieval
|
582 |
n_results = 5
|
|
|
|
|
|
|
583 |
folder_path = 'downloaded_articles/downloaded_articles'
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
query_embeddings,
|
588 |
-
query_embedding,
|
589 |
-
embeddings_data,
|
590 |
-
n_results
|
591 |
-
)
|
592 |
-
|
593 |
document_ids = [doc_id for doc_id, *_ in initial_results]
|
594 |
-
document_texts =
|
595 |
-
|
596 |
-
|
597 |
-
folder_path
|
598 |
-
)
|
599 |
-
|
600 |
-
# Rerank documents
|
601 |
cross_encoder = models['cross_encoder']
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
async with asyncio.TaskGroup() as tg:
|
612 |
-
# Extract entities in parallel
|
613 |
-
entities_task = tg.create_task(
|
614 |
-
run_in_threadpool(
|
615 |
-
extract_entities_batch,
|
616 |
-
[query_text] + [doc[2] for doc in scored_documents[:3]],
|
617 |
-
models['bio_tokenizer'],
|
618 |
-
models['bio_model']
|
619 |
-
)
|
620 |
-
)
|
621 |
-
|
622 |
-
# Extract relevant portions
|
623 |
-
portions_task = tg.create_task(
|
624 |
-
run_in_threadpool(
|
625 |
-
extract_relevant_portions,
|
626 |
-
[doc[2] for doc in scored_documents[:3]],
|
627 |
-
query_text,
|
628 |
-
models['bio_tokenizer'],
|
629 |
-
models['bio_model']
|
630 |
-
)
|
631 |
-
)
|
632 |
-
|
633 |
-
entities = (await entities_task)[0] # First item is query entities
|
634 |
-
relevant_portions = await portions_task
|
635 |
-
|
636 |
-
# Flatten and process portions
|
637 |
-
flattened_portions = []
|
638 |
-
for doc_portions in relevant_portions.values():
|
639 |
-
flattened_portions.extend(doc_portions)
|
640 |
-
|
641 |
-
unique_selected_parts = list(dict.fromkeys(flattened_portions))
|
642 |
combined_parts = " ".join(unique_selected_parts)
|
643 |
-
|
644 |
-
#
|
|
|
|
|
645 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
|
|
646 |
prompt = create_prompt(query_text, passage)
|
647 |
-
|
648 |
-
# Generate answer
|
649 |
-
answer = await run_in_threadpool(
|
650 |
-
generate_answer,
|
651 |
-
prompt,
|
652 |
-
models['llm_tokenizer'],
|
653 |
-
models['llm_model']
|
654 |
-
)
|
655 |
-
|
656 |
-
# Process answer
|
657 |
answer_part = answer.split("Answer:")[-1].strip()
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
|
|
|
|
662 |
if language_code == 0:
|
663 |
-
final_answer =
|
664 |
-
|
665 |
-
#
|
666 |
-
end_time = asyncio.get_event_loop().time()
|
667 |
-
response_time = end_time - start_time
|
668 |
-
|
669 |
if final_answer:
|
670 |
-
print(
|
671 |
print(final_answer)
|
672 |
-
|
673 |
-
return {
|
674 |
-
"response": f"I hope this answers your question: {final_answer}",
|
675 |
-
"success": True,
|
676 |
-
"response_time": response_time
|
677 |
-
}
|
678 |
else:
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
|
|
|
|
685 |
except Exception as e:
|
686 |
-
print(f"Error in chat endpoint: {str(e)}")
|
687 |
raise HTTPException(status_code=500, detail=str(e))
|
688 |
|
689 |
@app.post("/api/resources")
|
|
|
340 |
print(f"Error retrieving metadata: {e}")
|
341 |
return {}
|
342 |
|
343 |
+
def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
344 |
try:
|
|
|
345 |
pairs = [(query, doc) for doc in document_texts]
|
346 |
+
scores = cross_encoder_model.predict(pairs)
|
347 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
348 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
349 |
+
print("Reranked results:")
|
350 |
+
for idx, (score, doc_id, doc) in enumerate(scored_documents):
|
351 |
+
print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id})")
|
352 |
return scored_documents
|
353 |
except Exception as e:
|
354 |
print(f"Error reranking documents: {e}")
|
355 |
return []
|
356 |
|
357 |
+
from sentence_transformers import SentenceTransformer
|
358 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
359 |
+
import nltk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
|
361 |
+
def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
|
|
|
362 |
try:
|
|
|
|
|
|
|
|
|
|
|
363 |
relevant_portions = {}
|
364 |
+
|
365 |
+
for _, doc_id, doc_text in top_documents:
|
366 |
+
if doc_id not in embeddings_data:
|
367 |
+
print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
|
368 |
+
continue
|
|
|
|
|
369 |
|
370 |
+
# Retrieve the precomputed embedding for this document
|
371 |
+
doc_embedding = np.array(embeddings_data[doc_id])
|
|
|
|
|
|
|
372 |
|
373 |
+
# Compute similarity between the query embedding and the document embedding
|
374 |
+
similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
|
375 |
+
|
376 |
+
# Split the document into sentences
|
377 |
+
sentences = nltk.sent_tokenize(doc_text)
|
378 |
+
|
379 |
+
# Rank sentences based on their length (proxy for importance) or other heuristic
|
380 |
+
# Since we're using document-level embeddings, we assume all sentences are equally relevant.
|
381 |
+
sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
|
382 |
+
|
383 |
+
relevant_portions[doc_id] = sorted_sentences
|
384 |
|
385 |
+
print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
|
386 |
+
for i, sentence in enumerate(sorted_sentences, start=1):
|
387 |
+
print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
|
388 |
+
|
|
|
|
|
|
|
389 |
return relevant_portions
|
390 |
+
|
391 |
except Exception as e:
|
392 |
+
print(f"Error in extract_relevant_portions: {e}")
|
393 |
+
return {}
|
394 |
+
|
395 |
+
|
396 |
+
def remove_duplicates(selected_parts):
|
397 |
+
unique_sentences = set()
|
398 |
+
unique_selected_parts = []
|
399 |
+
for sentence in selected_parts:
|
400 |
+
if sentence not in unique_sentences:
|
401 |
+
unique_selected_parts.append(sentence)
|
402 |
+
unique_sentences.add(sentence)
|
403 |
+
return unique_selected_parts
|
404 |
|
405 |
+
def extract_entities(text):
|
406 |
try:
|
407 |
+
biobert_tokenizer = models['bio_tokenizer']
|
408 |
+
biobert_model = models['bio_model']
|
409 |
+
inputs = biobert_tokenizer(text, return_tensors="pt")
|
410 |
+
outputs = biobert_model(**inputs)
|
411 |
+
predictions = torch.argmax(outputs.logits, dim=2)
|
412 |
+
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
413 |
+
entities = [
|
414 |
+
tokens[i]
|
415 |
+
for i in range(len(tokens))
|
416 |
+
if predictions[0][i].item() != 0 # Assuming 0 is the label for non-entity
|
417 |
+
]
|
418 |
+
return entities
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
except Exception as e:
|
420 |
+
print(f"Error extracting entities: {e}")
|
421 |
+
return []
|
422 |
|
423 |
def enhance_passage_with_entities(passage, entities):
|
424 |
return f"{passage}\n\nEntities: {', '.join(entities)}"
|
425 |
|
426 |
+
def create_prompt(question, passage):
|
427 |
+
prompt = ("""
|
428 |
+
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
429 |
+
Passage: {passage}
|
430 |
+
Question: {question}
|
431 |
+
Answer:
|
432 |
+
""")
|
433 |
+
return prompt.format(passage=passage, question=question)
|
434 |
+
|
435 |
+
def generate_answer(prompt, max_length=860, temperature=0.2):
|
436 |
+
tokenizer_f = models['llm_tokenizer']
|
437 |
+
model_f = models['llm_model']
|
438 |
+
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
439 |
+
output_ids = model_f.generate(
|
440 |
+
inputs.input_ids,
|
441 |
+
max_length=max_length,
|
442 |
+
num_return_sequences=1,
|
443 |
+
temperature=temperature,
|
444 |
+
pad_token_id=tokenizer_f.eos_token_id
|
445 |
+
)
|
446 |
+
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
447 |
+
passage_keywords = set(prompt.lower().split())
|
448 |
+
answer_keywords = set(answer.lower().split())
|
449 |
+
if passage_keywords.intersection(answer_keywords):
|
450 |
+
return answer
|
451 |
+
else:
|
452 |
+
return "Sorry, I can't help with that."
|
453 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
def remove_answer_prefix(text):
|
455 |
prefix = "Answer:"
|
456 |
if prefix in text:
|
|
|
516 |
@app.post("/api/chat")
|
517 |
async def chat_endpoint(chat_query: ChatQuery):
|
518 |
try:
|
|
|
|
|
|
|
|
|
519 |
query_text = chat_query.query
|
520 |
language_code = chat_query.language_code
|
521 |
+
|
522 |
+
# Translate Arabic to English if language_code is 0
|
523 |
if language_code == 0:
|
524 |
+
query_text = translate_ar_to_en(query_text)
|
525 |
+
|
526 |
+
# Generate query embedding
|
527 |
+
query_embedding = embed_query_text(query_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
n_results = 5
|
529 |
+
|
530 |
+
# Load embeddings and retrieve initial results
|
531 |
+
embeddings_data = load_embeddings()
|
532 |
folder_path = 'downloaded_articles/downloaded_articles'
|
533 |
+
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
534 |
+
|
535 |
+
# Extract document IDs and texts
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
document_ids = [doc_id for doc_id, *_ in initial_results]
|
537 |
+
document_texts = retrieve_document_texts(document_ids, folder_path)
|
538 |
+
|
539 |
+
# Use cross-encoder to score documents
|
|
|
|
|
|
|
|
|
540 |
cross_encoder = models['cross_encoder']
|
541 |
+
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
542 |
+
|
543 |
+
# Score and sort documents
|
544 |
+
scored_documents = list(zip(scores, document_ids, document_texts))
|
545 |
+
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
546 |
+
|
547 |
+
# Extract relevant portions
|
548 |
+
relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
|
549 |
+
unique_selected_parts = remove_duplicates(relevant_portions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
combined_parts = " ".join(unique_selected_parts)
|
551 |
+
|
552 |
+
# Build context and enhance passage with entities
|
553 |
+
context = [query_text] + unique_selected_parts
|
554 |
+
entities = extract_entities(query_text)
|
555 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
556 |
+
|
557 |
+
# Create prompt and generate answer
|
558 |
prompt = create_prompt(query_text, passage)
|
559 |
+
answer = generate_answer(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
answer_part = answer.split("Answer:")[-1].strip()
|
561 |
+
|
562 |
+
# Clean and finalize the answer
|
563 |
+
cleaned_answer = remove_answer_prefix(answer_part)
|
564 |
+
final_answer = remove_incomplete_sentence(cleaned_answer)
|
565 |
+
|
566 |
+
# Translate English back to Arabic if needed
|
567 |
if language_code == 0:
|
568 |
+
final_answer = translate_en_to_ar(final_answer)
|
569 |
+
|
570 |
+
# Print and return the answer
|
|
|
|
|
|
|
571 |
if final_answer:
|
572 |
+
print("Answer:")
|
573 |
print(final_answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
else:
|
575 |
+
print("Sorry, I can't help with that.")
|
576 |
+
|
577 |
+
return {
|
578 |
+
"response": f"I hope this answers your question: {final_answer}",
|
579 |
+
# "conversation_id": chat_query.conversation_id, # Uncomment if needed
|
580 |
+
"success": True
|
581 |
+
}
|
582 |
+
|
583 |
except Exception as e:
|
|
|
584 |
raise HTTPException(status_code=500, detail=str(e))
|
585 |
|
586 |
@app.post("/api/resources")
|