Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -167,9 +167,23 @@ def normalize_key(key: str) -> str:
|
|
167 |
return key
|
168 |
|
169 |
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
try:
|
172 |
embeddings_path = 'recipes_embeddings.safetensors'
|
|
|
|
|
173 |
if not os.path.exists(embeddings_path):
|
174 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
175 |
embeddings_path = hf_hub_download(
|
@@ -178,30 +192,30 @@ def load_recipes_embeddings() -> Optional[Dict[str, np.ndarray]]:
|
|
178 |
repo_type="space"
|
179 |
)
|
180 |
|
181 |
-
#
|
182 |
-
embeddings =
|
183 |
-
from safetensors.numpy import safe_open
|
184 |
-
with safe_open(embeddings_path, framework="pt") as f:
|
185 |
-
keys = list(f.keys())
|
186 |
-
for key in keys:
|
187 |
-
try:
|
188 |
-
normalized_key = normalize_key(key)
|
189 |
-
tensor = f.get_tensor(key)
|
190 |
-
embeddings[normalized_key] = tensor.numpy()
|
191 |
-
except Exception as key_error:
|
192 |
-
print(f"Failed to process key {key}: {key_error}")
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
198 |
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
except Exception as e:
|
202 |
print(f"Error loading embeddings: {e}")
|
203 |
return None
|
204 |
|
|
|
|
|
205 |
def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
|
206 |
"""Load document data from HTML articles in a specified folder."""
|
207 |
try:
|
@@ -295,19 +309,42 @@ def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
|
|
295 |
print(f"Error in query_embeddings: {e}")
|
296 |
return []
|
297 |
|
298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
embeddings_data = load_recipes_embeddings()
|
300 |
-
if
|
301 |
print("No embeddings data available.")
|
302 |
return []
|
|
|
303 |
try:
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
307 |
top_indices = similarities.argsort()[-n_results:][::-1]
|
308 |
-
|
|
|
|
|
|
|
309 |
except Exception as e:
|
310 |
-
print(f"Error in
|
311 |
return []
|
312 |
|
313 |
def get_page_title(url):
|
@@ -342,31 +379,68 @@ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded
|
|
342 |
texts.append("")
|
343 |
return texts
|
344 |
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
|
@@ -773,6 +847,19 @@ async def resources_endpoint(profile: MedicalProfile):
|
|
773 |
|
774 |
|
775 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
776 |
@app.post("/api/recipes")
|
777 |
async def recipes_endpoint(profile: MedicalProfile):
|
778 |
try:
|
@@ -795,55 +882,80 @@ async def recipes_endpoint(profile: MedicalProfile):
|
|
795 |
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=10)
|
796 |
if not initial_results:
|
797 |
raise ValueError("No relevant recipes found.")
|
|
|
798 |
print(initial_results)
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
|
|
|
|
804 |
if not document_texts:
|
805 |
raise ValueError("Failed to retrieve document texts.")
|
|
|
806 |
print(document_texts)
|
807 |
-
|
808 |
-
|
809 |
-
file_path = 'recipes_metadata.xlsx'
|
810 |
-
metadata_path = 'recipes_metadata.xlsx'
|
811 |
-
metadata_df = pd.read_excel(file_path)
|
812 |
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
|
|
|
813 |
print(relevant_portions)
|
|
|
814 |
flattened_relevant_portions = []
|
815 |
for doc_id, portions in relevant_portions.items():
|
816 |
flattened_relevant_portions.extend(portions)
|
817 |
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
|
|
818 |
print(unique_selected_parts)
|
|
|
819 |
combined_parts = " ".join(unique_selected_parts)
|
|
|
820 |
print(combined_parts)
|
|
|
821 |
context = [query_text] + unique_selected_parts
|
|
|
822 |
print(context)
|
|
|
|
|
823 |
entities = extract_entities(query_text)
|
|
|
824 |
print(entities)
|
|
|
|
|
825 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
|
|
826 |
print(passage)
|
|
|
|
|
827 |
prompt = create_prompt(query_text, passage)
|
|
|
828 |
print(prompt)
|
|
|
|
|
829 |
answer = generate_answer(prompt)
|
|
|
830 |
print(answer)
|
|
|
|
|
831 |
answer_part = answer.split("Answer:")[-1].strip()
|
832 |
-
print(answer_part)
|
833 |
cleaned_answer = remove_answer_prefix(answer_part)
|
|
|
834 |
print(cleaned_answer)
|
|
|
835 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
836 |
-
print(
|
|
|
|
|
837 |
if language_code == 0:
|
838 |
final_answer = translate_en_to_ar(final_answer)
|
|
|
839 |
if final_answer:
|
840 |
print("Answer:")
|
841 |
print(final_answer)
|
842 |
else:
|
843 |
print("Sorry, I can't help with that.")
|
844 |
-
|
845 |
-
|
846 |
-
}
|
847 |
|
848 |
except ValueError as ve:
|
849 |
# Handle expected errors
|
@@ -853,6 +965,7 @@ async def recipes_endpoint(profile: MedicalProfile):
|
|
853 |
print(f"Unexpected error: {e}")
|
854 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
855 |
|
|
|
856 |
if not init_success:
|
857 |
print("Warning: Application initialized with partial functionality")
|
858 |
|
|
|
167 |
return key
|
168 |
|
169 |
|
170 |
+
import os
|
171 |
+
import numpy as np
|
172 |
+
from typing import Optional
|
173 |
+
from safetensors.numpy import load_file
|
174 |
+
from huggingface_hub import hf_hub_download
|
175 |
+
|
176 |
+
def load_recipes_embeddings() -> Optional[np.ndarray]:
|
177 |
+
"""
|
178 |
+
Loads recipe embeddings from a .safetensors file, handling local and remote downloads.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
Optional[np.ndarray]: A numpy array containing all embeddings (shape: (num_recipes, embedding_dim)).
|
182 |
+
"""
|
183 |
try:
|
184 |
embeddings_path = 'recipes_embeddings.safetensors'
|
185 |
+
|
186 |
+
# Check if file exists locally, otherwise download from Hugging Face Hub
|
187 |
if not os.path.exists(embeddings_path):
|
188 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
189 |
embeddings_path = hf_hub_download(
|
|
|
192 |
repo_type="space"
|
193 |
)
|
194 |
|
195 |
+
# Load the embeddings tensor from the .safetensors file
|
196 |
+
embeddings = load_file(embeddings_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
+
# Ensure the key 'embeddings' exists in the file
|
199 |
+
if "embeddings" not in embeddings:
|
200 |
+
raise ValueError("Key 'embeddings' not found in the safetensors file.")
|
201 |
+
|
202 |
+
# Retrieve the tensor as a numpy array
|
203 |
+
tensor = embeddings["embeddings"]
|
204 |
|
205 |
+
# Print information about the embeddings
|
206 |
+
print(f"Successfully loaded embeddings.")
|
207 |
+
print(f"Shape of embeddings: {tensor.shape}")
|
208 |
+
print(f"Dtype of embeddings: {tensor.dtype}")
|
209 |
+
print(f"First few values of the first embedding: {tensor[0][:5]}")
|
210 |
+
|
211 |
+
return tensor
|
212 |
|
213 |
except Exception as e:
|
214 |
print(f"Error loading embeddings: {e}")
|
215 |
return None
|
216 |
|
217 |
+
|
218 |
+
|
219 |
def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
|
220 |
"""Load document data from HTML articles in a specified folder."""
|
221 |
try:
|
|
|
309 |
print(f"Error in query_embeddings: {e}")
|
310 |
return []
|
311 |
|
312 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
313 |
+
import numpy as np
|
314 |
+
|
315 |
+
def query_recipes_embeddings(query_embedding: np.ndarray, n_results: int = 5):
|
316 |
+
"""
|
317 |
+
Query the recipes embeddings to find the most similar recipes based on cosine similarity.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
query_embedding (np.ndarray): A 1D numpy array representing the query embedding.
|
321 |
+
n_results (int): Number of top results to return.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
List[Tuple[int, float]]: A list of tuples containing the indices of the top results and their similarity scores.
|
325 |
+
"""
|
326 |
+
# Load embeddings
|
327 |
embeddings_data = load_recipes_embeddings()
|
328 |
+
if embeddings_data is None:
|
329 |
print("No embeddings data available.")
|
330 |
return []
|
331 |
+
|
332 |
try:
|
333 |
+
# Ensure query_embedding is 2D for cosine similarity computation
|
334 |
+
if query_embedding.ndim == 1:
|
335 |
+
query_embedding = query_embedding.reshape(1, -1)
|
336 |
+
|
337 |
+
# Compute cosine similarity
|
338 |
+
similarities = cosine_similarity(query_embedding, embeddings_data).flatten()
|
339 |
+
|
340 |
+
# Get the indices of the top N most similar embeddings
|
341 |
top_indices = similarities.argsort()[-n_results:][::-1]
|
342 |
+
|
343 |
+
# Return the indices and similarity scores of the top results
|
344 |
+
return [(index, similarities[index]) for index in top_indices]
|
345 |
+
|
346 |
except Exception as e:
|
347 |
+
print(f"Error in query_recipes_embeddings: {e}")
|
348 |
return []
|
349 |
|
350 |
def get_page_title(url):
|
|
|
379 |
texts.append("")
|
380 |
return texts
|
381 |
|
382 |
+
import os
|
383 |
+
import pandas as pd
|
384 |
+
|
385 |
+
def retrieve_rec_texts(
|
386 |
+
document_indices,
|
387 |
+
folder_path='downloaded_articles/downloaded_articles',
|
388 |
+
metadata_path='recipes_metadata.xlsx'
|
389 |
+
):
|
390 |
+
"""
|
391 |
+
Retrieve the texts of documents corresponding to the given indices.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
document_indices (List[int]): A list of document indices to retrieve.
|
395 |
+
folder_path (str): Path to the folder containing the article files.
|
396 |
+
metadata_path (str): Path to the metadata file mapping indices to file names.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
List[str]: A list of document texts corresponding to the given indices.
|
400 |
+
"""
|
401 |
+
try:
|
402 |
+
# Load metadata file to map indices to original file names
|
403 |
+
metadata_df = pd.read_excel(metadata_path)
|
404 |
+
|
405 |
+
# Ensure the metadata file has the required columns
|
406 |
+
if "id" not in metadata_df.columns or "original_file_name" not in metadata_df.columns:
|
407 |
+
raise ValueError("Metadata file must contain 'id' and 'original_file_name' columns.")
|
408 |
+
|
409 |
+
# Ensure the 'id' column aligns with the embeddings row indices
|
410 |
+
metadata_df = metadata_df.sort_values(by="id").reset_index(drop=True)
|
411 |
+
|
412 |
+
# Verify the alignment of metadata with embeddings indices
|
413 |
+
if metadata_df.index.max() < max(document_indices):
|
414 |
+
raise ValueError("Some document indices exceed the range of metadata.")
|
415 |
+
|
416 |
+
# Retrieve file names for the given indices
|
417 |
+
document_texts = []
|
418 |
+
for idx in document_indices:
|
419 |
+
if idx >= len(metadata_df):
|
420 |
+
print(f"Warning: Index {idx} is out of range for metadata.")
|
421 |
+
continue
|
422 |
+
|
423 |
+
original_file_name = metadata_df.iloc[idx]["original_file_name"]
|
424 |
+
if not original_file_name:
|
425 |
+
print(f"Warning: No file name found for index {idx}")
|
426 |
+
continue
|
427 |
+
|
428 |
+
# Construct the file path using the original file name
|
429 |
+
file_path = os.path.join(folder_path, original_file_name)
|
430 |
+
|
431 |
+
# Check if the file exists and read its content
|
432 |
+
if os.path.exists(file_path):
|
433 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
434 |
+
document_texts.append(f.read())
|
435 |
+
else:
|
436 |
+
print(f"Warning: File not found at {file_path}")
|
437 |
+
|
438 |
+
return document_texts
|
439 |
+
|
440 |
+
except Exception as e:
|
441 |
+
print(f"Error in retrieve_rec_texts: {e}")
|
442 |
+
return []
|
443 |
+
|
444 |
|
445 |
|
446 |
|
|
|
847 |
|
848 |
|
849 |
|
850 |
+
from fastapi import FastAPI, HTTPException
|
851 |
+
from pydantic import BaseModel
|
852 |
+
import pandas as pd
|
853 |
+
import numpy as np
|
854 |
+
from typing import Optional, List
|
855 |
+
|
856 |
+
app = FastAPI()
|
857 |
+
|
858 |
+
# Define your profile model for input data
|
859 |
+
class MedicalProfile(BaseModel):
|
860 |
+
conditions: str
|
861 |
+
daily_symptoms: str
|
862 |
+
|
863 |
@app.post("/api/recipes")
|
864 |
async def recipes_endpoint(profile: MedicalProfile):
|
865 |
try:
|
|
|
882 |
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=10)
|
883 |
if not initial_results:
|
884 |
raise ValueError("No relevant recipes found.")
|
885 |
+
print("Initial results (document indices and similarities):")
|
886 |
print(initial_results)
|
887 |
+
|
888 |
+
# Extract document indices from the results
|
889 |
+
document_indices = [doc_id for doc_id, _ in initial_results]
|
890 |
+
print("Document indices:", document_indices)
|
891 |
+
|
892 |
+
# Retrieve document texts using the indices
|
893 |
+
document_texts = retrieve_rec_texts(document_indices, folder_path)
|
894 |
if not document_texts:
|
895 |
raise ValueError("Failed to retrieve document texts.")
|
896 |
+
print("Document texts retrieved:")
|
897 |
print(document_texts)
|
898 |
+
|
899 |
+
# Extract relevant portions from documents using the query text
|
|
|
|
|
|
|
900 |
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
|
901 |
+
print("Relevant portions extracted:")
|
902 |
print(relevant_portions)
|
903 |
+
|
904 |
flattened_relevant_portions = []
|
905 |
for doc_id, portions in relevant_portions.items():
|
906 |
flattened_relevant_portions.extend(portions)
|
907 |
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
908 |
+
print("Unique selected parts:")
|
909 |
print(unique_selected_parts)
|
910 |
+
|
911 |
combined_parts = " ".join(unique_selected_parts)
|
912 |
+
print("Combined text for context:")
|
913 |
print(combined_parts)
|
914 |
+
|
915 |
context = [query_text] + unique_selected_parts
|
916 |
+
print("Final context for answering:")
|
917 |
print(context)
|
918 |
+
|
919 |
+
# Extract entities from the query
|
920 |
entities = extract_entities(query_text)
|
921 |
+
print("Extracted entities:")
|
922 |
print(entities)
|
923 |
+
|
924 |
+
# Enhance the passage with the extracted entities
|
925 |
passage = enhance_passage_with_entities(combined_parts, entities)
|
926 |
+
print("Enhanced passage with entities:")
|
927 |
print(passage)
|
928 |
+
|
929 |
+
# Create the prompt for the model
|
930 |
prompt = create_prompt(query_text, passage)
|
931 |
+
print("Generated prompt:")
|
932 |
print(prompt)
|
933 |
+
|
934 |
+
# Generate the answer from the model
|
935 |
answer = generate_answer(prompt)
|
936 |
+
print("Generated answer:")
|
937 |
print(answer)
|
938 |
+
|
939 |
+
# Clean up the answer to extract the relevant part
|
940 |
answer_part = answer.split("Answer:")[-1].strip()
|
|
|
941 |
cleaned_answer = remove_answer_prefix(answer_part)
|
942 |
+
print("Cleaned answer:")
|
943 |
print(cleaned_answer)
|
944 |
+
|
945 |
final_answer = remove_incomplete_sentence(cleaned_answer)
|
946 |
+
print("Final answer:")
|
947 |
+
print(final_answer)
|
948 |
+
|
949 |
if language_code == 0:
|
950 |
final_answer = translate_en_to_ar(final_answer)
|
951 |
+
|
952 |
if final_answer:
|
953 |
print("Answer:")
|
954 |
print(final_answer)
|
955 |
else:
|
956 |
print("Sorry, I can't help with that.")
|
957 |
+
|
958 |
+
return {"response": final_answer}
|
|
|
959 |
|
960 |
except ValueError as ve:
|
961 |
# Handle expected errors
|
|
|
965 |
print(f"Unexpected error: {e}")
|
966 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
967 |
|
968 |
+
|
969 |
if not init_success:
|
970 |
print("Warning: Application initialized with partial functionality")
|
971 |
|