thechaiexperiment commited on
Commit
9b08b8e
·
verified ·
1 Parent(s): f9e3554

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -239
app.py CHANGED
@@ -340,159 +340,117 @@ def retrieve_metadata(document_indices: List[int], metadata_path: str = 'recipes
340
  print(f"Error retrieving metadata: {e}")
341
  return {}
342
 
343
- def rerank_documents(query: str, document_ids: List[str], document_texts: List[str], cross_encoder_model) -> List[Tuple[float, str, str]]:
344
  try:
345
- # Batch process all documents at once
346
  pairs = [(query, doc) for doc in document_texts]
347
- scores = cross_encoder_model.predict(pairs, batch_size=8) # Increased batch size
348
  scored_documents = list(zip(scores, document_ids, document_texts))
349
  scored_documents.sort(key=lambda x: x[0], reverse=True)
 
 
 
350
  return scored_documents
351
  except Exception as e:
352
  print(f"Error reranking documents: {e}")
353
  return []
354
 
355
- def extract_entities_batch(texts: List[str], biobert_tokenizer, biobert_model, batch_size: int = 8) -> List[List[str]]:
356
- try:
357
- all_entities = []
358
- for i in range(0, len(texts), batch_size):
359
- batch_texts = texts[i:i + batch_size]
360
- # Process multiple texts in parallel
361
- inputs = biobert_tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
362
- with torch.no_grad(): # Disable gradient calculation
363
- outputs = biobert_model(**inputs)
364
-
365
- predictions = torch.argmax(outputs.logits, dim=2)
366
-
367
- for j, (input_ids, preds) in enumerate(zip(inputs.input_ids, predictions)):
368
- tokens = biobert_tokenizer.convert_ids_to_tokens(input_ids)
369
- entities = [tokens[k] for k in range(len(tokens)) if preds[k].item() != 0]
370
- all_entities.append(entities)
371
-
372
- return all_entities
373
- except Exception as e:
374
- print(f"Error in batch entity extraction: {e}")
375
- return [[] for _ in texts]
376
 
377
- def extract_relevant_portions(document_texts: List[str], query: str, biobert_tokenizer, biobert_model,
378
- max_portions: int = 3, portion_size: int = 1) -> Dict[str, List[str]]:
379
  try:
380
- # Process query and all documents in one batch
381
- all_texts = [query] + document_texts
382
- all_entities = extract_entities_batch(all_texts, biobert_tokenizer, biobert_model)
383
-
384
- query_entities = set(all_entities[0])
385
  relevant_portions = {}
386
-
387
- def process_document(doc_idx: int) -> Tuple[str, List[str]]:
388
- doc_text = document_texts[doc_idx]
389
- doc_entities = set(all_entities[doc_idx + 1]) # +1 because query was first
390
-
391
- sentences = nltk.sent_tokenize(doc_text)
392
- doc_relevant_portions = []
393
 
394
- # Score sentences based on entity overlap
395
- sentence_scores = []
396
- for i, sentence in enumerate(sentences):
397
- entity_overlap = len(query_entities.intersection(doc_entities))
398
- sentence_scores.append((entity_overlap, i))
399
 
400
- # Sort and select top sentences
401
- sentence_scores.sort(reverse=True)
402
- for _, sent_idx in sentence_scores[:max_portions]:
403
- start_idx = max(0, sent_idx - portion_size // 2)
404
- end_idx = min(len(sentences), sent_idx + portion_size // 2 + 1)
405
- portion = " ".join(sentences[start_idx:end_idx])
406
- doc_relevant_portions.append(portion)
 
 
 
 
407
 
408
- return f"Document_{doc_idx}", doc_relevant_portions
409
-
410
- # Process documents in parallel
411
- with ThreadPoolExecutor(max_workers=4) as executor:
412
- results = list(executor.map(lambda x: process_document(x), range(len(document_texts))))
413
-
414
- relevant_portions = dict(results)
415
  return relevant_portions
416
-
417
  except Exception as e:
418
- print(f"Error extracting relevant portions: {e}")
419
- return {f"Document_{i}": [] for i in range(len(document_texts))}
 
 
 
 
 
 
 
 
 
 
420
 
421
- def generate_answer(prompt: str, tokenizer_f, model_f, max_length: int = 860, temperature: float = 0.2) -> str:
422
  try:
423
- # Optimize input processing
424
- inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True, max_length=512)
425
-
426
- with torch.no_grad(): # Disable gradient calculation
427
- output_ids = model_f.generate(
428
- inputs.input_ids,
429
- max_length=max_length,
430
- num_return_sequences=1,
431
- temperature=temperature,
432
- pad_token_id=tokenizer_f.eos_token_id,
433
- do_sample=False, # Use greedy decoding for faster generation
434
- early_stopping=True
435
- )
436
-
437
- answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
438
-
439
- # Quick relevance check
440
- if any(word in answer.lower() for word in prompt.lower().split()):
441
- return answer
442
- return "I apologize, but I cannot provide a relevant answer based on the given information."
443
-
444
  except Exception as e:
445
- print(f"Error generating answer: {e}")
446
- return "I apologize, but I encountered an error while generating the answer."
447
 
448
  def enhance_passage_with_entities(passage, entities):
449
  return f"{passage}\n\nEntities: {', '.join(entities)}"
450
 
451
-
452
- def create_prompt(question: str, passage: str) -> str:
453
- return f"""As a medical expert, answer the following question based only on the provided passage. Be concise and direct.
454
- Passage: {passage}
455
- Question: {question}
456
- Answer:"""
457
-
458
- def process_query_and_generate_answer(
459
- query: str,
460
- relevant_documents: List[Tuple[float, str, str]],
461
- models: Dict,
462
- max_portions: int = 3
463
- ) -> str:
464
- try:
465
- # Extract relevant portions from top documents
466
- relevant_portions = extract_relevant_portions(
467
- [doc[2] for doc in relevant_documents[:3]], # Use top 3 documents
468
- query,
469
- models['bio_tokenizer'],
470
- models['bio_model'],
471
- max_portions=max_portions
472
- )
473
-
474
- # Combine relevant portions
475
- all_portions = []
476
- for doc_portions in relevant_portions.values():
477
- all_portions.extend(doc_portions)
478
-
479
- # Remove duplicates while preserving order
480
- unique_portions = list(dict.fromkeys(all_portions))
481
-
482
- # Create context from unique portions
483
- context = " ".join(unique_portions[:max_portions])
484
-
485
- # Generate and return answer
486
- prompt = create_prompt(query, context)
487
- return generate_answer(
488
- prompt,
489
- models['llm_tokenizer'],
490
- models['llm_model']
491
- )
492
-
493
- except Exception as e:
494
- print(f"Error in query processing pipeline: {e}")
495
- return "I apologize, but I encountered an error while processing your question."
496
  def remove_answer_prefix(text):
497
  prefix = "Answer:"
498
  if prefix in text:
@@ -558,132 +516,71 @@ async def health_check():
558
  @app.post("/api/chat")
559
  async def chat_endpoint(chat_query: ChatQuery):
560
  try:
561
- # Initialize response timing
562
- start_time = asyncio.get_event_loop().time()
563
-
564
- # Extract query and handle translation
565
  query_text = chat_query.query
566
  language_code = chat_query.language_code
567
-
 
568
  if language_code == 0:
569
- query_text = await run_in_threadpool(translate_ar_to_en, query_text)
570
-
571
- # Embed query and load embeddings in parallel
572
- query_embedding_task = run_in_threadpool(embed_query_text, query_text)
573
- embeddings_data_task = run_in_threadpool(load_embeddings)
574
-
575
- # Wait for both tasks to complete
576
- query_embedding, embeddings_data = await asyncio.gather(
577
- query_embedding_task,
578
- embeddings_data_task
579
- )
580
-
581
- # Initial document retrieval
582
  n_results = 5
 
 
 
583
  folder_path = 'downloaded_articles/downloaded_articles'
584
-
585
- # Get initial results and retrieve documents
586
- initial_results = await run_in_threadpool(
587
- query_embeddings,
588
- query_embedding,
589
- embeddings_data,
590
- n_results
591
- )
592
-
593
  document_ids = [doc_id for doc_id, *_ in initial_results]
594
- document_texts = await run_in_threadpool(
595
- retrieve_document_texts,
596
- document_ids,
597
- folder_path
598
- )
599
-
600
- # Rerank documents
601
  cross_encoder = models['cross_encoder']
602
- scored_documents = await run_in_threadpool(
603
- rerank_documents,
604
- query_text,
605
- document_ids,
606
- document_texts,
607
- cross_encoder
608
- )
609
-
610
- # Process documents and generate answer
611
- async with asyncio.TaskGroup() as tg:
612
- # Extract entities in parallel
613
- entities_task = tg.create_task(
614
- run_in_threadpool(
615
- extract_entities_batch,
616
- [query_text] + [doc[2] for doc in scored_documents[:3]],
617
- models['bio_tokenizer'],
618
- models['bio_model']
619
- )
620
- )
621
-
622
- # Extract relevant portions
623
- portions_task = tg.create_task(
624
- run_in_threadpool(
625
- extract_relevant_portions,
626
- [doc[2] for doc in scored_documents[:3]],
627
- query_text,
628
- models['bio_tokenizer'],
629
- models['bio_model']
630
- )
631
- )
632
-
633
- entities = (await entities_task)[0] # First item is query entities
634
- relevant_portions = await portions_task
635
-
636
- # Flatten and process portions
637
- flattened_portions = []
638
- for doc_portions in relevant_portions.values():
639
- flattened_portions.extend(doc_portions)
640
-
641
- unique_selected_parts = list(dict.fromkeys(flattened_portions))
642
  combined_parts = " ".join(unique_selected_parts)
643
-
644
- # Enhance passage and create prompt
 
 
645
  passage = enhance_passage_with_entities(combined_parts, entities)
 
 
646
  prompt = create_prompt(query_text, passage)
647
-
648
- # Generate answer
649
- answer = await run_in_threadpool(
650
- generate_answer,
651
- prompt,
652
- models['llm_tokenizer'],
653
- models['llm_model']
654
- )
655
-
656
- # Process answer
657
  answer_part = answer.split("Answer:")[-1].strip()
658
- cleaned_answer = await run_in_threadpool(remove_answer_prefix, answer_part)
659
- final_answer = await run_in_threadpool(remove_incomplete_sentence, cleaned_answer)
660
-
661
- # Handle translation if needed
 
 
662
  if language_code == 0:
663
- final_answer = await run_in_threadpool(translate_en_to_ar, final_answer)
664
-
665
- # Calculate response time
666
- end_time = asyncio.get_event_loop().time()
667
- response_time = end_time - start_time
668
-
669
  if final_answer:
670
- print(f"Answer generated in {response_time:.2f} seconds")
671
  print(final_answer)
672
-
673
- return {
674
- "response": f"I hope this answers your question: {final_answer}",
675
- "success": True,
676
- "response_time": response_time
677
- }
678
  else:
679
- return {
680
- "response": "Sorry, I can't help with that.",
681
- "success": False,
682
- "response_time": response_time
683
- }
684
-
 
 
685
  except Exception as e:
686
- print(f"Error in chat endpoint: {str(e)}")
687
  raise HTTPException(status_code=500, detail=str(e))
688
 
689
  @app.post("/api/resources")
 
340
  print(f"Error retrieving metadata: {e}")
341
  return {}
342
 
343
+ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
344
  try:
 
345
  pairs = [(query, doc) for doc in document_texts]
346
+ scores = cross_encoder_model.predict(pairs)
347
  scored_documents = list(zip(scores, document_ids, document_texts))
348
  scored_documents.sort(key=lambda x: x[0], reverse=True)
349
+ print("Reranked results:")
350
+ for idx, (score, doc_id, doc) in enumerate(scored_documents):
351
+ print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id})")
352
  return scored_documents
353
  except Exception as e:
354
  print(f"Error reranking documents: {e}")
355
  return []
356
 
357
+ from sentence_transformers import SentenceTransformer
358
+ from sklearn.metrics.pairwise import cosine_similarity
359
+ import nltk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
 
362
  try:
 
 
 
 
 
363
  relevant_portions = {}
364
+
365
+ for _, doc_id, doc_text in top_documents:
366
+ if doc_id not in embeddings_data:
367
+ print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
368
+ continue
 
 
369
 
370
+ # Retrieve the precomputed embedding for this document
371
+ doc_embedding = np.array(embeddings_data[doc_id])
 
 
 
372
 
373
+ # Compute similarity between the query embedding and the document embedding
374
+ similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
375
+
376
+ # Split the document into sentences
377
+ sentences = nltk.sent_tokenize(doc_text)
378
+
379
+ # Rank sentences based on their length (proxy for importance) or other heuristic
380
+ # Since we're using document-level embeddings, we assume all sentences are equally relevant.
381
+ sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
382
+
383
+ relevant_portions[doc_id] = sorted_sentences
384
 
385
+ print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
386
+ for i, sentence in enumerate(sorted_sentences, start=1):
387
+ print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
388
+
 
 
 
389
  return relevant_portions
390
+
391
  except Exception as e:
392
+ print(f"Error in extract_relevant_portions: {e}")
393
+ return {}
394
+
395
+
396
+ def remove_duplicates(selected_parts):
397
+ unique_sentences = set()
398
+ unique_selected_parts = []
399
+ for sentence in selected_parts:
400
+ if sentence not in unique_sentences:
401
+ unique_selected_parts.append(sentence)
402
+ unique_sentences.add(sentence)
403
+ return unique_selected_parts
404
 
405
+ def extract_entities(text):
406
  try:
407
+ biobert_tokenizer = models['bio_tokenizer']
408
+ biobert_model = models['bio_model']
409
+ inputs = biobert_tokenizer(text, return_tensors="pt")
410
+ outputs = biobert_model(**inputs)
411
+ predictions = torch.argmax(outputs.logits, dim=2)
412
+ tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
413
+ entities = [
414
+ tokens[i]
415
+ for i in range(len(tokens))
416
+ if predictions[0][i].item() != 0 # Assuming 0 is the label for non-entity
417
+ ]
418
+ return entities
 
 
 
 
 
 
 
 
 
419
  except Exception as e:
420
+ print(f"Error extracting entities: {e}")
421
+ return []
422
 
423
  def enhance_passage_with_entities(passage, entities):
424
  return f"{passage}\n\nEntities: {', '.join(entities)}"
425
 
426
+ def create_prompt(question, passage):
427
+ prompt = ("""
428
+ As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
429
+ Passage: {passage}
430
+ Question: {question}
431
+ Answer:
432
+ """)
433
+ return prompt.format(passage=passage, question=question)
434
+
435
+ def generate_answer(prompt, max_length=860, temperature=0.2):
436
+ tokenizer_f = models['llm_tokenizer']
437
+ model_f = models['llm_model']
438
+ inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
439
+ output_ids = model_f.generate(
440
+ inputs.input_ids,
441
+ max_length=max_length,
442
+ num_return_sequences=1,
443
+ temperature=temperature,
444
+ pad_token_id=tokenizer_f.eos_token_id
445
+ )
446
+ answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
447
+ passage_keywords = set(prompt.lower().split())
448
+ answer_keywords = set(answer.lower().split())
449
+ if passage_keywords.intersection(answer_keywords):
450
+ return answer
451
+ else:
452
+ return "Sorry, I can't help with that."
453
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  def remove_answer_prefix(text):
455
  prefix = "Answer:"
456
  if prefix in text:
 
516
  @app.post("/api/chat")
517
  async def chat_endpoint(chat_query: ChatQuery):
518
  try:
 
 
 
 
519
  query_text = chat_query.query
520
  language_code = chat_query.language_code
521
+
522
+ # Translate Arabic to English if language_code is 0
523
  if language_code == 0:
524
+ query_text = translate_ar_to_en(query_text)
525
+
526
+ # Generate query embedding
527
+ query_embedding = embed_query_text(query_text)
 
 
 
 
 
 
 
 
 
528
  n_results = 5
529
+
530
+ # Load embeddings and retrieve initial results
531
+ embeddings_data = load_embeddings()
532
  folder_path = 'downloaded_articles/downloaded_articles'
533
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
534
+
535
+ # Extract document IDs and texts
 
 
 
 
 
 
536
  document_ids = [doc_id for doc_id, *_ in initial_results]
537
+ document_texts = retrieve_document_texts(document_ids, folder_path)
538
+
539
+ # Use cross-encoder to score documents
 
 
 
 
540
  cross_encoder = models['cross_encoder']
541
+ scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
542
+
543
+ # Score and sort documents
544
+ scored_documents = list(zip(scores, document_ids, document_texts))
545
+ scored_documents.sort(key=lambda x: x[0], reverse=True)
546
+
547
+ # Extract relevant portions
548
+ relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
549
+ unique_selected_parts = remove_duplicates(relevant_portions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  combined_parts = " ".join(unique_selected_parts)
551
+
552
+ # Build context and enhance passage with entities
553
+ context = [query_text] + unique_selected_parts
554
+ entities = extract_entities(query_text)
555
  passage = enhance_passage_with_entities(combined_parts, entities)
556
+
557
+ # Create prompt and generate answer
558
  prompt = create_prompt(query_text, passage)
559
+ answer = generate_answer(prompt)
 
 
 
 
 
 
 
 
 
560
  answer_part = answer.split("Answer:")[-1].strip()
561
+
562
+ # Clean and finalize the answer
563
+ cleaned_answer = remove_answer_prefix(answer_part)
564
+ final_answer = remove_incomplete_sentence(cleaned_answer)
565
+
566
+ # Translate English back to Arabic if needed
567
  if language_code == 0:
568
+ final_answer = translate_en_to_ar(final_answer)
569
+
570
+ # Print and return the answer
 
 
 
571
  if final_answer:
572
+ print("Answer:")
573
  print(final_answer)
 
 
 
 
 
 
574
  else:
575
+ print("Sorry, I can't help with that.")
576
+
577
+ return {
578
+ "response": f"I hope this answers your question: {final_answer}",
579
+ # "conversation_id": chat_query.conversation_id, # Uncomment if needed
580
+ "success": True
581
+ }
582
+
583
  except Exception as e:
 
584
  raise HTTPException(status_code=500, detail=str(e))
585
 
586
  @app.post("/api/resources")