thechaiexperiment commited on
Commit
462ad54
·
verified ·
1 Parent(s): f37cde7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -123
app.py CHANGED
@@ -28,9 +28,12 @@ from sklearn.metrics.pairwise import cosine_similarity
28
  from bs4 import BeautifulSoup
29
  from huggingface_hub import hf_hub_download
30
  from safetensors.torch import load_file
31
- from typing import List, Dict, Optional
32
  from safetensors.numpy import load_file
33
  from safetensors.torch import safe_open
 
 
 
34
  nltk.download('punkt_tab')
35
 
36
  app = FastAPI()
@@ -63,6 +66,11 @@ class ChatMessage(BaseModel):
63
  content: str
64
  timestamp: str
65
 
 
 
 
 
 
66
  def init_nltk():
67
  try:
68
  nltk.download('punkt', quiet=True)
@@ -332,120 +340,155 @@ def retrieve_metadata(document_indices: List[int], metadata_path: str = 'recipes
332
  print(f"Error retrieving metadata: {e}")
333
  return {}
334
 
335
- def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
336
  try:
 
337
  pairs = [(query, doc) for doc in document_texts]
338
- scores = cross_encoder_model.predict(pairs)
339
  scored_documents = list(zip(scores, document_ids, document_texts))
340
  scored_documents.sort(key=lambda x: x[0], reverse=True)
341
- print("Reranked results:")
342
- for idx, (score, doc_id, doc) in enumerate(scored_documents):
343
- print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id})")
344
  return scored_documents
345
  except Exception as e:
346
  print(f"Error reranking documents: {e}")
347
  return []
348
 
349
- from sentence_transformers import SentenceTransformer
350
- from sklearn.metrics.pairwise import cosine_similarity
351
- import nltk
352
-
353
- def extract_relevant_portions(query_embedding, top_documents, embeddings_data, max_portions=3):
354
  try:
355
- relevant_portions = {}
356
-
357
- for _, doc_id, doc_text in top_documents:
358
- if doc_id not in embeddings_data:
359
- print(f"Warning: No embedding available for Document ID {doc_id}. Skipping...")
360
- continue
 
361
 
362
- # Retrieve the precomputed embedding for this document
363
- doc_embedding = np.array(embeddings_data[doc_id])
364
 
365
- # Compute similarity between the query embedding and the document embedding
366
- similarity = cosine_similarity(query_embedding, [doc_embedding]).flatten()[0]
 
 
 
 
 
 
 
367
 
368
- # Split the document into sentences
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  sentences = nltk.sent_tokenize(doc_text)
370
-
371
- # Rank sentences based on their length (proxy for importance) or other heuristic
372
- # Since we're using document-level embeddings, we assume all sentences are equally relevant.
373
- sorted_sentences = sorted(sentences, key=lambda x: len(x), reverse=True)[:max_portions]
374
-
375
- relevant_portions[doc_id] = sorted_sentences
376
 
377
- print(f"Extracted relevant portions for Document ID {doc_id} (Similarity: {similarity:.4f}):")
378
- for i, sentence in enumerate(sorted_sentences, start=1):
379
- print(f" Portion {i}: {sentence[:100]}...") # Print first 100 characters for preview
380
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  return relevant_portions
382
-
383
  except Exception as e:
384
- print(f"Error in extract_relevant_portions: {e}")
385
- return {}
386
-
387
 
388
- def remove_duplicates(selected_parts):
389
- unique_sentences = set()
390
- unique_selected_parts = []
391
- for sentence in selected_parts:
392
- if sentence not in unique_sentences:
393
- unique_selected_parts.append(sentence)
394
- unique_sentences.add(sentence)
395
- return unique_selected_parts
396
-
397
- def extract_entities(text):
398
  try:
399
- biobert_tokenizer = models['bio_tokenizer']
400
- biobert_model = models['bio_model']
401
- inputs = biobert_tokenizer(text, return_tensors="pt")
402
- outputs = biobert_model(**inputs)
403
- predictions = torch.argmax(outputs.logits, dim=2)
404
- tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
405
- entities = [
406
- tokens[i]
407
- for i in range(len(tokens))
408
- if predictions[0][i].item() != 0 # Assuming 0 is the label for non-entity
409
- ]
410
- return entities
 
 
 
 
 
 
 
 
 
411
  except Exception as e:
412
- print(f"Error extracting entities: {e}")
413
- return []
414
-
415
- def enhance_passage_with_entities(passage, entities):
416
- return f"{passage}\n\nEntities: {', '.join(entities)}"
417
-
418
- def create_prompt(question, passage):
419
- prompt = ("""
420
- As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
421
-
422
- Passage: {passage}
423
-
424
- Question: {question}
425
-
426
- Answer:
427
- """)
428
- return prompt.format(passage=passage, question=question)
429
-
430
- def generate_answer(prompt, max_length=860, temperature=0.2):
431
- tokenizer_f = models['llm_tokenizer']
432
- model_f = models['llm_model']
433
- inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
434
- output_ids = model_f.generate(
435
- inputs.input_ids,
436
- max_length=max_length,
437
- num_return_sequences=1,
438
- temperature=temperature,
439
- pad_token_id=tokenizer_f.eos_token_id
440
- )
441
- answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
442
- passage_keywords = set(prompt.lower().split())
443
- answer_keywords = set(answer.lower().split())
444
- if passage_keywords.intersection(answer_keywords):
445
- return answer
446
- else:
447
- return "Sorry, I can't help with that."
448
-
 
 
 
 
 
 
 
 
 
 
449
  def remove_answer_prefix(text):
450
  prefix = "Answer:"
451
  if prefix in text:
@@ -511,48 +554,132 @@ async def health_check():
511
  @app.post("/api/chat")
512
  async def chat_endpoint(chat_query: ChatQuery):
513
  try:
 
 
 
 
514
  query_text = chat_query.query
515
- language_code = chat_query.language_code
 
516
  if language_code == 0:
517
- query_text = translate_ar_to_en(query_text)
518
- query_embedding = embed_query_text(query_text)
 
 
 
 
 
 
 
 
 
 
 
519
  n_results = 5
520
- embeddings_data = load_embeddings ()
521
  folder_path = 'downloaded_articles/downloaded_articles'
522
- initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
523
- document_ids = [doc_id for doc_id, _ in initial_results]
524
- document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  cross_encoder = models['cross_encoder']
526
- scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
527
- scored_documents = list(zip(scores, document_ids, document_texts))
528
- scored_documents.sort(key=lambda x: x[0], reverse=True)
529
- relevant_portions = extract_relevant_portions(query_embedding, scored_documents, embeddings_data, max_portions=3)
530
- #flattened_relevant_portions = []
531
- #for doc_id, portions in relevant_portions.items():
532
- #flattened_relevant_portions.extend(portions)
533
- unique_selected_parts = remove_duplicates(relevant_portions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  combined_parts = " ".join(unique_selected_parts)
535
- context = [query_text] + unique_selected_parts
536
- entities = extract_entities(query_text)
537
  passage = enhance_passage_with_entities(combined_parts, entities)
538
  prompt = create_prompt(query_text, passage)
539
- answer = generate_answer(prompt)
 
 
 
 
 
 
 
 
 
540
  answer_part = answer.split("Answer:")[-1].strip()
541
- cleaned_answer = remove_answer_prefix(answer_part)
542
- final_answer = remove_incomplete_sentence(cleaned_answer)
 
 
543
  if language_code == 0:
544
- final_answer = translate_en_to_ar(final_answer)
 
 
 
 
 
545
  if final_answer:
546
- print("Answer:")
547
  print(final_answer)
 
 
 
 
 
 
548
  else:
549
- print("Sorry, I can't help with that.")
550
- return {
551
- "response": f"I hope this answers your question: {final_answer}",
552
- # "conversation_id": chat_query.conversation_id,
553
- "success": True
554
- }
555
  except Exception as e:
 
556
  raise HTTPException(status_code=500, detail=str(e))
557
 
558
  @app.post("/api/resources")
 
28
  from bs4 import BeautifulSoup
29
  from huggingface_hub import hf_hub_download
30
  from safetensors.torch import load_file
31
+ from typing import List, Dict,Any,Tuple, Optional
32
  from safetensors.numpy import load_file
33
  from safetensors.torch import safe_open
34
+ from concurrent.futures import ThreadPoolExecutor
35
+ import asyncio
36
+ from functools import partial
37
  nltk.download('punkt_tab')
38
 
39
  app = FastAPI()
 
66
  content: str
67
  timestamp: str
68
 
69
+ async def run_in_threadpool(func, *args, **kwargs):
70
+ return await asyncio.get_event_loop().run_in_executor(
71
+ None, partial(func, *args, **kwargs)
72
+ )
73
+
74
  def init_nltk():
75
  try:
76
  nltk.download('punkt', quiet=True)
 
340
  print(f"Error retrieving metadata: {e}")
341
  return {}
342
 
343
+ def rerank_documents(query: str, document_ids: List[str], document_texts: List[str], cross_encoder_model) -> List[Tuple[float, str, str]]:
344
  try:
345
+ # Batch process all documents at once
346
  pairs = [(query, doc) for doc in document_texts]
347
+ scores = cross_encoder_model.predict(pairs, batch_size=8) # Increased batch size
348
  scored_documents = list(zip(scores, document_ids, document_texts))
349
  scored_documents.sort(key=lambda x: x[0], reverse=True)
 
 
 
350
  return scored_documents
351
  except Exception as e:
352
  print(f"Error reranking documents: {e}")
353
  return []
354
 
355
+ def extract_entities_batch(texts: List[str], biobert_tokenizer, biobert_model, batch_size: int = 8) -> List[List[str]]:
 
 
 
 
356
  try:
357
+ all_entities = []
358
+ for i in range(0, len(texts), batch_size):
359
+ batch_texts = texts[i:i + batch_size]
360
+ # Process multiple texts in parallel
361
+ inputs = biobert_tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512)
362
+ with torch.no_grad(): # Disable gradient calculation
363
+ outputs = biobert_model(**inputs)
364
 
365
+ predictions = torch.argmax(outputs.logits, dim=2)
 
366
 
367
+ for j, (input_ids, preds) in enumerate(zip(inputs.input_ids, predictions)):
368
+ tokens = biobert_tokenizer.convert_ids_to_tokens(input_ids)
369
+ entities = [tokens[k] for k in range(len(tokens)) if preds[k].item() != 0]
370
+ all_entities.append(entities)
371
+
372
+ return all_entities
373
+ except Exception as e:
374
+ print(f"Error in batch entity extraction: {e}")
375
+ return [[] for _ in texts]
376
 
377
+ def extract_relevant_portions(document_texts: List[str], query: str, biobert_tokenizer, biobert_model,
378
+ max_portions: int = 3, portion_size: int = 1) -> Dict[str, List[str]]:
379
+ try:
380
+ # Process query and all documents in one batch
381
+ all_texts = [query] + document_texts
382
+ all_entities = extract_entities_batch(all_texts, biobert_tokenizer, biobert_model)
383
+
384
+ query_entities = set(all_entities[0])
385
+ relevant_portions = {}
386
+
387
+ def process_document(doc_idx: int) -> Tuple[str, List[str]]:
388
+ doc_text = document_texts[doc_idx]
389
+ doc_entities = set(all_entities[doc_idx + 1]) # +1 because query was first
390
+
391
  sentences = nltk.sent_tokenize(doc_text)
392
+ doc_relevant_portions = []
 
 
 
 
 
393
 
394
+ # Score sentences based on entity overlap
395
+ sentence_scores = []
396
+ for i, sentence in enumerate(sentences):
397
+ entity_overlap = len(query_entities.intersection(doc_entities))
398
+ sentence_scores.append((entity_overlap, i))
399
+
400
+ # Sort and select top sentences
401
+ sentence_scores.sort(reverse=True)
402
+ for _, sent_idx in sentence_scores[:max_portions]:
403
+ start_idx = max(0, sent_idx - portion_size // 2)
404
+ end_idx = min(len(sentences), sent_idx + portion_size // 2 + 1)
405
+ portion = " ".join(sentences[start_idx:end_idx])
406
+ doc_relevant_portions.append(portion)
407
+
408
+ return f"Document_{doc_idx}", doc_relevant_portions
409
+
410
+ # Process documents in parallel
411
+ with ThreadPoolExecutor(max_workers=4) as executor:
412
+ results = list(executor.map(lambda x: process_document(x), range(len(document_texts))))
413
+
414
+ relevant_portions = dict(results)
415
  return relevant_portions
416
+
417
  except Exception as e:
418
+ print(f"Error extracting relevant portions: {e}")
419
+ return {f"Document_{i}": [] for i in range(len(document_texts))}
 
420
 
421
+ def generate_answer(prompt: str, tokenizer_f, model_f, max_length: int = 860, temperature: float = 0.2) -> str:
 
 
 
 
 
 
 
 
 
422
  try:
423
+ # Optimize input processing
424
+ inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True, max_length=512)
425
+
426
+ with torch.no_grad(): # Disable gradient calculation
427
+ output_ids = model_f.generate(
428
+ inputs.input_ids,
429
+ max_length=max_length,
430
+ num_return_sequences=1,
431
+ temperature=temperature,
432
+ pad_token_id=tokenizer_f.eos_token_id,
433
+ do_sample=False, # Use greedy decoding for faster generation
434
+ early_stopping=True
435
+ )
436
+
437
+ answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
438
+
439
+ # Quick relevance check
440
+ if any(word in answer.lower() for word in prompt.lower().split()):
441
+ return answer
442
+ return "I apologize, but I cannot provide a relevant answer based on the given information."
443
+
444
  except Exception as e:
445
+ print(f"Error generating answer: {e}")
446
+ return "I apologize, but I encountered an error while generating the answer."
447
+
448
+ def create_prompt(question: str, passage: str) -> str:
449
+ return f"""As a medical expert, answer the following question based only on the provided passage. Be concise and direct.
450
+ Passage: {passage}
451
+ Question: {question}
452
+ Answer:"""
453
+
454
+ def process_query_and_generate_answer(
455
+ query: str,
456
+ relevant_documents: List[Tuple[float, str, str]],
457
+ models: Dict,
458
+ max_portions: int = 3
459
+ ) -> str:
460
+ try:
461
+ # Extract relevant portions from top documents
462
+ relevant_portions = extract_relevant_portions(
463
+ [doc[2] for doc in relevant_documents[:3]], # Use top 3 documents
464
+ query,
465
+ models['bio_tokenizer'],
466
+ models['bio_model'],
467
+ max_portions=max_portions
468
+ )
469
+
470
+ # Combine relevant portions
471
+ all_portions = []
472
+ for doc_portions in relevant_portions.values():
473
+ all_portions.extend(doc_portions)
474
+
475
+ # Remove duplicates while preserving order
476
+ unique_portions = list(dict.fromkeys(all_portions))
477
+
478
+ # Create context from unique portions
479
+ context = " ".join(unique_portions[:max_portions])
480
+
481
+ # Generate and return answer
482
+ prompt = create_prompt(query, context)
483
+ return generate_answer(
484
+ prompt,
485
+ models['llm_tokenizer'],
486
+ models['llm_model']
487
+ )
488
+
489
+ except Exception as e:
490
+ print(f"Error in query processing pipeline: {e}")
491
+ return "I apologize, but I encountered an error while processing your question."
492
  def remove_answer_prefix(text):
493
  prefix = "Answer:"
494
  if prefix in text:
 
554
  @app.post("/api/chat")
555
  async def chat_endpoint(chat_query: ChatQuery):
556
  try:
557
+ # Initialize response timing
558
+ start_time = asyncio.get_event_loop().time()
559
+
560
+ # Extract query and handle translation
561
  query_text = chat_query.query
562
+ language_code = chat_query.language_code
563
+
564
  if language_code == 0:
565
+ query_text = await run_in_threadpool(translate_ar_to_en, query_text)
566
+
567
+ # Embed query and load embeddings in parallel
568
+ query_embedding_task = run_in_threadpool(embed_query_text, query_text)
569
+ embeddings_data_task = run_in_threadpool(load_embeddings)
570
+
571
+ # Wait for both tasks to complete
572
+ query_embedding, embeddings_data = await asyncio.gather(
573
+ query_embedding_task,
574
+ embeddings_data_task
575
+ )
576
+
577
+ # Initial document retrieval
578
  n_results = 5
 
579
  folder_path = 'downloaded_articles/downloaded_articles'
580
+
581
+ # Get initial results and retrieve documents
582
+ initial_results = await run_in_threadpool(
583
+ query_embeddings,
584
+ query_embedding,
585
+ embeddings_data,
586
+ n_results
587
+ )
588
+
589
+ document_ids = [doc_id for doc_id, *_ in initial_results]
590
+ document_texts = await run_in_threadpool(
591
+ retrieve_document_texts,
592
+ document_ids,
593
+ folder_path
594
+ )
595
+
596
+ # Rerank documents
597
  cross_encoder = models['cross_encoder']
598
+ scored_documents = await run_in_threadpool(
599
+ rerank_documents,
600
+ query_text,
601
+ document_ids,
602
+ document_texts,
603
+ cross_encoder
604
+ )
605
+
606
+ # Process documents and generate answer
607
+ async with asyncio.TaskGroup() as tg:
608
+ # Extract entities in parallel
609
+ entities_task = tg.create_task(
610
+ run_in_threadpool(
611
+ extract_entities_batch,
612
+ [query_text] + [doc[2] for doc in scored_documents[:3]],
613
+ models['bio_tokenizer'],
614
+ models['bio_model']
615
+ )
616
+ )
617
+
618
+ # Extract relevant portions
619
+ portions_task = tg.create_task(
620
+ run_in_threadpool(
621
+ extract_relevant_portions,
622
+ [doc[2] for doc in scored_documents[:3]],
623
+ query_text,
624
+ models['bio_tokenizer'],
625
+ models['bio_model']
626
+ )
627
+ )
628
+
629
+ entities = (await entities_task)[0] # First item is query entities
630
+ relevant_portions = await portions_task
631
+
632
+ # Flatten and process portions
633
+ flattened_portions = []
634
+ for doc_portions in relevant_portions.values():
635
+ flattened_portions.extend(doc_portions)
636
+
637
+ unique_selected_parts = list(dict.fromkeys(flattened_portions))
638
  combined_parts = " ".join(unique_selected_parts)
639
+
640
+ # Enhance passage and create prompt
641
  passage = enhance_passage_with_entities(combined_parts, entities)
642
  prompt = create_prompt(query_text, passage)
643
+
644
+ # Generate answer
645
+ answer = await run_in_threadpool(
646
+ generate_answer,
647
+ prompt,
648
+ models['llm_tokenizer'],
649
+ models['llm_model']
650
+ )
651
+
652
+ # Process answer
653
  answer_part = answer.split("Answer:")[-1].strip()
654
+ cleaned_answer = await run_in_threadpool(remove_answer_prefix, answer_part)
655
+ final_answer = await run_in_threadpool(remove_incomplete_sentence, cleaned_answer)
656
+
657
+ # Handle translation if needed
658
  if language_code == 0:
659
+ final_answer = await run_in_threadpool(translate_en_to_ar, final_answer)
660
+
661
+ # Calculate response time
662
+ end_time = asyncio.get_event_loop().time()
663
+ response_time = end_time - start_time
664
+
665
  if final_answer:
666
+ print(f"Answer generated in {response_time:.2f} seconds")
667
  print(final_answer)
668
+
669
+ return {
670
+ "response": f"I hope this answers your question: {final_answer}",
671
+ "success": True,
672
+ "response_time": response_time
673
+ }
674
  else:
675
+ return {
676
+ "response": "Sorry, I can't help with that.",
677
+ "success": False,
678
+ "response_time": response_time
679
+ }
680
+
681
  except Exception as e:
682
+ print(f"Error in chat endpoint: {str(e)}")
683
  raise HTTPException(status_code=500, detail=str(e))
684
 
685
  @app.post("/api/resources")