Spaces:
Sleeping
Sleeping
import transformers | |
import pickle | |
import os | |
import re | |
import numpy as np | |
import torchvision | |
import nltk | |
import torch | |
import pandas as pd | |
import requests | |
import zipfile | |
import tempfile | |
from PyPDF2 import PdfReader | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoModelForTokenClassification, | |
AutoModelForCausalLM, | |
pipeline, | |
Qwen2Tokenizer, | |
BartForConditionalGeneration | |
) | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from sklearn.metrics.pairwise import cosine_similarity | |
from bs4 import BeautifulSoup | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from typing import List, Dict,Any,Tuple, Optional | |
from safetensors.numpy import load_file | |
from safetensors.torch import safe_open | |
from concurrent.futures import ThreadPoolExecutor | |
import asyncio | |
from functools import partial | |
nltk.download('punkt_tab') | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
models = {} | |
data = {} | |
class QueryRequest(BaseModel): | |
query: str | |
language_code: int = 1 | |
class MedicalProfile(BaseModel): | |
conditions: str | |
daily_symptoms: str | |
count: int | |
class ChatQuery(BaseModel): | |
query: str | |
language_code: int = 1 | |
#conversation_id: str | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
timestamp: str | |
async def run_in_threadpool(func, *args, **kwargs): | |
return await asyncio.get_event_loop().run_in_executor( | |
None, partial(func, *args, **kwargs) | |
) | |
def init_nltk(): | |
try: | |
nltk.download('punkt', quiet=True) | |
return True | |
except Exception as e: | |
print(f"Error initializing NLTK: {e}") | |
return False | |
def load_models(): | |
try: | |
print("Loading models...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device set to use {device}") | |
models['embedding_model'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
models['semantic_model'] = SentenceTransformer('all-MiniLM-L6-v2') | |
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
models['att_tokenizer'] = AutoTokenizer.from_pretrained("facebook/bart-base") | |
models['att_model'] = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer']) | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name) | |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name) | |
print("Models loaded successfully") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
def load_embeddings() -> Optional[Dict[str, np.ndarray]]: | |
try: | |
embeddings_path = 'embeddings.safetensors' | |
if not os.path.exists(embeddings_path): | |
print("File not found locally. Attempting to download from Hugging Face Hub...") | |
embeddings_path = hf_hub_download( | |
repo_id=os.environ.get('HF_SPACE_ID', 'thechaiexperiment/TeaRAG'), | |
filename="embeddings.safetensors", | |
repo_type="space" | |
) | |
embeddings = {} | |
with safe_open(embeddings_path, framework="pt") as f: | |
keys = f.keys() | |
for key in keys: | |
try: | |
tensor = f.get_tensor(key) | |
if not isinstance(tensor, torch.Tensor): | |
raise TypeError(f"Value for key {key} is not a valid PyTorch tensor.") | |
embeddings[key] = tensor.numpy() | |
except Exception as key_error: | |
print(f"Failed to process key {key}: {key_error}") | |
if embeddings: | |
print("Embeddings successfully loaded.") | |
else: | |
print("No embeddings could be loaded. Please check the file format and content.") | |
return embeddings | |
except Exception as e: | |
print(f"Error loading embeddings: {e}") | |
return None | |
def normalize_key(key: str) -> str: | |
match = re.search(r'file_(\d+)', key) | |
if match: | |
return match.group(1) | |
return key | |
def load_recipes_embeddings() -> Optional[np.ndarray]: | |
try: | |
embeddings_path = 'recipes_embeddings.safetensors' | |
if not os.path.exists(embeddings_path): | |
print("File not found locally. Attempting to download from Hugging Face Hub...") | |
embeddings_path = hf_hub_download( | |
repo_id=os.environ.get('HF_SPACE_ID', 'thechaiexperiment/TeaRAG'), | |
filename="embeddings.safetensors", | |
repo_type="space" | |
) | |
embeddings = load_file(embeddings_path) | |
if "embeddings" not in embeddings: | |
raise ValueError("Key 'embeddings' not found in the safetensors file.") | |
tensor = embeddings["embeddings"] | |
print(f"Successfully loaded embeddings.") | |
print(f"Shape of embeddings: {tensor.shape}") | |
print(f"Dtype of embeddings: {tensor.dtype}") | |
print(f"First few values of the first embedding: {tensor[0][:5]}") | |
return tensor | |
except Exception as e: | |
print(f"Error loading embeddings: {e}") | |
return None | |
def load_documents_data(folder_path='downloaded_articles/downloaded_articles'): | |
try: | |
print("Loading documents data...") | |
if not os.path.exists(folder_path) or not os.path.isdir(folder_path): | |
print(f"Error: Folder '{folder_path}' not found") | |
return False | |
html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')] | |
if not html_files: | |
print(f"No HTML files found in folder '{folder_path}'") | |
return False | |
documents = [] | |
for file_name in html_files: | |
file_path = os.path.join(folder_path, file_name) | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
text = soup.get_text(separator='\n').strip() | |
documents.append({"file_name": file_name, "content": text}) | |
except Exception as e: | |
print(f"Error reading file {file_name}: {e}") | |
data['df'] = pd.DataFrame(documents) | |
if data['df'].empty: | |
print("No valid documents loaded.") | |
return False | |
print(f"Successfully loaded {len(data['df'])} document records.") | |
return True | |
except Exception as e: | |
print(f"Error loading docs: {e}") | |
return None | |
def load_data(): | |
embeddings_success = load_embeddings() | |
documents_success = load_documents_data() | |
if not embeddings_success: | |
print("Warning: Failed to load embeddings, falling back to basic functionality") | |
if not documents_success: | |
print("Warning: Failed to load documents data, falling back to basic functionality") | |
return True | |
print("Initializing application...") | |
init_success = load_models() and load_data() | |
def translate_text(text, source_to_target='ar_to_en'): | |
try: | |
if source_to_target == 'ar_to_en': | |
tokenizer = models['ar_to_en_tokenizer'] | |
model = models['ar_to_en_model'] | |
else: | |
tokenizer = models['en_to_ar_tokenizer'] | |
model = models['en_to_ar_model'] | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
outputs = model.generate(**inputs) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return text | |
def embed_query_text(query_text): | |
embedding = models['embedding_model'] | |
query_embedding = embedding.encode([query_text]) | |
return query_embedding | |
def query_embeddings(query_embedding, embeddings_data, n_results): | |
embeddings_data = load_embeddings() | |
if not embeddings_data: | |
print("No embeddings data available.") | |
return [] | |
try: | |
doc_ids = list(embeddings_data.keys()) | |
doc_embeddings = np.array(list(embeddings_data.values())) | |
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(doc_ids[i], similarities[i]) for i in top_indices] | |
except Exception as e: | |
print(f"Error in query_embeddings: {e}") | |
return [] | |
def query_recipes_embeddings(query_embedding, embeddings_data, n_results): | |
embeddings_data = load_recipes_embeddings() | |
if embeddings_data is None: | |
print("No embeddings data available.") | |
return [] | |
try: | |
if query_embedding.ndim == 1: | |
query_embedding = query_embedding.reshape(1, -1) | |
similarities = cosine_similarity(query_embedding, embeddings_data).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(index, similarities[index]) for index in top_indices] | |
except Exception as e: | |
print(f"Error in query_recipes_embeddings: {e}") | |
return [] | |
def get_page_title(url): | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
soup = BeautifulSoup(response.text, 'html.parser') | |
title = soup.find('title') | |
return title.get_text() if title else "No title found" | |
else: | |
return None | |
except requests.exceptions.RequestException: | |
return None | |
def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded_articles'): | |
texts = [] | |
for doc_id in doc_ids: | |
file_path = os.path.join(folder_path, doc_id) | |
try: | |
if not os.path.exists(file_path): | |
print(f"Warning: Document file not found: {file_path}") | |
texts.append("") | |
continue | |
with open(file_path, 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
text = soup.get_text(separator=' ', strip=True) | |
texts.append(text) | |
except Exception as e: | |
print(f"Error retrieving document {doc_id}: {e}") | |
texts.append("") | |
return texts | |
def retrieve_rec_texts( | |
document_indices, | |
folder_path='downloaded_articles/downloaded_articles', | |
metadata_path='recipes_metadata.xlsx' | |
): | |
try: | |
metadata_df = pd.read_excel(metadata_path) | |
if "id" not in metadata_df.columns or "original_file_name" not in metadata_df.columns: | |
raise ValueError("Metadata file must contain 'id' and 'original_file_name' columns.") | |
metadata_df = metadata_df.sort_values(by="id").reset_index(drop=True) | |
if metadata_df.index.max() < max(document_indices): | |
raise ValueError("Some document indices exceed the range of metadata.") | |
document_texts = [] | |
for idx in document_indices: | |
if idx >= len(metadata_df): | |
print(f"Warning: Index {idx} is out of range for metadata.") | |
continue | |
original_file_name = metadata_df.iloc[idx]["original_file_name"] | |
if not original_file_name: | |
print(f"Warning: No file name found for index {idx}") | |
continue | |
file_path = os.path.join(folder_path, original_file_name) | |
if os.path.exists(file_path): | |
with open(file_path, "r", encoding="utf-8") as f: | |
document_texts.append(f.read()) | |
else: | |
print(f"Warning: File not found at {file_path}") | |
return document_texts | |
except Exception as e: | |
print(f"Error in retrieve_rec_texts: {e}") | |
return [] | |
def retrieve_metadata(document_indices: List[int], metadata_path: str = 'recipes_metadata.xlsx') -> Dict[int, Dict[str, str]]: | |
try: | |
metadata_df = pd.read_excel(metadata_path) | |
required_columns = {'id', 'original_file_name', 'url'} | |
if not required_columns.issubset(metadata_df.columns): | |
raise ValueError(f"Metadata file must contain columns: {required_columns}") | |
metadata_df['id'] = metadata_df['id'].astype(int) | |
filtered_metadata = metadata_df[metadata_df['id'].isin(document_indices)] | |
metadata_dict = { | |
int(row['id']): { | |
"original_file_name": row['original_file_name'], | |
"url": row['url'] | |
} | |
for _, row in filtered_metadata.iterrows() | |
} | |
return metadata_dict | |
except Exception as e: | |
print(f"Error retrieving metadata: {e}") | |
return {} | |
def rerank_documents(query: str, document_ids: List[str], document_texts: List[str], cross_encoder_model) -> List[Tuple[float, str, str]]: | |
try: | |
# Batch process all documents at once | |
pairs = [(query, doc) for doc in document_texts] | |
scores = cross_encoder_model.predict(pairs, batch_size=8) # Increased batch size | |
scored_documents = list(zip(scores, document_ids, document_texts)) | |
scored_documents.sort(key=lambda x: x[0], reverse=True) | |
return scored_documents | |
except Exception as e: | |
print(f"Error reranking documents: {e}") | |
return [] | |
def extract_entities_batch(texts: List[str], biobert_tokenizer, biobert_model, batch_size: int = 8) -> List[List[str]]: | |
try: | |
all_entities = [] | |
for i in range(0, len(texts), batch_size): | |
batch_texts = texts[i:i + batch_size] | |
# Process multiple texts in parallel | |
inputs = biobert_tokenizer(batch_texts, padding=True, truncation=True, return_tensors="pt", max_length=512) | |
with torch.no_grad(): # Disable gradient calculation | |
outputs = biobert_model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=2) | |
for j, (input_ids, preds) in enumerate(zip(inputs.input_ids, predictions)): | |
tokens = biobert_tokenizer.convert_ids_to_tokens(input_ids) | |
entities = [tokens[k] for k in range(len(tokens)) if preds[k].item() != 0] | |
all_entities.append(entities) | |
return all_entities | |
except Exception as e: | |
print(f"Error in batch entity extraction: {e}") | |
return [[] for _ in texts] | |
def extract_relevant_portions(document_texts: List[str], query: str, biobert_tokenizer, biobert_model, | |
max_portions: int = 3, portion_size: int = 1) -> Dict[str, List[str]]: | |
try: | |
# Process query and all documents in one batch | |
all_texts = [query] + document_texts | |
all_entities = extract_entities_batch(all_texts, biobert_tokenizer, biobert_model) | |
query_entities = set(all_entities[0]) | |
relevant_portions = {} | |
def process_document(doc_idx: int) -> Tuple[str, List[str]]: | |
doc_text = document_texts[doc_idx] | |
doc_entities = set(all_entities[doc_idx + 1]) # +1 because query was first | |
sentences = nltk.sent_tokenize(doc_text) | |
doc_relevant_portions = [] | |
# Score sentences based on entity overlap | |
sentence_scores = [] | |
for i, sentence in enumerate(sentences): | |
entity_overlap = len(query_entities.intersection(doc_entities)) | |
sentence_scores.append((entity_overlap, i)) | |
# Sort and select top sentences | |
sentence_scores.sort(reverse=True) | |
for _, sent_idx in sentence_scores[:max_portions]: | |
start_idx = max(0, sent_idx - portion_size // 2) | |
end_idx = min(len(sentences), sent_idx + portion_size // 2 + 1) | |
portion = " ".join(sentences[start_idx:end_idx]) | |
doc_relevant_portions.append(portion) | |
return f"Document_{doc_idx}", doc_relevant_portions | |
# Process documents in parallel | |
with ThreadPoolExecutor(max_workers=4) as executor: | |
results = list(executor.map(lambda x: process_document(x), range(len(document_texts)))) | |
relevant_portions = dict(results) | |
return relevant_portions | |
except Exception as e: | |
print(f"Error extracting relevant portions: {e}") | |
return {f"Document_{i}": [] for i in range(len(document_texts))} | |
def generate_answer(prompt: str, tokenizer_f, model_f, max_length: int = 860, temperature: float = 0.2) -> str: | |
try: | |
# Optimize input processing | |
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True, max_length=512) | |
with torch.no_grad(): # Disable gradient calculation | |
output_ids = model_f.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
pad_token_id=tokenizer_f.eos_token_id, | |
do_sample=False, # Use greedy decoding for faster generation | |
early_stopping=True | |
) | |
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True) | |
# Quick relevance check | |
if any(word in answer.lower() for word in prompt.lower().split()): | |
return answer | |
return "I apologize, but I cannot provide a relevant answer based on the given information." | |
except Exception as e: | |
print(f"Error generating answer: {e}") | |
return "I apologize, but I encountered an error while generating the answer." | |
def create_prompt(question: str, passage: str) -> str: | |
return f"""As a medical expert, answer the following question based only on the provided passage. Be concise and direct. | |
Passage: {passage} | |
Question: {question} | |
Answer:""" | |
def process_query_and_generate_answer( | |
query: str, | |
relevant_documents: List[Tuple[float, str, str]], | |
models: Dict, | |
max_portions: int = 3 | |
) -> str: | |
try: | |
# Extract relevant portions from top documents | |
relevant_portions = extract_relevant_portions( | |
[doc[2] for doc in relevant_documents[:3]], # Use top 3 documents | |
query, | |
models['bio_tokenizer'], | |
models['bio_model'], | |
max_portions=max_portions | |
) | |
# Combine relevant portions | |
all_portions = [] | |
for doc_portions in relevant_portions.values(): | |
all_portions.extend(doc_portions) | |
# Remove duplicates while preserving order | |
unique_portions = list(dict.fromkeys(all_portions)) | |
# Create context from unique portions | |
context = " ".join(unique_portions[:max_portions]) | |
# Generate and return answer | |
prompt = create_prompt(query, context) | |
return generate_answer( | |
prompt, | |
models['llm_tokenizer'], | |
models['llm_model'] | |
) | |
except Exception as e: | |
print(f"Error in query processing pipeline: {e}") | |
return "I apologize, but I encountered an error while processing your question." | |
def remove_answer_prefix(text): | |
prefix = "Answer:" | |
if prefix in text: | |
return text.split(prefix, 1)[-1].strip() | |
return text | |
def remove_incomplete_sentence(text): | |
if not text.endswith('.'): | |
last_period_index = text.rfind('.') | |
if last_period_index != -1: | |
return text[:last_period_index + 1].strip() | |
return text | |
def translate_ar_to_en(text): | |
try: | |
ar_to_en_tokenizer = models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
ar_to_en_model= models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = ar_to_en_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during Arabic to English translation: {e}") | |
return None | |
def translate_en_to_ar(text): | |
try: | |
en_to_ar_tokenizer = models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
en_to_ar_model = models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = en_to_ar_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during English to Arabic translation: {e}") | |
return None | |
async def root(): | |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."} | |
async def health_check(): | |
"""Health check endpoint""" | |
status = { | |
'status': 'healthy', | |
'models_loaded': bool(models), | |
'embeddings_loaded': bool(data.get('embeddings')), | |
'documents_loaded': not data.get('df', pd.DataFrame()).empty | |
} | |
return status | |
async def chat_endpoint(chat_query: ChatQuery): | |
try: | |
# Initialize response timing | |
start_time = asyncio.get_event_loop().time() | |
# Extract query and handle translation | |
query_text = chat_query.query | |
language_code = chat_query.language_code | |
if language_code == 0: | |
query_text = await run_in_threadpool(translate_ar_to_en, query_text) | |
# Embed query and load embeddings in parallel | |
query_embedding_task = run_in_threadpool(embed_query_text, query_text) | |
embeddings_data_task = run_in_threadpool(load_embeddings) | |
# Wait for both tasks to complete | |
query_embedding, embeddings_data = await asyncio.gather( | |
query_embedding_task, | |
embeddings_data_task | |
) | |
# Initial document retrieval | |
n_results = 5 | |
folder_path = 'downloaded_articles/downloaded_articles' | |
# Get initial results and retrieve documents | |
initial_results = await run_in_threadpool( | |
query_embeddings, | |
query_embedding, | |
embeddings_data, | |
n_results | |
) | |
document_ids = [doc_id for doc_id, *_ in initial_results] | |
document_texts = await run_in_threadpool( | |
retrieve_document_texts, | |
document_ids, | |
folder_path | |
) | |
# Rerank documents | |
cross_encoder = models['cross_encoder'] | |
scored_documents = await run_in_threadpool( | |
rerank_documents, | |
query_text, | |
document_ids, | |
document_texts, | |
cross_encoder | |
) | |
# Process documents and generate answer | |
async with asyncio.TaskGroup() as tg: | |
# Extract entities in parallel | |
entities_task = tg.create_task( | |
run_in_threadpool( | |
extract_entities_batch, | |
[query_text] + [doc[2] for doc in scored_documents[:3]], | |
models['bio_tokenizer'], | |
models['bio_model'] | |
) | |
) | |
# Extract relevant portions | |
portions_task = tg.create_task( | |
run_in_threadpool( | |
extract_relevant_portions, | |
[doc[2] for doc in scored_documents[:3]], | |
query_text, | |
models['bio_tokenizer'], | |
models['bio_model'] | |
) | |
) | |
entities = (await entities_task)[0] # First item is query entities | |
relevant_portions = await portions_task | |
# Flatten and process portions | |
flattened_portions = [] | |
for doc_portions in relevant_portions.values(): | |
flattened_portions.extend(doc_portions) | |
unique_selected_parts = list(dict.fromkeys(flattened_portions)) | |
combined_parts = " ".join(unique_selected_parts) | |
# Enhance passage and create prompt | |
passage = enhance_passage_with_entities(combined_parts, entities) | |
prompt = create_prompt(query_text, passage) | |
# Generate answer | |
answer = await run_in_threadpool( | |
generate_answer, | |
prompt, | |
models['llm_tokenizer'], | |
models['llm_model'] | |
) | |
# Process answer | |
answer_part = answer.split("Answer:")[-1].strip() | |
cleaned_answer = await run_in_threadpool(remove_answer_prefix, answer_part) | |
final_answer = await run_in_threadpool(remove_incomplete_sentence, cleaned_answer) | |
# Handle translation if needed | |
if language_code == 0: | |
final_answer = await run_in_threadpool(translate_en_to_ar, final_answer) | |
# Calculate response time | |
end_time = asyncio.get_event_loop().time() | |
response_time = end_time - start_time | |
if final_answer: | |
print(f"Answer generated in {response_time:.2f} seconds") | |
print(final_answer) | |
return { | |
"response": f"I hope this answers your question: {final_answer}", | |
"success": True, | |
"response_time": response_time | |
} | |
else: | |
return { | |
"response": "Sorry, I can't help with that.", | |
"success": False, | |
"response_time": response_time | |
} | |
except Exception as e: | |
print(f"Error in chat endpoint: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def resources_endpoint(profile: MedicalProfile): | |
try: | |
query_text = profile.conditions + " " + profile.daily_symptoms | |
n_results = profile.count | |
print(f"Generated query text: {query_text}") | |
query_embedding = embed_query_text(query_text) | |
if query_embedding is None: | |
raise ValueError("Failed to generate query embedding.") | |
embeddings_data = load_embeddings() | |
folder_path = 'downloaded_articles/downloaded_articles' | |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results) | |
if not initial_results: | |
raise ValueError("No relevant documents found.") | |
document_ids = [doc_id for doc_id, _ in initial_results] | |
file_path = 'finalcleaned_excel_file.xlsx' | |
df = pd.read_excel(file_path) | |
file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} | |
resources = [] | |
for file_name in document_ids: | |
original_url = file_name_to_url.get(file_name, None) | |
if original_url: | |
title = get_page_title(original_url) or "Unknown Title" | |
resources.append({"file_name": file_name, "title": title, "url": original_url}) | |
else: | |
resources.append({"file_name": file_name, "title": "Unknown", "url": None}) | |
document_texts = retrieve_document_texts(document_ids, folder_path) | |
if not document_texts: | |
raise ValueError("Failed to retrieve document texts.") | |
cross_encoder = models['cross_encoder'] | |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts]) | |
scores = [float(score) for score in scores] | |
for i, resource in enumerate(resources): | |
resource["score"] = scores[i] if i < len(scores) else 0.0 | |
resources.sort(key=lambda x: x["score"], reverse=True) | |
output = [{"title": resource["title"], "url": resource["url"]} for resource in resources] | |
return output | |
except ValueError as ve: | |
raise HTTPException(status_code=400, detail=str(ve)) | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail="An unexpected error occurred.") | |
async def recipes_endpoint(profile: MedicalProfile): | |
try: | |
recipe_query = ( | |
f"Recipes and foods for: " | |
f"{profile.conditions} and experiencing {profile.daily_symptoms}" | |
) | |
query_text = recipe_query | |
print(f"Generated query text: {query_text}") | |
n_results = profile.count | |
query_embedding = embed_query_text(query_text) | |
if query_embedding is None: | |
raise ValueError("Failed to generate query embedding.") | |
embeddings_data = load_recipes_embeddings() | |
folder_path = 'downloaded_articles/downloaded_articles' | |
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results) | |
if not initial_results: | |
raise ValueError("No relevant recipes found.") | |
print("Initial results (document indices and similarities):") | |
print(initial_results) | |
document_indices = [doc_id for doc_id, _ in initial_results] | |
print("Document indices:", document_indices) | |
metadata_path = 'recipes_metadata.xlsx' | |
metadata = retrieve_metadata(document_indices, metadata_path=metadata_path) | |
print(f"Retrieved Metadata: {metadata}") | |
recipes = [] | |
for item in metadata.values(): | |
recipes.append({ | |
"title": item["original_file_name"] if "original_file_name" in item else "Unknown Title", | |
"url": item["url"] if "url" in item else "" | |
}) | |
print(recipes) | |
return recipes | |
except ValueError as ve: | |
raise HTTPException(status_code=400, detail=str(ve)) | |
except Exception as e: | |
print(f"Unexpected error: {e}") | |
raise HTTPException(status_code=500, detail="An unexpected error occurred.") | |
if not init_success: | |
print("Warning: Application initialized with partial functionality") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |