thechaiexperiment commited on
Commit
783c1a9
·
1 Parent(s): ed65920

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -227,9 +227,6 @@ def translate_text(text, source_to_target='ar_to_en'):
227
  def embed_query_text(query_text):
228
  query_embedding = embedding.encode([query_text])
229
  return query_embedding
230
-
231
- from sklearn.metrics.pairwise import cosine_similarity
232
- import numpy as np
233
 
234
  def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
235
  embeddings_data = embeddings_data or data.get('embeddings', {})
@@ -446,20 +443,45 @@ async def health_check():
446
  async def chat_endpoint(chat_query: ChatQuery):
447
  try:
448
  query_text = chat_query.query
 
 
449
  query_embedding = embed_query_text(query_text)
 
 
450
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
451
  document_ids = [doc_id for doc_id, _ in initial_results]
 
 
452
  document_texts = retrieve_document_texts(document_ids, folder_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  flattened_relevant_portions = []
454
  for doc_id, portions in relevant_portions.items():
455
  flattened_relevant_portions.extend(portions)
456
  unique_selected_parts = remove_duplicates(flattened_relevant_portions)
457
  combined_parts = " ".join(unique_selected_parts)
458
- context = [query_text] + unique_selected_parts
 
459
  entities = extract_entities(query_text)
460
  passage = enhance_passage_with_entities(combined_parts, entities)
 
 
461
  prompt = create_prompt(query_text, passage)
462
  answer, generation_time = generate_answer(prompt)
 
 
463
  answer_part = answer.split("Answer:")[-1].strip()
464
  cleaned_answer = remove_answer_prefix(answer_part)
465
  final_answer = remove_incomplete_sentence(cleaned_answer)
@@ -469,9 +491,11 @@ async def chat_endpoint(chat_query: ChatQuery):
469
  "conversation_id": chat_query.conversation_id,
470
  "success": True
471
  }
 
472
  except Exception as e:
473
  raise HTTPException(status_code=500, detail=str(e))
474
 
 
475
  @app.post("/api/resources")
476
  async def resources_endpoint(profile: MedicalProfile):
477
  try:
 
227
  def embed_query_text(query_text):
228
  query_embedding = embedding.encode([query_text])
229
  return query_embedding
 
 
 
230
 
231
  def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
232
  embeddings_data = embeddings_data or data.get('embeddings', {})
 
443
  async def chat_endpoint(chat_query: ChatQuery):
444
  try:
445
  query_text = chat_query.query
446
+
447
+ # Step 1: Embed the query
448
  query_embedding = embed_query_text(query_text)
449
+
450
+ # Step 2: Retrieve top results using embeddings similarity
451
  initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
452
  document_ids = [doc_id for doc_id, _ in initial_results]
453
+
454
+ # Step 3: Fetch document texts
455
  document_texts = retrieve_document_texts(document_ids, folder_path)
456
+
457
+ # Step 4: Re-rank documents (optional, if reranking is used)
458
+ reranked_documents = rerank_documents(query_text, document_ids, document_texts, cross_encoder_model)
459
+
460
+ # Step 5: Extract relevant portions (if enabled)
461
+ relevant_portions = extract_relevant_portions(
462
+ document_texts,
463
+ query=query_text,
464
+ max_portions=3,
465
+ portion_size=1,
466
+ min_query_words=1
467
+ )
468
+
469
+ # Step 6: Flatten and clean relevant portions
470
  flattened_relevant_portions = []
471
  for doc_id, portions in relevant_portions.items():
472
  flattened_relevant_portions.extend(portions)
473
  unique_selected_parts = remove_duplicates(flattened_relevant_portions)
474
  combined_parts = " ".join(unique_selected_parts)
475
+
476
+ # Step 7: Extract entities and enhance passage
477
  entities = extract_entities(query_text)
478
  passage = enhance_passage_with_entities(combined_parts, entities)
479
+
480
+ # Step 8: Create prompt and generate answer
481
  prompt = create_prompt(query_text, passage)
482
  answer, generation_time = generate_answer(prompt)
483
+
484
+ # Step 9: Clean the generated answer
485
  answer_part = answer.split("Answer:")[-1].strip()
486
  cleaned_answer = remove_answer_prefix(answer_part)
487
  final_answer = remove_incomplete_sentence(cleaned_answer)
 
491
  "conversation_id": chat_query.conversation_id,
492
  "success": True
493
  }
494
+
495
  except Exception as e:
496
  raise HTTPException(status_code=500, detail=str(e))
497
 
498
+
499
  @app.post("/api/resources")
500
  async def resources_endpoint(profile: MedicalProfile):
501
  try: