thechaiexperiment commited on
Commit
e1033ec
·
verified ·
1 Parent(s): 089f890

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -13
app.py CHANGED
@@ -51,11 +51,12 @@ class QueryRequest(BaseModel):
51
  class MedicalProfile(BaseModel):
52
  conditions: str
53
  daily_symptoms: str
 
54
 
55
  class ChatQuery(BaseModel):
56
  query: str
57
  language_code: int = 1
58
- conversation_id: str
59
 
60
  class ChatMessage(BaseModel):
61
  role: str
@@ -219,7 +220,7 @@ def embed_query_text(query_text):
219
  query_embedding = embedding.encode([query_text])
220
  return query_embedding
221
 
222
- def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
223
  embeddings_data = load_embeddings()
224
  if not embeddings_data:
225
  print("No embeddings data available.")
@@ -234,7 +235,7 @@ def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
234
  print(f"Error in query_embeddings: {e}")
235
  return []
236
 
237
- def query_recipes_embeddings(query_embedding, embeddings_data, n_results = 5):
238
  embeddings_data = load_recipes_embeddings()
239
  if embeddings_data is None:
240
  print("No embeddings data available.")
@@ -365,7 +366,7 @@ def match_entities(query_entities, sentence_entities):
365
  print(f"Error matching entities: {e}")
366
  return 0
367
 
368
- def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
369
  relevant_portions = {}
370
  query_entities = extract_entities(query)
371
  print(f"Extracted Query Entities: {query_entities}")
@@ -466,6 +467,40 @@ def remove_incomplete_sentence(text):
466
  return text[:last_period_index + 1].strip()
467
  return text
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  @app.get("/")
470
  async def root():
471
  return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
@@ -485,18 +520,20 @@ async def health_check():
485
  async def chat_endpoint(chat_query: ChatQuery):
486
  try:
487
  query_text = chat_query.query
488
- language_code = chat_query.language_code
 
 
489
  query_embedding = embed_query_text(query_text)
490
  embeddings_data = load_embeddings ()
491
  folder_path = 'downloaded_articles/downloaded_articles'
492
- initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
493
  document_ids = [doc_id for doc_id, _ in initial_results]
494
  document_texts = retrieve_document_texts(document_ids, folder_path)
495
  cross_encoder = models['cross_encoder']
496
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
497
  scored_documents = list(zip(scores, document_ids, document_texts))
498
  scored_documents.sort(key=lambda x: x[0], reverse=True)
499
- relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
500
  flattened_relevant_portions = []
501
  for doc_id, portions in relevant_portions.items():
502
  flattened_relevant_portions.extend(portions)
@@ -518,8 +555,8 @@ async def chat_endpoint(chat_query: ChatQuery):
518
  else:
519
  print("Sorry, I can't help with that.")
520
  return {
521
- "response": final_answer,
522
- "conversation_id": chat_query.conversation_id,
523
  "success": True
524
  }
525
  except Exception as e:
@@ -529,13 +566,14 @@ async def chat_endpoint(chat_query: ChatQuery):
529
  async def resources_endpoint(profile: MedicalProfile):
530
  try:
531
  query_text = profile.conditions + " " + profile.daily_symptoms
 
532
  print(f"Generated query text: {query_text}")
533
  query_embedding = embed_query_text(query_text)
534
  if query_embedding is None:
535
  raise ValueError("Failed to generate query embedding.")
536
  embeddings_data = load_embeddings()
537
  folder_path = 'downloaded_articles/downloaded_articles'
538
- initial_results = query_embeddings(query_embedding, embeddings_data, n_results=6)
539
  if not initial_results:
540
  raise ValueError("No relevant documents found.")
541
  document_ids = [doc_id for doc_id, _ in initial_results]
@@ -570,17 +608,18 @@ async def resources_endpoint(profile: MedicalProfile):
570
  async def recipes_endpoint(profile: MedicalProfile):
571
  try:
572
  recipe_query = (
573
- f"Recipes foods and meals suitable for someone with: "
574
  f"{profile.conditions} and experiencing {profile.daily_symptoms}"
575
  )
576
  query_text = recipe_query
577
  print(f"Generated query text: {query_text}")
 
578
  query_embedding = embed_query_text(query_text)
579
  if query_embedding is None:
580
  raise ValueError("Failed to generate query embedding.")
581
  embeddings_data = load_recipes_embeddings()
582
  folder_path = 'downloaded_articles/downloaded_articles'
583
- initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=5)
584
  if not initial_results:
585
  raise ValueError("No relevant recipes found.")
586
  print("Initial results (document indices and similarities):")
@@ -590,8 +629,19 @@ async def recipes_endpoint(profile: MedicalProfile):
590
  metadata_path = 'recipes_metadata.xlsx'
591
  metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
592
  print(f"Retrieved Metadata: {metadata}")
 
 
 
 
 
 
 
 
 
 
 
593
  response = {
594
- "metadata": metadata,
595
  }
596
  return response
597
  except ValueError as ve:
 
51
  class MedicalProfile(BaseModel):
52
  conditions: str
53
  daily_symptoms: str
54
+ count: int
55
 
56
  class ChatQuery(BaseModel):
57
  query: str
58
  language_code: int = 1
59
+ #conversation_id: str
60
 
61
  class ChatMessage(BaseModel):
62
  role: str
 
220
  query_embedding = embedding.encode([query_text])
221
  return query_embedding
222
 
223
+ def query_embeddings(query_embedding, embeddings_data=None, n_results):
224
  embeddings_data = load_embeddings()
225
  if not embeddings_data:
226
  print("No embeddings data available.")
 
235
  print(f"Error in query_embeddings: {e}")
236
  return []
237
 
238
+ def query_recipes_embeddings(query_embedding, embeddings_data, n_results):
239
  embeddings_data = load_recipes_embeddings()
240
  if embeddings_data is None:
241
  print("No embeddings data available.")
 
366
  print(f"Error matching entities: {e}")
367
  return 0
368
 
369
+ def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
370
  relevant_portions = {}
371
  query_entities = extract_entities(query)
372
  print(f"Extracted Query Entities: {query_entities}")
 
467
  return text[:last_period_index + 1].strip()
468
  return text
469
 
470
+ def translate_ar_to_en(text):
471
+ try:
472
+ ar_to_en_tokenizer = models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
473
+ ar_to_en_model= models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
474
+ inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
475
+ translated_ids = ar_to_en_model.generate(
476
+ inputs.input_ids,
477
+ max_length=512,
478
+ num_beams=4,
479
+ early_stopping=True
480
+ )
481
+ translated_text = ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
482
+ return translated_text
483
+ except Exception as e:
484
+ print(f"Error during Arabic to English translation: {e}")
485
+ return None
486
+
487
+ def translate_en_to_ar(text):
488
+ try:
489
+ en_to_ar_tokenizer = models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
490
+ en_to_ar_model = models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
491
+ inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
492
+ translated_ids = en_to_ar_model.generate(
493
+ inputs.input_ids,
494
+ max_length=512,
495
+ num_beams=4,
496
+ early_stopping=True
497
+ )
498
+ translated_text = en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
499
+ return translated_text
500
+ except Exception as e:
501
+ print(f"Error during English to Arabic translation: {e}")
502
+ return None
503
+
504
  @app.get("/")
505
  async def root():
506
  return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
 
520
  async def chat_endpoint(chat_query: ChatQuery):
521
  try:
522
  query_text = chat_query.query
523
+ language_code = chat_query.language_code
524
+ if language_code == 0:
525
+ query_text = translate_ar_to_en(query_text)
526
  query_embedding = embed_query_text(query_text)
527
  embeddings_data = load_embeddings ()
528
  folder_path = 'downloaded_articles/downloaded_articles'
529
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
530
  document_ids = [doc_id for doc_id, _ in initial_results]
531
  document_texts = retrieve_document_texts(document_ids, folder_path)
532
  cross_encoder = models['cross_encoder']
533
  scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
534
  scored_documents = list(zip(scores, document_ids, document_texts))
535
  scored_documents.sort(key=lambda x: x[0], reverse=True)
536
+ relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=2)
537
  flattened_relevant_portions = []
538
  for doc_id, portions in relevant_portions.items():
539
  flattened_relevant_portions.extend(portions)
 
555
  else:
556
  print("Sorry, I can't help with that.")
557
  return {
558
+ "response": "I hope this answers your question:" {final_answer},
559
+ #"conversation_id": chat_query.conversation_id,
560
  "success": True
561
  }
562
  except Exception as e:
 
566
  async def resources_endpoint(profile: MedicalProfile):
567
  try:
568
  query_text = profile.conditions + " " + profile.daily_symptoms
569
+ n_results = profile.count
570
  print(f"Generated query text: {query_text}")
571
  query_embedding = embed_query_text(query_text)
572
  if query_embedding is None:
573
  raise ValueError("Failed to generate query embedding.")
574
  embeddings_data = load_embeddings()
575
  folder_path = 'downloaded_articles/downloaded_articles'
576
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
577
  if not initial_results:
578
  raise ValueError("No relevant documents found.")
579
  document_ids = [doc_id for doc_id, _ in initial_results]
 
608
  async def recipes_endpoint(profile: MedicalProfile):
609
  try:
610
  recipe_query = (
611
+ f"Recipes and foods for: "
612
  f"{profile.conditions} and experiencing {profile.daily_symptoms}"
613
  )
614
  query_text = recipe_query
615
  print(f"Generated query text: {query_text}")
616
+ n_results = profile.count
617
  query_embedding = embed_query_text(query_text)
618
  if query_embedding is None:
619
  raise ValueError("Failed to generate query embedding.")
620
  embeddings_data = load_recipes_embeddings()
621
  folder_path = 'downloaded_articles/downloaded_articles'
622
+ initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results)
623
  if not initial_results:
624
  raise ValueError("No relevant recipes found.")
625
  print("Initial results (document indices and similarities):")
 
629
  metadata_path = 'recipes_metadata.xlsx'
630
  metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
631
  print(f"Retrieved Metadata: {metadata}")
632
+ recipes = metadata
633
+ document_texts = retrieve_rec_texts(document_indices, folder_path)
634
+ if not document_texts:
635
+ raise ValueError("Failed to retrieve document texts.")
636
+ cross_encoder = models['cross_encoder']
637
+ scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
638
+ scores = [float(score) for score in scores]
639
+ for i, recipe in enumerate(recipes):
640
+ recipe["score"] = scores[i] if i < len(scores) else 0.0
641
+ recipes.sort(key=lambda x: x["score"], reverse=True)
642
+ return {"recipes": recipes[:5], "success": True}
643
  response = {
644
+ "recipes": recipes,
645
  }
646
  return response
647
  except ValueError as ve: