TeaRAG / app.py
thechaiexperiment's picture
Update app.py
0a3c7e7
raw
history blame
17.5 kB
import os
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForCausalLM,
pipeline
)
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import nltk
import torch
import pandas as pd
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file # Import Safetensors loader
from typing import List, Dict, Optional
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models and data
models = {}
data = {}
class QueryRequest(BaseModel):
query: str
language_code: int = 0
class MedicalProfile(BaseModel):
chronic_conditions: List[str]
symptoms: List[str]
food_restrictions: List[str]
mental_conditions: List[str]
daily_symptoms: List[str]
class ChatQuery(BaseModel):
query: str
conversation_id: str
class ChatMessage(BaseModel):
role: str
content: str
timestamp: str
def init_nltk():
"""Initialize NLTK resources"""
try:
nltk.download('punkt', quiet=True)
return True
except Exception as e:
print(f"Error initializing NLTK: {e}")
return False
def load_models():
"""Initialize all required models"""
try:
print("Loading models...")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set to use {device}")
# Embedding models
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
# Translation models
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# NER model
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
# LLM model
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
print("Models loaded successfully")
return True
except Exception as e:
print(f"Error loading models: {e}")
return False
def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
"""Load embeddings from Safetensors file"""
try:
embeddings_path = 'embeddings.safetensors'
if not os.path.exists(embeddings_path):
embeddings_path = hf_hub_download(
repo_id=os.environ.get('thechaiexperiment/TeaRAG', ''),
filename="embeddings.safetensors",
repo_type="space"
)
embeddings = load_file(embeddings_path)
if not isinstance(embeddings, dict):
raise ValueError("Invalid format for embeddings in Safetensors file.")
# Convert to dictionary with numpy arrays
return {k: tensor.numpy() for k, tensor in embeddings.items()}
except Exception as e:
print(f"Error loading embeddings: {e}")
return None
def load_documents_data():
"""Load document data with error handling"""
try:
print("Loading documents data...")
docs_path = 'finalcleaned_excel_file.xlsx'
if not os.path.exists(docs_path):
print(f"Error: {docs_path} not found")
return False
data['df'] = pd.read_excel(docs_path)
print(f"Successfully loaded {len(data['df'])} document records")
return True
except Exception as e:
print(f"Error loading documents data: {e}")
data['df'] = pd.DataFrame()
return False
def load_data():
"""Load all required data"""
embeddings_success = load_embeddings()
documents_success = load_documents_data()
if not embeddings_success:
print("Warning: Failed to load embeddings, falling back to basic functionality")
if not documents_success:
print("Warning: Failed to load documents data, falling back to basic functionality")
return True
def translate_text(text, source_to_target='ar_to_en'):
"""Translate text between Arabic and English"""
try:
if source_to_target == 'ar_to_en':
tokenizer = models['ar_to_en_tokenizer']
model = models['ar_to_en_model']
else:
tokenizer = models['en_to_ar_tokenizer']
model = models['en_to_ar_model']
inputs = tokenizer(text, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Translation error: {e}")
return text
def embed_query_text(query_text):
query_embedding = embedding.encode([query_text])
return query_embedding
def query_embeddings(query_embedding, n_results=5):
"""Find relevant documents using embedding similarity"""
if not data['embeddings']:
return []
try:
doc_ids = list(data['embeddings'].keys())
doc_embeddings = np.array(list(data['embeddings'].values()))
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
top_indices = similarities.argsort()[-n_results:][::-1]
return [(doc_ids[i], similarities[i]) for i in top_indices]
except Exception as e:
print(f"Error in query_embeddings: {e}")
return []
def retrieve_document_text(doc_id):
"""Retrieve document text from HTML file"""
try:
file_path = os.path.join('downloaded_articles', doc_id)
if not os.path.exists(file_path):
print(f"Warning: Document file not found: {file_path}")
return ""
with open(file_path, 'r', encoding='utf-8') as file:
soup = BeautifulSoup(file, 'html.parser')
return soup.get_text(separator=' ', strip=True)
except Exception as e:
print(f"Error retrieving document {doc_id}: {e}")
return ""
def rerank_documents(query, doc_texts):
"""Rerank documents using cross-encoder"""
try:
pairs = [(query, doc) for doc in doc_texts]
scores = models['cross_encoder'].predict(pairs)
return scores
except Exception as e:
print(f"Error reranking documents: {e}")
return np.zeros(len(doc_texts))
def extract_entities(text):
"""Extract medical entities from text using NER"""
try:
results = models['ner_pipeline'](text)
return list({result['word'] for result in results if result['entity'].startswith("B-")})
except Exception as e:
print(f"Error extracting entities: {e}")
return []
def match_entities(query_entities, sentence_entities):
query_set, sentence_set = set(query_entities), set(sentence_entities)
matches = query_set.intersection(sentence_set)
return len(matches)
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
relevant_portions = {}
# Extract entities from the query
query_entities = extract_entities(query, ner_biobert)
print(f"Extracted Query Entities: {query_entities}")
for doc_id, doc_text in enumerate(document_texts):
sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
doc_relevant_portions = []
# Extract entities from the entire document
doc_entities = extract_entities(doc_text, ner_biobert)
print(f"Document {doc_id} Entities: {doc_entities}")
for i, sentence in enumerate(sentences):
# Extract entities from the sentence
sentence_entities = extract_entities(sentence, ner_biobert)
# Compute relevance score
relevance_score = match_entities(query_entities, sentence_entities)
# Select sentences with at least `min_query_words` matching entities
if relevance_score >= min_query_words:
start_idx = max(0, i - portion_size // 2)
end_idx = min(len(sentences), i + portion_size // 2 + 1)
portion = " ".join(sentences[start_idx:end_idx])
doc_relevant_portions.append(portion)
if len(doc_relevant_portions) >= max_portions:
break
# Add fallback to include the most entity-dense sentences if no results
if not doc_relevant_portions and len(doc_entities) > 0:
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
for fallback_sentence in sorted_sentences[:max_portions]:
doc_relevant_portions.append(fallback_sentence)
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
return relevant_portions
def remove_duplicates(selected_parts):
unique_sentences = set()
unique_selected_parts = []
for sentence in selected_parts:
if sentence not in unique_sentences:
unique_selected_parts.append(sentence)
unique_sentences.add(sentence)
return unique_selected_parts
def extract_entities(text):
inputs = biobert_tokenizer(text, return_tensors="pt")
outputs = biobert_model(**inputs)
predictions = torch.argmax(outputs.logits, dim=2)
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
return entities
def enhance_passage_with_entities(passage, entities):
# Example: Add entities to the passage for better context
return f"{passage}\n\nEntities: {', '.join(entities)}"
def create_prompt(question, passage):
prompt = ("""
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
Passage: {passage}
Question: {question}
Answer:
""")
return prompt.format(passage=passage, question=question)
def generate_answer(prompt, max_length=860, temperature=0.2):
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
# Start timing
start_time = time.time()
output_ids = model_f.generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=temperature,
pad_token_id=tokenizer_f.eos_token_id
)
# End timing
end_time = time.time()
# Calculate the duration
duration = end_time - start_time
# Decode the answer
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
passage_keywords = set(passage.lower().split())
answer_keywords = set(answer.lower().split())
if passage_keywords.intersection(answer_keywords):
return answer, duration
else:
return "Sorry, I can't help with that.", duration
def remove_answer_prefix(text):
prefix = "Answer:"
if prefix in text:
return text.split(prefix)[-1].strip()
return text
def remove_incomplete_sentence(text):
# Check if the text ends with a period
if not text.endswith('.'):
# Find the last period or the end of the string
last_period_index = text.rfind('.')
if last_period_index != -1:
# Remove everything after the last period
return text[:last_period_index + 1].strip()
return text
@app.get("/")
async def root():
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
status = {
'status': 'healthy',
'models_loaded': bool(models),
'embeddings_loaded': bool(data.get('embeddings')),
'documents_loaded': not data.get('df', pd.DataFrame()).empty
}
return status
@app.post("/api/chat")
async def chat_endpoint(chat_query: ChatQuery):
try:
query_text = chat_query.query
query_embedding = embed_query_text(query_text)
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
document_ids = [doc_id for doc_id, _ in initial_results]
document_texts = retrieve_document_texts(document_ids, folder_path)
flattened_relevant_portions = []
for doc_id, portions in relevant_portions.items():
flattened_relevant_portions.extend(portions)
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
combined_parts = " ".join(unique_selected_parts)
context = [query_text] + unique_selected_parts
entities = extract_entities(query_text)
passage = enhance_passage_with_entities(combined_parts, entities)
prompt = create_prompt(query_text, passage)
answer, generation_time = generate_answer(prompt)
answer_part = answer.split("Answer:")[-1].strip()
cleaned_answer = remove_answer_prefix(answer_part)
final_answer = remove_incomplete_sentence(cleaned_answer)
return {
"response": final_answer,
"conversation_id": chat_query.conversation_id,
"success": True
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/resources")
async def resources_endpoint(profile: MedicalProfile):
try:
context = f"""
Medical conditions: {', '.join(profile.chronic_conditions)}
Current symptoms: {', '.join(profile.daily_symptoms)}
Restrictions: {', '.join(profile.food_restrictions)}
Mental health: {', '.join(profile.mental_conditions)}
"""
query_embedding = models['embedding'].encode([context])
relevant_docs = query_embeddings(query_embedding)
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
doc_texts = [text for text in doc_texts if text.strip()]
rerank_scores = rerank_documents(context, doc_texts)
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True)
resources = []
for (doc_id, _), score, text in ranked_docs[:10]:
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0]
resources.append({
"id": doc_id,
"title": doc_info['title'],
"content": text[:200],
"score": float(score)
})
return {"resources": resources, "success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/recipes")
async def recipes_endpoint(profile: MedicalProfile):
try:
recipe_query = f"Recipes and meals suitable for someone with: {', '.join(profile.chronic_conditions + profile.food_restrictions)}"
query_embedding = models['embedding'].encode([recipe_query])
relevant_docs = query_embeddings(query_embedding)
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
doc_texts = [text for text in doc_texts if text.strip()]
rerank_scores = rerank_documents(recipe_query, doc_texts)
ranked_docs = sorted(zip(relevant_docs, rerank_scores, doc_texts), key=lambda x: x[1], reverse=True)
recipes = []
for (doc_id, _), score, text in ranked_docs[:10]:
doc_info = data['df'][data['df']['id'] == doc_id].iloc[0]
if 'recipe' in text.lower() or 'meal' in text.lower():
recipes.append({
"id": doc_id,
"title": doc_info['title'],
"content": text[:200],
"score": float(score)
})
return {"recipes": recipes[:5], "success": True}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Initialize application
print("Initializing application...")
init_success = load_models() and load_data()
if not init_success:
print("Warning: Application initialized with partial functionality")
# For running locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)