Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import numpy as np | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
AutoModelForTokenClassification, | |
AutoModelForCausalLM, | |
pipeline | |
) | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from sklearn.metrics.pairwise import cosine_similarity | |
from bs4 import BeautifulSoup | |
import nltk | |
import torch | |
import pandas as pd | |
app = Flask(__name__) | |
CORS(app) | |
# Global variables for models and data | |
models = {} | |
data = {} | |
def init_nltk(): | |
"""Initialize NLTK resources""" | |
try: | |
nltk.download('punkt', quiet=True) | |
return True | |
except Exception as e: | |
print(f"Error initializing NLTK: {e}") | |
return False | |
def load_models(): | |
"""Initialize all required models""" | |
try: | |
print("Loading models...") | |
# Embedding models | |
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
# Translation models | |
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
# NER model | |
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer']) | |
# LLM model | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name) | |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name) | |
print("Models loaded successfully") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
def load_data(): | |
"""Load embeddings and document data""" | |
try: | |
print("Loading data...") | |
# Load embeddings | |
with open('embeddings.pkl', 'rb') as f: | |
data['embeddings'] = pickle.load(f) | |
# Load document links | |
data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx') | |
print("Data loaded successfully") | |
return True | |
except Exception as e: | |
print(f"Error loading data: {e}") | |
return False | |
def translate_text(text, source_to_target='ar_to_en'): | |
"""Translate text between Arabic and English""" | |
try: | |
if source_to_target == 'ar_to_en': | |
tokenizer = models['ar_to_en_tokenizer'] | |
model = models['ar_to_en_model'] | |
else: | |
tokenizer = models['en_to_ar_tokenizer'] | |
model = models['en_to_ar_model'] | |
inputs = tokenizer(text, return_tensors="pt", truncation=True) | |
outputs = model.generate(**inputs) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return text | |
def query_embeddings(query_embedding, n_results=5): | |
"""Find relevant documents using embedding similarity""" | |
doc_ids = list(data['embeddings'].keys()) | |
doc_embeddings = np.array(list(data['embeddings'].values())) | |
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(doc_ids[i], similarities[i]) for i in top_indices] | |
def retrieve_document_text(doc_id): | |
"""Retrieve document text from HTML file""" | |
try: | |
with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
return soup.get_text(separator=' ', strip=True) | |
except Exception as e: | |
print(f"Error retrieving document {doc_id}: {e}") | |
return "" | |
def extract_entities(text): | |
"""Extract medical entities from text""" | |
try: | |
results = models['ner_pipeline'](text) | |
return list({result['word'] for result in results if result['entity'].startswith("B-")}) | |
except Exception as e: | |
print(f"Error extracting entities: {e}") | |
return [] | |
def generate_answer(query, context, max_length=860, temperature=0.2): | |
"""Generate answer using LLM""" | |
try: | |
prompt = f""" | |
As a medical expert, answer the following question based only on the provided context: | |
Context: {context} | |
Question: {query} | |
Answer:""" | |
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True) | |
outputs = models['llm_model'].generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
pad_token_id=models['llm_tokenizer'].eos_token_id | |
) | |
answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True) | |
return answer.split("Answer:")[-1].strip() | |
except Exception as e: | |
print(f"Error generating answer: {e}") | |
return "Sorry, I couldn't generate an answer at this time." | |
def health_check(): | |
"""Health check endpoint""" | |
return jsonify({'status': 'healthy'}) | |
def process_query(): | |
"""Main query processing endpoint""" | |
try: | |
data = request.json | |
if not data or 'query' not in data: | |
return jsonify({'error': 'No query provided', 'success': False}), 400 | |
query_text = data['query'] | |
language_code = data.get('language_code', 0) | |
# Translate if Arabic | |
if language_code == 0: | |
query_text = translate_text(query_text, 'ar_to_en') | |
# Get query embedding and find relevant documents | |
query_embedding = models['embedding'].encode([query_text]) | |
relevant_docs = query_embeddings(query_embedding) | |
# Retrieve and process documents | |
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs] | |
# Extract entities and generate context | |
query_entities = extract_entities(query_text) | |
contexts = [] | |
for text in doc_texts: | |
doc_entities = extract_entities(text) | |
if set(query_entities) & set(doc_entities): | |
contexts.append(text) | |
context = " ".join(contexts[:3]) # Use top 3 most relevant contexts | |
# Generate answer | |
answer = generate_answer(query_text, context) | |
# Translate back if needed | |
if language_code == 0: | |
answer = translate_text(answer, 'en_to_ar') | |
return jsonify({ | |
'answer': answer, | |
'success': True | |
}) | |
except Exception as e: | |
return jsonify({ | |
'error': str(e), | |
'success': False | |
}), 500 | |
# Initialize everything when the app starts | |
print("Initializing application...") | |
init_success = init_nltk() and load_models() and load_data() | |
if not init_success: | |
print("Failed to initialize application") | |
exit(1) | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=7860) |