Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -51,11 +51,12 @@ class QueryRequest(BaseModel):
|
|
51 |
class MedicalProfile(BaseModel):
|
52 |
conditions: str
|
53 |
daily_symptoms: str
|
|
|
54 |
|
55 |
class ChatQuery(BaseModel):
|
56 |
query: str
|
57 |
language_code: int = 1
|
58 |
-
conversation_id: str
|
59 |
|
60 |
class ChatMessage(BaseModel):
|
61 |
role: str
|
@@ -219,7 +220,7 @@ def embed_query_text(query_text):
|
|
219 |
query_embedding = embedding.encode([query_text])
|
220 |
return query_embedding
|
221 |
|
222 |
-
def query_embeddings(query_embedding, embeddings_data=None, n_results
|
223 |
embeddings_data = load_embeddings()
|
224 |
if not embeddings_data:
|
225 |
print("No embeddings data available.")
|
@@ -234,7 +235,7 @@ def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
|
|
234 |
print(f"Error in query_embeddings: {e}")
|
235 |
return []
|
236 |
|
237 |
-
def query_recipes_embeddings(query_embedding, embeddings_data, n_results
|
238 |
embeddings_data = load_recipes_embeddings()
|
239 |
if embeddings_data is None:
|
240 |
print("No embeddings data available.")
|
@@ -365,7 +366,7 @@ def match_entities(query_entities, sentence_entities):
|
|
365 |
print(f"Error matching entities: {e}")
|
366 |
return 0
|
367 |
|
368 |
-
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=
|
369 |
relevant_portions = {}
|
370 |
query_entities = extract_entities(query)
|
371 |
print(f"Extracted Query Entities: {query_entities}")
|
@@ -466,6 +467,40 @@ def remove_incomplete_sentence(text):
|
|
466 |
return text[:last_period_index + 1].strip()
|
467 |
return text
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
@app.get("/")
|
470 |
async def root():
|
471 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
@@ -485,18 +520,20 @@ async def health_check():
|
|
485 |
async def chat_endpoint(chat_query: ChatQuery):
|
486 |
try:
|
487 |
query_text = chat_query.query
|
488 |
-
language_code = chat_query.language_code
|
|
|
|
|
489 |
query_embedding = embed_query_text(query_text)
|
490 |
embeddings_data = load_embeddings ()
|
491 |
folder_path = 'downloaded_articles/downloaded_articles'
|
492 |
-
initial_results = query_embeddings(query_embedding, embeddings_data, n_results
|
493 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
494 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
495 |
cross_encoder = models['cross_encoder']
|
496 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
497 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
498 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
499 |
-
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=
|
500 |
flattened_relevant_portions = []
|
501 |
for doc_id, portions in relevant_portions.items():
|
502 |
flattened_relevant_portions.extend(portions)
|
@@ -518,8 +555,8 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
518 |
else:
|
519 |
print("Sorry, I can't help with that.")
|
520 |
return {
|
521 |
-
"response": final_answer,
|
522 |
-
"conversation_id": chat_query.conversation_id,
|
523 |
"success": True
|
524 |
}
|
525 |
except Exception as e:
|
@@ -529,13 +566,14 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
529 |
async def resources_endpoint(profile: MedicalProfile):
|
530 |
try:
|
531 |
query_text = profile.conditions + " " + profile.daily_symptoms
|
|
|
532 |
print(f"Generated query text: {query_text}")
|
533 |
query_embedding = embed_query_text(query_text)
|
534 |
if query_embedding is None:
|
535 |
raise ValueError("Failed to generate query embedding.")
|
536 |
embeddings_data = load_embeddings()
|
537 |
folder_path = 'downloaded_articles/downloaded_articles'
|
538 |
-
initial_results = query_embeddings(query_embedding, embeddings_data, n_results
|
539 |
if not initial_results:
|
540 |
raise ValueError("No relevant documents found.")
|
541 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
@@ -570,17 +608,18 @@ async def resources_endpoint(profile: MedicalProfile):
|
|
570 |
async def recipes_endpoint(profile: MedicalProfile):
|
571 |
try:
|
572 |
recipe_query = (
|
573 |
-
f"Recipes
|
574 |
f"{profile.conditions} and experiencing {profile.daily_symptoms}"
|
575 |
)
|
576 |
query_text = recipe_query
|
577 |
print(f"Generated query text: {query_text}")
|
|
|
578 |
query_embedding = embed_query_text(query_text)
|
579 |
if query_embedding is None:
|
580 |
raise ValueError("Failed to generate query embedding.")
|
581 |
embeddings_data = load_recipes_embeddings()
|
582 |
folder_path = 'downloaded_articles/downloaded_articles'
|
583 |
-
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results
|
584 |
if not initial_results:
|
585 |
raise ValueError("No relevant recipes found.")
|
586 |
print("Initial results (document indices and similarities):")
|
@@ -590,8 +629,19 @@ async def recipes_endpoint(profile: MedicalProfile):
|
|
590 |
metadata_path = 'recipes_metadata.xlsx'
|
591 |
metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
|
592 |
print(f"Retrieved Metadata: {metadata}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
response = {
|
594 |
-
"
|
595 |
}
|
596 |
return response
|
597 |
except ValueError as ve:
|
|
|
51 |
class MedicalProfile(BaseModel):
|
52 |
conditions: str
|
53 |
daily_symptoms: str
|
54 |
+
count: int
|
55 |
|
56 |
class ChatQuery(BaseModel):
|
57 |
query: str
|
58 |
language_code: int = 1
|
59 |
+
#conversation_id: str
|
60 |
|
61 |
class ChatMessage(BaseModel):
|
62 |
role: str
|
|
|
220 |
query_embedding = embedding.encode([query_text])
|
221 |
return query_embedding
|
222 |
|
223 |
+
def query_embeddings(query_embedding, embeddings_data=None, n_results):
|
224 |
embeddings_data = load_embeddings()
|
225 |
if not embeddings_data:
|
226 |
print("No embeddings data available.")
|
|
|
235 |
print(f"Error in query_embeddings: {e}")
|
236 |
return []
|
237 |
|
238 |
+
def query_recipes_embeddings(query_embedding, embeddings_data, n_results):
|
239 |
embeddings_data = load_recipes_embeddings()
|
240 |
if embeddings_data is None:
|
241 |
print("No embeddings data available.")
|
|
|
366 |
print(f"Error matching entities: {e}")
|
367 |
return 0
|
368 |
|
369 |
+
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
|
370 |
relevant_portions = {}
|
371 |
query_entities = extract_entities(query)
|
372 |
print(f"Extracted Query Entities: {query_entities}")
|
|
|
467 |
return text[:last_period_index + 1].strip()
|
468 |
return text
|
469 |
|
470 |
+
def translate_ar_to_en(text):
|
471 |
+
try:
|
472 |
+
ar_to_en_tokenizer = models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
473 |
+
ar_to_en_model= models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
474 |
+
inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
475 |
+
translated_ids = ar_to_en_model.generate(
|
476 |
+
inputs.input_ids,
|
477 |
+
max_length=512,
|
478 |
+
num_beams=4,
|
479 |
+
early_stopping=True
|
480 |
+
)
|
481 |
+
translated_text = ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
482 |
+
return translated_text
|
483 |
+
except Exception as e:
|
484 |
+
print(f"Error during Arabic to English translation: {e}")
|
485 |
+
return None
|
486 |
+
|
487 |
+
def translate_en_to_ar(text):
|
488 |
+
try:
|
489 |
+
en_to_ar_tokenizer = models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
490 |
+
en_to_ar_model = models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
491 |
+
inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
492 |
+
translated_ids = en_to_ar_model.generate(
|
493 |
+
inputs.input_ids,
|
494 |
+
max_length=512,
|
495 |
+
num_beams=4,
|
496 |
+
early_stopping=True
|
497 |
+
)
|
498 |
+
translated_text = en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
499 |
+
return translated_text
|
500 |
+
except Exception as e:
|
501 |
+
print(f"Error during English to Arabic translation: {e}")
|
502 |
+
return None
|
503 |
+
|
504 |
@app.get("/")
|
505 |
async def root():
|
506 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
|
|
520 |
async def chat_endpoint(chat_query: ChatQuery):
|
521 |
try:
|
522 |
query_text = chat_query.query
|
523 |
+
language_code = chat_query.language_code
|
524 |
+
if language_code == 0:
|
525 |
+
query_text = translate_ar_to_en(query_text)
|
526 |
query_embedding = embed_query_text(query_text)
|
527 |
embeddings_data = load_embeddings ()
|
528 |
folder_path = 'downloaded_articles/downloaded_articles'
|
529 |
+
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
530 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
531 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
532 |
cross_encoder = models['cross_encoder']
|
533 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
534 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
535 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
536 |
+
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
|
537 |
flattened_relevant_portions = []
|
538 |
for doc_id, portions in relevant_portions.items():
|
539 |
flattened_relevant_portions.extend(portions)
|
|
|
555 |
else:
|
556 |
print("Sorry, I can't help with that.")
|
557 |
return {
|
558 |
+
"response": "I hope this answers your question:" {final_answer},
|
559 |
+
#"conversation_id": chat_query.conversation_id,
|
560 |
"success": True
|
561 |
}
|
562 |
except Exception as e:
|
|
|
566 |
async def resources_endpoint(profile: MedicalProfile):
|
567 |
try:
|
568 |
query_text = profile.conditions + " " + profile.daily_symptoms
|
569 |
+
n_results = profile.count
|
570 |
print(f"Generated query text: {query_text}")
|
571 |
query_embedding = embed_query_text(query_text)
|
572 |
if query_embedding is None:
|
573 |
raise ValueError("Failed to generate query embedding.")
|
574 |
embeddings_data = load_embeddings()
|
575 |
folder_path = 'downloaded_articles/downloaded_articles'
|
576 |
+
initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
|
577 |
if not initial_results:
|
578 |
raise ValueError("No relevant documents found.")
|
579 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
|
|
608 |
async def recipes_endpoint(profile: MedicalProfile):
|
609 |
try:
|
610 |
recipe_query = (
|
611 |
+
f"Recipes and foods for: "
|
612 |
f"{profile.conditions} and experiencing {profile.daily_symptoms}"
|
613 |
)
|
614 |
query_text = recipe_query
|
615 |
print(f"Generated query text: {query_text}")
|
616 |
+
n_results = profile.count
|
617 |
query_embedding = embed_query_text(query_text)
|
618 |
if query_embedding is None:
|
619 |
raise ValueError("Failed to generate query embedding.")
|
620 |
embeddings_data = load_recipes_embeddings()
|
621 |
folder_path = 'downloaded_articles/downloaded_articles'
|
622 |
+
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results)
|
623 |
if not initial_results:
|
624 |
raise ValueError("No relevant recipes found.")
|
625 |
print("Initial results (document indices and similarities):")
|
|
|
629 |
metadata_path = 'recipes_metadata.xlsx'
|
630 |
metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
|
631 |
print(f"Retrieved Metadata: {metadata}")
|
632 |
+
recipes = metadata
|
633 |
+
document_texts = retrieve_rec_texts(document_indices, folder_path)
|
634 |
+
if not document_texts:
|
635 |
+
raise ValueError("Failed to retrieve document texts.")
|
636 |
+
cross_encoder = models['cross_encoder']
|
637 |
+
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
638 |
+
scores = [float(score) for score in scores]
|
639 |
+
for i, recipe in enumerate(recipes):
|
640 |
+
recipe["score"] = scores[i] if i < len(scores) else 0.0
|
641 |
+
recipes.sort(key=lambda x: x["score"], reverse=True)
|
642 |
+
return {"recipes": recipes[:5], "success": True}
|
643 |
response = {
|
644 |
+
"recipes": recipes,
|
645 |
}
|
646 |
return response
|
647 |
except ValueError as ve:
|