Spaces:
Sleeping
Sleeping
Commit
·
783c1a9
1
Parent(s):
ed65920
Update app.py
Browse files
app.py
CHANGED
@@ -227,9 +227,6 @@ def translate_text(text, source_to_target='ar_to_en'):
|
|
227 |
def embed_query_text(query_text):
|
228 |
query_embedding = embedding.encode([query_text])
|
229 |
return query_embedding
|
230 |
-
|
231 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
232 |
-
import numpy as np
|
233 |
|
234 |
def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
|
235 |
embeddings_data = embeddings_data or data.get('embeddings', {})
|
@@ -446,20 +443,45 @@ async def health_check():
|
|
446 |
async def chat_endpoint(chat_query: ChatQuery):
|
447 |
try:
|
448 |
query_text = chat_query.query
|
|
|
|
|
449 |
query_embedding = embed_query_text(query_text)
|
|
|
|
|
450 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
451 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
|
|
|
|
452 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
flattened_relevant_portions = []
|
454 |
for doc_id, portions in relevant_portions.items():
|
455 |
flattened_relevant_portions.extend(portions)
|
456 |
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
457 |
combined_parts = " ".join(unique_selected_parts)
|
458 |
-
|
|
|
459 |
entities = extract_entities(query_text)
|
460 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
|
|
461 |
prompt = create_prompt(query_text, passage)
|
462 |
answer, generation_time = generate_answer(prompt)
|
|
|
|
|
463 |
answer_part = answer.split("Answer:")[-1].strip()
|
464 |
cleaned_answer = remove_answer_prefix(answer_part)
|
465 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
@@ -469,9 +491,11 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
469 |
"conversation_id": chat_query.conversation_id,
|
470 |
"success": True
|
471 |
}
|
|
|
472 |
except Exception as e:
|
473 |
raise HTTPException(status_code=500, detail=str(e))
|
474 |
|
|
|
475 |
@app.post("/api/resources")
|
476 |
async def resources_endpoint(profile: MedicalProfile):
|
477 |
try:
|
|
|
227 |
def embed_query_text(query_text):
|
228 |
query_embedding = embedding.encode([query_text])
|
229 |
return query_embedding
|
|
|
|
|
|
|
230 |
|
231 |
def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
|
232 |
embeddings_data = embeddings_data or data.get('embeddings', {})
|
|
|
443 |
async def chat_endpoint(chat_query: ChatQuery):
|
444 |
try:
|
445 |
query_text = chat_query.query
|
446 |
+
|
447 |
+
# Step 1: Embed the query
|
448 |
query_embedding = embed_query_text(query_text)
|
449 |
+
|
450 |
+
# Step 2: Retrieve top results using embeddings similarity
|
451 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
452 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
453 |
+
|
454 |
+
# Step 3: Fetch document texts
|
455 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
456 |
+
|
457 |
+
# Step 4: Re-rank documents (optional, if reranking is used)
|
458 |
+
reranked_documents = rerank_documents(query_text, document_ids, document_texts, cross_encoder_model)
|
459 |
+
|
460 |
+
# Step 5: Extract relevant portions (if enabled)
|
461 |
+
relevant_portions = extract_relevant_portions(
|
462 |
+
document_texts,
|
463 |
+
query=query_text,
|
464 |
+
max_portions=3,
|
465 |
+
portion_size=1,
|
466 |
+
min_query_words=1
|
467 |
+
)
|
468 |
+
|
469 |
+
# Step 6: Flatten and clean relevant portions
|
470 |
flattened_relevant_portions = []
|
471 |
for doc_id, portions in relevant_portions.items():
|
472 |
flattened_relevant_portions.extend(portions)
|
473 |
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
474 |
combined_parts = " ".join(unique_selected_parts)
|
475 |
+
|
476 |
+
# Step 7: Extract entities and enhance passage
|
477 |
entities = extract_entities(query_text)
|
478 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
479 |
+
|
480 |
+
# Step 8: Create prompt and generate answer
|
481 |
prompt = create_prompt(query_text, passage)
|
482 |
answer, generation_time = generate_answer(prompt)
|
483 |
+
|
484 |
+
# Step 9: Clean the generated answer
|
485 |
answer_part = answer.split("Answer:")[-1].strip()
|
486 |
cleaned_answer = remove_answer_prefix(answer_part)
|
487 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
|
|
491 |
"conversation_id": chat_query.conversation_id,
|
492 |
"success": True
|
493 |
}
|
494 |
+
|
495 |
except Exception as e:
|
496 |
raise HTTPException(status_code=500, detail=str(e))
|
497 |
|
498 |
+
|
499 |
@app.post("/api/resources")
|
500 |
async def resources_endpoint(profile: MedicalProfile):
|
501 |
try:
|