Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict | |
import pickle | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from bs4 import BeautifulSoup | |
import os | |
import nltk | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
BartForConditionalGeneration, | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM, | |
AutoModelForTokenClassification | |
) | |
import pandas as pd | |
import time | |
# Initialize FastAPI app first | |
app = FastAPI() | |
class ArticleEmbeddingUnpickler(pickle.Unpickler): | |
"""Custom unpickler for article embeddings with enhanced persistence handling""" | |
def find_class(self, module: str, name: str) -> Any: | |
if module == 'numpy': | |
return getattr(np, name) | |
if module == 'sentence_transformers.SentenceTransformer': | |
from sentence_transformers import SentenceTransformer | |
return SentenceTransformer | |
return super().find_class(module, name) | |
def persistent_load(self, pid: Any) -> str: | |
"""Enhanced persistent ID handler with better encoding management""" | |
try: | |
# Handle different types of persistent IDs | |
if isinstance(pid, bytes): | |
return pid.decode('utf-8', errors='replace') | |
if isinstance(pid, (str, int, float)): | |
return str(pid) | |
return repr(pid) | |
except Exception as e: | |
print(f"Warning: Error in persistent_load: {str(e)}") | |
return repr(pid) | |
def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]: | |
"""Load embeddings with enhanced error handling and validation""" | |
try: | |
if not os.path.exists(file_path): | |
raise FileNotFoundError(f"Embeddings file not found at {file_path}") | |
with open(file_path, 'rb') as file: | |
unpickler = ArticleEmbeddingUnpickler(file) | |
embeddings_data = unpickler.load() | |
if not isinstance(embeddings_data, dict): | |
raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}") | |
# Process and validate embeddings | |
valid_embeddings = {} | |
for key, value in embeddings_data.items(): | |
try: | |
# Ensure key is a valid string | |
key_str = str(key).strip() | |
if not key_str: | |
continue | |
# Convert value to numpy array if needed | |
if isinstance(value, list): | |
value = np.array(value, dtype=np.float32) | |
elif isinstance(value, np.ndarray): | |
value = value.astype(np.float32) | |
else: | |
print(f"Skipping invalid embedding type for key {key_str}: {type(value)}") | |
continue | |
# Validate array dimensions and values | |
if value.ndim != 1: | |
print(f"Skipping invalid embedding shape for key {key_str}: {value.shape}") | |
continue | |
if np.isnan(value).any() or np.isinf(value).any(): | |
print(f"Skipping embedding with invalid values for key {key_str}") | |
continue | |
valid_embeddings[key_str] = value | |
except Exception as e: | |
print(f"Error processing embedding for key {key}: {str(e)}") | |
continue | |
if not valid_embeddings: | |
raise ValueError("No valid embeddings found in file") | |
print(f"Successfully loaded {len(valid_embeddings)} valid embeddings") | |
return valid_embeddings | |
except Exception as e: | |
print(f"Error loading embeddings: {str(e)}") | |
raise | |
def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'): | |
# Convert all keys to ASCII-safe strings | |
cleaned_embeddings = { | |
str(key).encode('ascii', errors='replace').decode('ascii'): value | |
for key, value in embeddings_dict.items() | |
} | |
with open(file_path, 'wb') as f: | |
pickle.dump(cleaned_embeddings, f, protocol=0) | |
# Models and data structures | |
class GlobalModels: | |
embedding_model = None | |
cross_encoder = None | |
semantic_model = None | |
tokenizer = None | |
model = None | |
tokenizer_f = None | |
model_f = None | |
ar_to_en_tokenizer = None | |
ar_to_en_model = None | |
en_to_ar_tokenizer = None | |
en_to_ar_model = None | |
embeddings_data = None | |
file_name_to_url = None | |
bio_tokenizer = None | |
bio_model = None | |
# Initialize global models | |
global_models = GlobalModels() | |
# Download NLTK data | |
nltk.download('punkt') | |
# Pydantic models for request validation | |
class QueryInput(BaseModel): | |
query_text: str | |
language_code: int # 0 for Arabic, 1 for English | |
query_type: str # "profile" or "question" | |
previous_qa: Optional[List[Dict[str, str]]] = None | |
class DocumentResponse(BaseModel): | |
title: str | |
url: str | |
text: str | |
score: float | |
# Modified startup event handler | |
async def load_models(): | |
try: | |
print("Starting to load embeddings...") | |
embeddings_data = safe_load_embeddings() | |
print(f"Embeddings data type: {type(embeddings_data)}") | |
if embeddings_data: | |
print(f"Number of embeddings: {len(embeddings_data)}") | |
# Print sample of keys | |
print("Sample keys:", list(embeddings_data.keys())[:3]) | |
# Load embedding models first | |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Load embeddings data with new safe loader | |
embeddings_data = safe_load_embeddings() | |
if embeddings_data is None: | |
raise HTTPException(status_code=500, detail="Failed to load embeddings data") | |
global_models.embeddings_data = embeddings_data | |
# Load remaining models | |
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Load BART models | |
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | |
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
# Load Orca model | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) | |
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) | |
# Load translation models | |
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
# Load Medical NER models | |
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
# Load URL mapping data | |
try: | |
df = pd.read_excel('finalcleaned_excel_file.xlsx') | |
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} | |
except Exception as e: | |
print(f"Error loading URL mapping data: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load URL mapping data.") | |
print("All models loaded successfully") | |
except Exception as e: | |
print(f"Error during startup: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}") | |
# Models and data structures to store loaded models | |
class GlobalModels: | |
embedding_model = None | |
cross_encoder = None | |
semantic_model = None | |
tokenizer = None | |
model = None | |
tokenizer_f = None | |
model_f = None | |
ar_to_en_tokenizer = None | |
ar_to_en_model = None | |
en_to_ar_tokenizer = None | |
en_to_ar_model = None | |
embeddings_data = None | |
file_name_to_url = None | |
bio_tokenizer = None | |
bio_model = None | |
global_models = GlobalModels() | |
# Download NLTK data | |
nltk.download('punkt') | |
# Pydantic models for request validation | |
class QueryInput(BaseModel): | |
query_text: str | |
language_code: int # 0 for Arabic, 1 for English | |
query_type: str # "profile" or "question" | |
previous_qa: Optional[List[Dict[str, str]]] = None | |
class DocumentResponse(BaseModel): | |
title: str | |
url: str | |
text: str | |
score: float | |
async def load_models(): | |
"""Initialize all models and data on startup""" | |
try: | |
# Load embedding models | |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Load BART models | |
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | |
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
# Load Orca model | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) | |
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) | |
# Load translation models | |
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
# Load Medical NER models | |
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
# Load embeddings data with better error handling | |
try: | |
with open('embeddings.pkl', 'rb') as file: | |
global_models.embeddings_data = pickle.load(file) | |
except (FileNotFoundError, pickle.UnpicklingError) as e: | |
print(f"Error loading embeddings data: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load embeddings data.") | |
# Load URL mapping data | |
try: | |
df = pd.read_excel('finalcleaned_excel_file.xlsx') | |
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} | |
except Exception as e: | |
print(f"Error loading URL mapping data: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load URL mapping data.") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
raise HTTPException(status_code=500, detail="Failed to load models.") | |
def translate_ar_to_en(text): | |
try: | |
inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = global_models.ar_to_en_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during Arabic to English translation: {e}") | |
return None | |
def translate_en_to_ar(text): | |
try: | |
inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = global_models.en_to_ar_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during English to Arabic translation: {e}") | |
return None | |
def process_query(query_text, language_code): | |
if language_code == 0: | |
return translate_ar_to_en(query_text) | |
return query_text | |
def embed_query_text(query_text): | |
return global_models.embedding_model.encode([query_text]) | |
def query_embeddings(query_embedding, n_results=5): | |
doc_ids = list(global_models.embeddings_data.keys()) | |
doc_embeddings = np.array(list(global_models.embeddings_data.values())) | |
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(doc_ids[i], similarities[i]) for i in top_indices] | |
def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'): | |
texts = [] | |
for doc_id in doc_ids: | |
file_path = os.path.join(folder_path, doc_id) | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
text = soup.get_text(separator=' ', strip=True) | |
texts.append(text) | |
except FileNotFoundError: | |
texts.append("") | |
return texts | |
def extract_entities(text): | |
inputs = global_models.bio_tokenizer(text, return_tensors="pt") | |
outputs = global_models.bio_model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=2) | |
tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] | |
def create_prompt(question, passage): | |
return f""" | |
As a medical expert, you are required to answer the following question based only on the provided passage. | |
Do not include any information not present in the passage. Your response should directly reflect the content | |
of the passage. Maintain accuracy and relevance to the provided information. | |
Passage: {passage} | |
Question: {question} | |
Answer: | |
""" | |
def generate_answer(prompt, max_length=860, temperature=0.2): | |
inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True) | |
start_time = time.time() | |
output_ids = global_models.model_f.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
pad_token_id=global_models.tokenizer_f.eos_token_id | |
) | |
duration = time.time() - start_time | |
answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True) | |
return answer, duration | |
def clean_answer(answer): | |
answer_part = answer.split("Answer:")[-1].strip() | |
if not answer_part.endswith('.'): | |
last_period_index = answer_part.rfind('.') | |
if last_period_index != -1: | |
answer_part = answer_part[:last_period_index + 1].strip() | |
return answer_part | |
async def retrieve_documents(input_data: QueryInput): | |
try: | |
# Process query | |
processed_query = process_query(input_data.query_text, input_data.language_code) | |
query_embedding = embed_query_text(processed_query) | |
results = query_embeddings(query_embedding) | |
# Get document texts and rerank | |
document_ids = [doc_id for doc_id, _ in results] | |
document_texts = retrieve_document_texts(document_ids) | |
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts]) | |
# Prepare response | |
documents = [] | |
for score, doc_id, text in zip(scores, document_ids, document_texts): | |
url = global_models.file_name_to_url.get(doc_id, "") | |
documents.append({ | |
"title": doc_id, | |
"url": url, | |
"text": text if input_data.language_code == 1 else translate_en_to_ar(text), | |
"score": float(score) | |
}) | |
return {"status": "success", "documents": documents} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_answer(input_data: QueryInput): | |
try: | |
# Process query | |
processed_query = process_query(input_data.query_text, input_data.language_code) | |
# Get relevant documents | |
query_embedding = embed_query_text(processed_query) | |
results = query_embeddings(query_embedding) | |
document_ids = [doc_id for doc_id, _ in results] | |
document_texts = retrieve_document_texts(document_ids) | |
# Extract entities and create context | |
entities = extract_entities(processed_query) | |
context = " ".join(document_texts) | |
enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}" | |
# Generate answer | |
prompt = create_prompt(processed_query, enhanced_context) | |
answer, duration = generate_answer(prompt) | |
final_answer = clean_answer(answer) | |
# Translate if needed | |
if input_data.language_code == 0: | |
final_answer = translate_en_to_ar(final_answer) | |
return { | |
"status": "success", | |
"answer": final_answer, | |
"processing_time": duration | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def root(): | |
return {"message": "Server is running"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |