TeaRAG / app.py
thechaiexperiment's picture
Update app.py
b2bdaba
raw
history blame
11.9 kB
import os
import pickle
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForCausalLM,
pipeline
)
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import nltk
import torch
import pandas as pd
import subprocess
from typing import Dict, Optional
import codecs
from huggingface_hub import hf_hub_download
# Initialize FastAPI app
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for models and data
models = {}
data = {}
class QueryRequest(BaseModel):
query: str
language_code: int = 0
def init_nltk():
"""Initialize NLTK resources"""
try:
nltk.download('punkt', quiet=True)
return True
except Exception as e:
print(f"Error initializing NLTK: {e}")
return False
def load_models():
"""Initialize all required models"""
try:
print("Loading models...")
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device set to use {device}")
# Embedding models
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
# Translation models
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# NER model
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
# LLM model
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
print("Models loaded successfully")
return True
except Exception as e:
print(f"Error loading models: {e}")
return False
def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
"""Load embeddings from local file or HuggingFace Hub"""
try:
import pickle
import numpy as np
import os
from typing import Dict, Optional
embeddings_path = 'embeddings.pkl'
if not os.path.exists(embeddings_path):
from huggingface_hub import hf_hub_download
embeddings_path = hf_hub_download(
repo_id=os.environ.get('HF_SPACE_ID', ''),
filename="embeddings.pkl",
repo_type="space"
)
class ASCIIUnpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == "__main__":
module = "numpy"
return super().find_class(module, name)
with open(embeddings_path, 'rb') as f:
unpickler = ASCIIUnpickler(f)
embeddings = unpickler.load()
if not isinstance(embeddings, dict):
return None
return {k: np.array(v, dtype=np.float32) for k, v in embeddings.items()}
except Exception as e:
print(f"Error loading embeddings: {e}")
return None
def load_documents_data():
"""Load document data with error handling"""
try:
print("Loading documents data...")
docs_path = 'finalcleaned_excel_file.xlsx'
if not os.path.exists(docs_path):
print(f"Error: {docs_path} not found")
return False
data['df'] = pd.read_excel(docs_path)
print(f"Successfully loaded {len(data['df'])} document records")
return True
except Exception as e:
print(f"Error loading documents data: {e}")
data['df'] = pd.DataFrame()
return False
def load_data():
"""Load all required data"""
embeddings_success = load_embeddings()
documents_success = load_documents_data()
if not embeddings_success:
print("Warning: Failed to load embeddings, falling back to basic functionality")
if not documents_success:
print("Warning: Failed to load documents data, falling back to basic functionality")
return True
def translate_text(text, source_to_target='ar_to_en'):
"""Translate text between Arabic and English"""
try:
if source_to_target == 'ar_to_en':
tokenizer = models['ar_to_en_tokenizer']
model = models['ar_to_en_model']
else:
tokenizer = models['en_to_ar_tokenizer']
model = models['en_to_ar_model']
inputs = tokenizer(text, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Translation error: {e}")
return text
def extract_entities(text):
"""Extract medical entities from text using NER"""
try:
results = models['ner_pipeline'](text)
return list({result['word'] for result in results if result['entity'].startswith("B-")})
except Exception as e:
print(f"Error extracting entities: {e}")
return []
def generate_answer(query, context, max_length=860, temperature=0.2):
"""Generate answer using LLM"""
try:
prompt = f"""
As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
Context: {context}
Question: {query}
Answer: Let me help you with accurate information from reliable medical sources."""
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
with torch.no_grad():
outputs = models['llm_model'].generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=temperature,
do_sample=True,
top_p=0.9,
pad_token_id=models['llm_tokenizer'].eos_token_id
)
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
if "Answer:" in response:
response = response.split("Answer:")[-1].strip()
sentences = nltk.sent_tokenize(response)
if sentences:
return " ".join(sentences)
return response
except Exception as e:
print(f"Error generating answer: {e}")
return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
def query_embeddings(query_embedding, n_results=5):
"""Find relevant documents using embedding similarity"""
if not data['embeddings']:
return []
try:
doc_ids = list(data['embeddings'].keys())
doc_embeddings = np.array(list(data['embeddings'].values()))
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
top_indices = similarities.argsort()[-n_results:][::-1]
return [(doc_ids[i], similarities[i]) for i in top_indices]
except Exception as e:
print(f"Error in query_embeddings: {e}")
return []
def retrieve_document_text(doc_id):
"""Retrieve document text from HTML file"""
try:
file_path = os.path.join('downloaded_articles', doc_id)
if not os.path.exists(file_path):
print(f"Warning: Document file not found: {file_path}")
return ""
with open(file_path, 'r', encoding='utf-8') as file:
soup = BeautifulSoup(file, 'html.parser')
return soup.get_text(separator=' ', strip=True)
except Exception as e:
print(f"Error retrieving document {doc_id}: {e}")
return ""
def rerank_documents(query, doc_texts):
"""Rerank documents using cross-encoder"""
try:
pairs = [(query, doc) for doc in doc_texts]
scores = models['cross_encoder'].predict(pairs)
return scores
except Exception as e:
print(f"Error reranking documents: {e}")
return np.zeros(len(doc_texts))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
status = {
'status': 'healthy',
'models_loaded': bool(models),
'embeddings_loaded': bool(data.get('embeddings')),
'documents_loaded': not data.get('df', pd.DataFrame()).empty
}
return status
@app.post("/api/query")
async def process_query(request: QueryRequest):
"""Main query processing endpoint"""
try:
query_text = request.query
language_code = request.language_code
if not models or not data.get('embeddings'):
raise HTTPException(
status_code=503,
detail="The system is currently initializing. Please try again in a few minutes."
)
try:
if language_code == 0:
query_text = translate_text(query_text, 'ar_to_en')
query_embedding = models['embedding'].encode([query_text])
relevant_docs = query_embeddings(query_embedding)
if not relevant_docs:
return {
'answer': 'No relevant information found. Please try a different query.',
'success': True
}
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
doc_texts = [text for text in doc_texts if text.strip()]
if not doc_texts:
return {
'answer': 'Unable to retrieve relevant documents. Please try again.',
'success': True
}
rerank_scores = rerank_documents(query_text, doc_texts)
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
context = " ".join(ranked_texts[:3])
answer = generate_answer(query_text, context)
if language_code == 0:
answer = translate_text(answer, 'en_to_ar')
return {
'answer': answer,
'success': True
}
except Exception as e:
print(f"Error processing query: {e}")
raise HTTPException(
status_code=500,
detail="An error occurred while processing your query"
)
except Exception as e:
print(f"Error in process_query: {e}")
raise HTTPException(
status_code=500,
detail=str(e)
)
# Initialize application
print("Initializing application...")
init_success = init_nltk() and load_models() and load_data()
if not init_success:
print("Warning: Application initialized with partial functionality")
# For running locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)