TeaRAG / app.py
thechaiexperiment's picture
Update app.py
7becdb7
raw
history blame
12.1 kB
import os
import pickle
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForCausalLM,
pipeline
)
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import nltk
import torch
import pandas as pd
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models and data
models = {}
data = {}
class QueryRequest(BaseModel):
query: str
language_code: int = 0
def init_nltk():
"""Initialize NLTK resources"""
try:
nltk.download('punkt', quiet=True)
return True
except Exception as e:
print(f"Error initializing NLTK: {e}")
return False
def load_models():
"""Initialize all required models"""
try:
print("Loading models...")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set to use {device}")
# Embedding models
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
# Translation models
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# NER model
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
# LLM model
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
print("Models loaded successfully")
return True
except Exception as e:
print(f"Error loading models: {e}")
return False
import pickle
import numpy as np
import os
from typing import Dict, Optional
def load_embeddings(embeddings_path: str = 'embeddings.pkl') -> Optional[Dict[str, np.ndarray]]:
"""
Load embeddings from a pickle file containing a dictionary of numpy arrays.
Args:
embeddings_path (str): Path to the pickle file containing embeddings
Returns:
Optional[Dict[str, np.ndarray]]: Dictionary of embeddings or None if loading fails
"""
if not os.path.exists(embeddings_path):
print(f"Error: {embeddings_path} not found")
return None
try:
with open(embeddings_path, 'rb') as f:
embeddings = pickle.load(f)
# Validate the loaded data
if not isinstance(embeddings, dict):
print(f"Error: Expected dict, got {type(embeddings)}")
return None
# Convert values to numpy arrays if they aren't already
for key in embeddings:
if not isinstance(embeddings[key], np.ndarray):
embeddings[key] = np.array(embeddings[key])
# Print sample for verification
sample_key = next(iter(embeddings))
print(f"Data type: {type(embeddings)}")
print(f"First few keys and values:")
print(f"Key: {sample_key}, Value: {embeddings[sample_key][:20]}") # Show first 20 values
print(f"Successfully loaded {len(embeddings)} embeddings")
return embeddings
except Exception as e:
print(f"Error loading embeddings: {str(e)}")
return None
def load_documents_data():
"""Load document data with error handling"""
try:
print("Loading documents data...")
docs_path = 'finalcleaned_excel_file.xlsx'
if not os.path.exists(docs_path):
print(f"Error: {docs_path} not found")
return False
data['df'] = pd.read_excel(docs_path)
print(f"Successfully loaded {len(data['df'])} document records")
return True
except Exception as e:
print(f"Error loading documents data: {e}")
data['df'] = pd.DataFrame()
return False
def load_data():
"""Load all required data"""
embeddings_success = load_embeddings()
documents_success = load_documents_data()
if not embeddings_success:
print("Warning: Failed to load embeddings, falling back to basic functionality")
if not documents_success:
print("Warning: Failed to load documents data, falling back to basic functionality")
return True
def translate_text(text, source_to_target='ar_to_en'):
"""Translate text between Arabic and English"""
try:
if source_to_target == 'ar_to_en':
tokenizer = models['ar_to_en_tokenizer']
model = models['ar_to_en_model']
else:
tokenizer = models['en_to_ar_tokenizer']
model = models['en_to_ar_model']
inputs = tokenizer(text, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Translation error: {e}")
return text
def extract_entities(text):
"""Extract medical entities from text using NER"""
try:
results = models['ner_pipeline'](text)
return list({result['word'] for result in results if result['entity'].startswith("B-")})
except Exception as e:
print(f"Error extracting entities: {e}")
return []
def generate_answer(query, context, max_length=860, temperature=0.2):
"""Generate answer using LLM"""
try:
prompt = f"""
As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
Context: {context}
Question: {query}
Answer: Let me help you with accurate information from reliable medical sources."""
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = models['llm_model'].generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=models['llm_tokenizer'].eos_token_id
)
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
if "Answer:" in response:
response = response.split("Answer:")[-1].strip()
sentences = nltk.sent_tokenize(response)
if sentences:
return " ".join(sentences)
return response
except Exception as e:
print(f"Error generating answer: {e}")
return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
def query_embeddings(query_embedding, n_results=5):
"""Find relevant documents using embedding similarity"""
if not data['embeddings']:
return []
try:
doc_ids = list(data['embeddings'].keys())
doc_embeddings = np.array(list(data['embeddings'].values()))
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
top_indices = similarities.argsort()[-n_results:][::-1]
return [(doc_ids[i], similarities[i]) for i in top_indices]
except Exception as e:
print(f"Error in query_embeddings: {e}")
return []
def retrieve_document_text(doc_id):
"""Retrieve document text from HTML file"""
try:
file_path = os.path.join('downloaded_articles', doc_id)
if not os.path.exists(file_path):
print(f"Warning: Document file not found: {file_path}")
return ""
with open(file_path, 'r', encoding='utf-8') as file:
soup = BeautifulSoup(file, 'html.parser')
return soup.get_text(separator=' ', strip=True)
except Exception as e:
print(f"Error retrieving document {doc_id}: {e}")
return ""
def rerank_documents(query, doc_texts):
"""Rerank documents using cross-encoder"""
try:
pairs = [(query, doc) for doc in doc_texts]
scores = models['cross_encoder'].predict(pairs)
return scores
except Exception as e:
print(f"Error reranking documents: {e}")
return np.zeros(len(doc_texts))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
status = {
'status': 'healthy',
'models_loaded': bool(models),
'embeddings_loaded': bool(data.get('embeddings')),
'documents_loaded': not data.get('df', pd.DataFrame()).empty
}
return status
@app.post("/api/query")
async def process_query(request: QueryRequest):
"""Main query processing endpoint"""
try:
query_text = request.query
language_code = request.language_code
if not models or not data.get('embeddings'):
raise HTTPException(
status_code=503,
detail="The system is currently initializing. Please try again in a few minutes."
)
try:
if language_code == 0:
query_text = translate_text(query_text, 'ar_to_en')
query_embedding = models['embedding'].encode([query_text])
relevant_docs = query_embeddings(query_embedding)
if not relevant_docs:
return {
'answer': 'No relevant information found. Please try a different query.',
'success': True
}
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
doc_texts = [text for text in doc_texts if text.strip()]
if not doc_texts:
return {
'answer': 'Unable to retrieve relevant documents. Please try again.',
'success': True
}
rerank_scores = rerank_documents(query_text, doc_texts)
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
context = " ".join(ranked_texts[:3])
answer = generate_answer(query_text, context)
if language_code == 0:
answer = translate_text(answer, 'en_to_ar')
return {
'answer': answer,
'success': True
}
except Exception as e:
print(f"Error processing query: {e}")
raise HTTPException(
status_code=500,
detail="An error occurred while processing your query"
)
except Exception as e:
print(f"Error in process_query: {e}")
raise HTTPException(
status_code=500,
detail=str(e)
)
# Initialize application
print("Initializing application...")
init_success = init_nltk() and load_models() and load_data()
if not init_success:
print("Warning: Application initialized with partial functionality")
# For running locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)