TeaRAG / app.py
thechaiexperiment's picture
Update app.py
31bad44
raw
history blame
7.61 kB
import os
import pickle
import numpy as np
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForCausalLM,
pipeline
)
from sentence_transformers import SentenceTransformer, CrossEncoder
from sklearn.metrics.pairwise import cosine_similarity
from bs4 import BeautifulSoup
import nltk
import torch
import pandas as pd
app = Flask(__name__)
CORS(app)
# Global variables for models and data
models = {}
data = {}
def init_nltk():
"""Initialize NLTK resources"""
try:
nltk.download('punkt', quiet=True)
return True
except Exception as e:
print(f"Error initializing NLTK: {e}")
return False
def load_models():
"""Initialize all required models"""
try:
print("Loading models...")
# Embedding models
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
# Translation models
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
# NER model
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
# LLM model
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
print("Models loaded successfully")
return True
except Exception as e:
print(f"Error loading models: {e}")
return False
def load_data():
"""Load embeddings and document data"""
try:
print("Loading data...")
# Load embeddings
with open('embeddings.pkl', 'rb') as f:
data['embeddings'] = pickle.load(f)
# Load document links
data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx')
print("Data loaded successfully")
return True
except Exception as e:
print(f"Error loading data: {e}")
return False
def translate_text(text, source_to_target='ar_to_en'):
"""Translate text between Arabic and English"""
try:
if source_to_target == 'ar_to_en':
tokenizer = models['ar_to_en_tokenizer']
model = models['ar_to_en_model']
else:
tokenizer = models['en_to_ar_tokenizer']
model = models['en_to_ar_model']
inputs = tokenizer(text, return_tensors="pt", truncation=True)
outputs = model.generate(**inputs)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
print(f"Translation error: {e}")
return text
def query_embeddings(query_embedding, n_results=5):
"""Find relevant documents using embedding similarity"""
doc_ids = list(data['embeddings'].keys())
doc_embeddings = np.array(list(data['embeddings'].values()))
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
top_indices = similarities.argsort()[-n_results:][::-1]
return [(doc_ids[i], similarities[i]) for i in top_indices]
def retrieve_document_text(doc_id):
"""Retrieve document text from HTML file"""
try:
with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
soup = BeautifulSoup(file, 'html.parser')
return soup.get_text(separator=' ', strip=True)
except Exception as e:
print(f"Error retrieving document {doc_id}: {e}")
return ""
def extract_entities(text):
"""Extract medical entities from text"""
try:
results = models['ner_pipeline'](text)
return list({result['word'] for result in results if result['entity'].startswith("B-")})
except Exception as e:
print(f"Error extracting entities: {e}")
return []
def generate_answer(query, context, max_length=860, temperature=0.2):
"""Generate answer using LLM"""
try:
prompt = f"""
As a medical expert, answer the following question based only on the provided context:
Context: {context}
Question: {query}
Answer:"""
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
outputs = models['llm_model'].generate(
inputs.input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=temperature,
pad_token_id=models['llm_tokenizer'].eos_token_id
)
answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
return answer.split("Answer:")[-1].strip()
except Exception as e:
print(f"Error generating answer: {e}")
return "Sorry, I couldn't generate an answer at this time."
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({'status': 'healthy'})
@app.route('/api/query', methods=['POST'])
def process_query():
"""Main query processing endpoint"""
try:
data = request.json
if not data or 'query' not in data:
return jsonify({'error': 'No query provided', 'success': False}), 400
query_text = data['query']
language_code = data.get('language_code', 0)
# Translate if Arabic
if language_code == 0:
query_text = translate_text(query_text, 'ar_to_en')
# Get query embedding and find relevant documents
query_embedding = models['embedding'].encode([query_text])
relevant_docs = query_embeddings(query_embedding)
# Retrieve and process documents
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
# Extract entities and generate context
query_entities = extract_entities(query_text)
contexts = []
for text in doc_texts:
doc_entities = extract_entities(text)
if set(query_entities) & set(doc_entities):
contexts.append(text)
context = " ".join(contexts[:3]) # Use top 3 most relevant contexts
# Generate answer
answer = generate_answer(query_text, context)
# Translate back if needed
if language_code == 0:
answer = translate_text(answer, 'en_to_ar')
return jsonify({
'answer': answer,
'success': True
})
except Exception as e:
return jsonify({
'error': str(e),
'success': False
}), 500
# Initialize everything when the app starts
print("Initializing application...")
init_success = init_nltk() and load_models() and load_data()
if not init_success:
print("Failed to initialize application")
exit(1)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)