Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoModelForTokenClassification, | |
AutoModelForCausalLM, | |
pipeline | |
) | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from sklearn.metrics.pairwise import cosine_similarity | |
from bs4 import BeautifulSoup | |
import nltk | |
import torch | |
import pandas as pd | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file # Import Safetensors loader | |
from typing import List, Dict, Optional | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for models and data | |
models = {} | |
data = {} | |
class QueryRequest(BaseModel): | |
query: str | |
language_code: int = 0 | |
class MedicalProfile(BaseModel): | |
chronic_conditions: List[str] | |
symptoms: List[str] | |
food_restrictions: List[str] | |
mental_conditions: List[str] | |
daily_symptoms: List[str] | |
class ChatQuery(BaseModel): | |
query: str | |
conversation_id: str | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
timestamp: str | |
def init_nltk(): | |
"""Initialize NLTK resources""" | |
try: | |
nltk.download('punkt', quiet=True) | |
return True | |
except Exception as e: | |
print(f"Error initializing NLTK: {e}") | |
return False | |
def load_models(): | |
"""Initialize all required models""" | |
try: | |
print("Loading models...") | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Device set to use {device}") | |
# Embedding models | |
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
# Translation models | |
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
# NER model | |
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer']) | |
# LLM model | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name) | |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name) | |
print("Models loaded successfully") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
def load_embeddings() -> Optional[Dict[str, np.ndarray]]: | |
"""Load embeddings from Safetensors file""" | |
try: | |
embeddings_path = 'embeddings.safetensors' | |
if not os.path.exists(embeddings_path): | |
embeddings_path = hf_hub_download( | |
repo_id=os.environ.get('thechaiexperiment/TeaRAG', ''), | |
filename="embeddings.safetensors", | |
repo_type="space" | |
) | |
embeddings = load_file(embeddings_path) | |
if not isinstance(embeddings, dict): | |
raise ValueError("Invalid format for embeddings in Safetensors file.") | |
# Convert to dictionary with numpy arrays | |
return {k: tensor.numpy() for k, tensor in embeddings.items()} | |
except Exception as e: | |
print(f"Error loading embeddings: {e}") | |
return None | |
def load_documents_data(): | |
"""Load document data with error handling""" | |
try: | |
print("Loading documents data...") | |
docs_path = 'finalcleaned_excel_file.xlsx' | |
if not os.path.exists(docs_path): | |
print(f"Error: {docs_path} not found") | |
return False | |
data['df'] = pd.read_excel(docs_path) | |
print(f"Successfully loaded {len(data['df'])} document records") | |
return True | |
except Exception as e: | |
print(f"Error loading documents data: {e}") | |
data['df'] = pd.DataFrame() | |
return False | |
def load_data(): | |
"""Load all required data""" | |
embeddings_success = load_embeddings() | |
documents_success = load_documents_data() | |
if not embeddings_success: | |
print("Warning: Failed to load embeddings, falling back to basic functionality") | |
if not documents_success: | |
print("Warning: Failed to load documents data, falling back to basic functionality") | |
return True | |
def translate_text(text, source_to_target='ar_to_en'): | |
"""Translate text between Arabic and English""" | |
try: | |
if source_to_target == 'ar_to_en': | |
tokenizer = models['ar_to_en_tokenizer'] | |
model = models['ar_to_en_model'] | |
else: | |
tokenizer = models['en_to_ar_tokenizer'] | |
model = models['en_to_ar_model'] | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
outputs = model.generate(**inputs) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return text | |
def embed_query_text(query_text): | |
query_embedding = embedding.encode([query_text]) | |
return query_embedding | |
def query_embeddings(query_embedding, n_results=5): | |
"""Find relevant documents using embedding similarity""" | |
if not data['embeddings']: | |
return [] | |
try: | |
doc_ids = list(data['embeddings'].keys()) | |
doc_embeddings = np.array(list(data['embeddings'].values())) | |
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(doc_ids[i], similarities[i]) for i in top_indices] | |
except Exception as e: | |
print(f"Error in query_embeddings: {e}") | |
return [] | |
def retrieve_document_text(doc_id): | |
"""Retrieve document text from HTML file""" | |
try: | |
file_path = os.path.join('downloaded_articles', doc_id) | |
if not os.path.exists(file_path): | |
print(f"Warning: Document file not found: {file_path}") | |
return "" | |
with open(file_path, 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
return soup.get_text(separator=' ', strip=True) | |
except Exception as e: | |
print(f"Error retrieving document {doc_id}: {e}") | |
return "" | |
def rerank_documents(query, doc_texts): | |
"""Rerank documents using cross-encoder""" | |
try: | |
pairs = [(query, doc) for doc in doc_texts] | |
scores = models['cross_encoder'].predict(pairs) | |
return scores | |
except Exception as e: | |
print(f"Error reranking documents: {e}") | |
return np.zeros(len(doc_texts)) | |
def extract_entities(text): | |
"""Extract medical entities from text using NER""" | |
try: | |
results = models['ner_pipeline'](text) | |
return list({result['word'] for result in results if result['entity'].startswith("B-")}) | |
except Exception as e: | |
print(f"Error extracting entities: {e}") | |
return [] | |
def match_entities(query_entities, sentence_entities): | |
query_set, sentence_set = set(query_entities), set(sentence_entities) | |
matches = query_set.intersection(sentence_set) | |
return len(matches) | |
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1): | |
relevant_portions = {} | |
# Extract entities from the query | |
query_entities = extract_entities(query, ner_biobert) | |
print(f"Extracted Query Entities: {query_entities}") | |
for doc_id, doc_text in enumerate(document_texts): | |
sentences = nltk.sent_tokenize(doc_text) # Split document into sentences | |
doc_relevant_portions = [] | |
# Extract entities from the entire document | |
doc_entities = extract_entities(doc_text, ner_biobert) | |
print(f"Document {doc_id} Entities: {doc_entities}") | |
for i, sentence in enumerate(sentences): | |
# Extract entities from the sentence | |
sentence_entities = extract_entities(sentence, ner_biobert) | |
# Compute relevance score | |
relevance_score = match_entities(query_entities, sentence_entities) | |
# Select sentences with at least `min_query_words` matching entities | |
if relevance_score >= min_query_words: | |
start_idx = max(0, i - portion_size // 2) | |
end_idx = min(len(sentences), i + portion_size // 2 + 1) | |
portion = " ".join(sentences[start_idx:end_idx]) | |
doc_relevant_portions.append(portion) | |
if len(doc_relevant_portions) >= max_portions: | |
break | |
# Add fallback to include the most entity-dense sentences if no results | |
if not doc_relevant_portions and len(doc_entities) > 0: | |
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}") | |
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True) | |
for fallback_sentence in sorted_sentences[:max_portions]: | |
doc_relevant_portions.append(fallback_sentence) | |
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions | |
return relevant_portions | |
def remove_duplicates(selected_parts): | |
unique_sentences = set() | |
unique_selected_parts = [] | |
for sentence in selected_parts: | |
if sentence not in unique_sentences: | |
unique_selected_parts.append(sentence) | |
unique_sentences.add(sentence) | |
return unique_selected_parts | |
def extract_entities(text): | |
inputs = biobert_tokenizer(text, return_tensors="pt") | |
outputs = biobert_model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=2) | |
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity | |
return entities | |
def enhance_passage_with_entities(passage, entities): | |
# Example: Add entities to the passage for better context | |
return f"{passage}\n\nEntities: {', '.join(entities)}" | |
def create_prompt(question, passage): | |
prompt = (""" | |
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information. | |
Passage: {passage} | |
Question: {question} | |
Answer: | |
""") | |
return prompt.format(passage=passage, question=question) | |
def generate_answer(prompt, max_length=860, temperature=0.2): | |
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True) | |
# Start timing | |
start_time = time.time() | |
output_ids = model_f.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
pad_token_id=tokenizer_f.eos_token_id | |
) | |
# End timing | |
end_time = time.time() | |
# Calculate the duration | |
duration = end_time - start_time | |
# Decode the answer | |
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True) | |
passage_keywords = set(passage.lower().split()) | |
answer_keywords = set(answer.lower().split()) | |
if passage_keywords.intersection(answer_keywords): | |
return answer, duration | |
else: | |
return "Sorry, I can't help with that.", duration | |
def remove_answer_prefix(text): | |
prefix = "Answer:" | |
if prefix in text: | |
return text.split(prefix)[-1].strip() | |
return text | |
def remove_incomplete_sentence(text): | |
# Check if the text ends with a period | |
if not text.endswith('.'): | |
# Find the last period or the end of the string | |
last_period_index = text.rfind('.') | |
if last_period_index != -1: | |
# Remove everything after the last period | |
return text[:last_period_index + 1].strip() | |
return text | |
async def root(): | |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."} | |
async def health_check(): | |
"""Health check endpoint""" | |
status = { | |
'status': 'healthy', | |
'models_loaded': bool(models), | |
'embeddings_loaded': bool(data.get('embeddings')), | |
'documents_loaded': not data.get('df', pd.DataFrame()).empty | |
} | |
return status | |
async def chat_endpoint(chat_query: ChatQuery): | |
try: | |
query_text = chat_query.query | |
query_embedding = embed_query_text(query_text) | |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5) | |
document_ids = [doc_id for doc_id, _ in initial_results] | |
document_texts = retrieve_document_texts(document_ids, folder_path) | |
flattened_relevant_portions = [] | |
for doc_id, portions in relevant_portions.items(): | |
flattened_relevant_portions.extend(portions) | |
unique_selected_parts = remove_duplicates(flattened_relevant_portions) | |
combined_parts = " ".join(unique_selected_parts) | |
context = [query_text] + unique_selected_parts | |
entities = extract_entities(query_text) | |
passage = enhance_passage_with_entities(combined_parts, entities) | |
prompt = create_prompt(query_text, passage) | |
answer, generation_time = generate_answer(prompt) | |
answer_part = answer.split("Answer:")[-1].strip() | |
cleaned_answer = remove_answer_prefix(answer_part) | |
final_answer = remove_incomplete_sentence(cleaned_answer) | |
return { | |
"response": final_answer, | |
"conversation_id": chat_query.conversation_id, | |
"success": True | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def resources_endpoint(profile: MedicalProfile): | |
try: | |
context = f""" | |
Medical conditions: {', '.join(profile.chronic_conditions)} | |
Current symptoms: {', '.join(profile.daily_symptoms)} | |
Restrictions: {', '.join(profile.food_restrictions)} | |
Mental health: {', '.join(profile.mental_conditions)} | |
""" | |
query_embedding = models['embedding'].encode([context]) | |
relevant_docs = query_embeddings(query_embedding) | |
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs] | |
doc_texts = [text for text in doc_texts if text.strip()] | |
rerank_scores = rerank_documents(context, doc_texts) | |
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True) | |
resources = [] | |
for (doc_id, _), score, text in ranked_docs[:10]: | |
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0] | |
resources.append({ | |
"id": doc_id, | |
"title": doc_info['title'], | |
"content": text[:200], | |
"score": float(score) | |
}) | |
return {"resources": resources, "success": True} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def recipes_endpoint(profile: MedicalProfile): | |
try: | |
recipe_query = f"Recipes and meals suitable for someone with: {', '.join(profile.chronic_conditions + profile.food_restrictions)}" | |
query_embedding = models['embedding'].encode([recipe_query]) | |
relevant_docs = query_embeddings(query_embedding) | |
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs] | |
doc_texts = [text for text in doc_texts if text.strip()] | |
rerank_scores = rerank_documents(recipe_query, doc_texts) | |
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True) | |
recipes = [] | |
for (doc_id, _), score, text in ranked_docs[:10]: | |
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0] | |
if 'recipe' in text.lower() or 'meal' in text.lower(): | |
recipes.append({ | |
"id": doc_id, | |
"title": doc_info['title'], | |
"content": text[:200], | |
"score": float(score) | |
}) | |
return {"recipes": recipes[:5], "success": True} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Initialize application | |
print("Initializing application...") | |
init_success = load_models() and load_data() | |
if not init_success: | |
print("Warning: Application initialized with partial functionality") | |
# For running locally | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |