from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict import pickle import numpy as np from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer, CrossEncoder, util from bs4 import BeautifulSoup import os import nltk import torch from transformers import ( AutoTokenizer, BartForConditionalGeneration, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForTokenClassification ) import pandas as pd import time # Initialize FastAPI app first app = FastAPI() class CustomUnpickler(pickle.Unpickler): def persistent_load(self, pid): try: if isinstance(pid, bytes): pid = pid.decode('utf-8', errors='ignore') pid = str(pid).encode('ascii', errors='ignore').decode('ascii') if pid == "sentence_transformer_model": return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") return pid except Exception as e: raise pickle.UnpicklingError(f"Error handling persistent ID: {e}") def safe_load_embeddings(): try: with open('embeddings.pkl', 'rb') as file: unpickler = CustomUnpickler(file) embeddings_data = unpickler.load() if not isinstance(embeddings_data, dict): raise ValueError("Loaded data is not a dictionary") first_key = next(iter(embeddings_data)) if not isinstance(embeddings_data[first_key], (np.ndarray, list)): raise ValueError("Embeddings are not in the expected format") return embeddings_data except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e: print(f"Error loading embeddings: {str(e)}") return None # Models and data structures class GlobalModels: embedding_model = None cross_encoder = None semantic_model = None tokenizer = None model = None tokenizer_f = None model_f = None ar_to_en_tokenizer = None ar_to_en_model = None en_to_ar_tokenizer = None en_to_ar_model = None embeddings_data = None file_name_to_url = None bio_tokenizer = None bio_model = None # Initialize global models global_models = GlobalModels() # Download NLTK data nltk.download('punkt') # Pydantic models for request validation class QueryInput(BaseModel): query_text: str language_code: int # 0 for Arabic, 1 for English query_type: str # "profile" or "question" previous_qa: Optional[List[Dict[str, str]]] = None class DocumentResponse(BaseModel): title: str url: str text: str score: float @app.on_event("startup") async def load_models(): """Initialize all models and data on startup""" try: # Load embedding models first global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') # Load embeddings data with new safe loader embeddings_data = safe_load_embeddings() if embeddings_data is None: raise HTTPException(status_code=500, detail="Failed to load embeddings data") global_models.embeddings_data = embeddings_data # Load remaining models global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') # Load BART models global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") # Load Orca model model_name = "M4-ai/Orca-2.0-Tau-1.8B" global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) # Load translation models global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") # Load Medical NER models global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") # Load URL mapping data try: df = pd.read_excel('finalcleaned_excel_file.xlsx') global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} except Exception as e: print(f"Error loading URL mapping data: {e}") raise HTTPException(status_code=500, detail="Failed to load URL mapping data.") print("All models loaded successfully") except Exception as e: print(f"Error during startup: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}") # Models and data structures to store loaded models class GlobalModels: embedding_model = None cross_encoder = None semantic_model = None tokenizer = None model = None tokenizer_f = None model_f = None ar_to_en_tokenizer = None ar_to_en_model = None en_to_ar_tokenizer = None en_to_ar_model = None embeddings_data = None file_name_to_url = None bio_tokenizer = None bio_model = None global_models = GlobalModels() # Download NLTK data nltk.download('punkt') # Pydantic models for request validation class QueryInput(BaseModel): query_text: str language_code: int # 0 for Arabic, 1 for English query_type: str # "profile" or "question" previous_qa: Optional[List[Dict[str, str]]] = None class DocumentResponse(BaseModel): title: str url: str text: str score: float @app.on_event("startup") async def load_models(): """Initialize all models and data on startup""" try: # Load embedding models global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') # Load BART models global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") # Load Orca model model_name = "M4-ai/Orca-2.0-Tau-1.8B" global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) # Load translation models global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") # Load Medical NER models global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") # Load embeddings data with better error handling try: with open('embeddings.pkl', 'rb') as file: global_models.embeddings_data = pickle.load(file) except (FileNotFoundError, pickle.UnpicklingError) as e: print(f"Error loading embeddings data: {e}") raise HTTPException(status_code=500, detail="Failed to load embeddings data.") # Load URL mapping data try: df = pd.read_excel('finalcleaned_excel_file.xlsx') global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} except Exception as e: print(f"Error loading URL mapping data: {e}") raise HTTPException(status_code=500, detail="Failed to load URL mapping data.") except Exception as e: print(f"Error loading models: {e}") raise HTTPException(status_code=500, detail="Failed to load models.") def translate_ar_to_en(text): try: inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True) translated_ids = global_models.ar_to_en_model.generate( inputs.input_ids, max_length=512, num_beams=4, early_stopping=True ) translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True) return translated_text except Exception as e: print(f"Error during Arabic to English translation: {e}") return None def translate_en_to_ar(text): try: inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True) translated_ids = global_models.en_to_ar_model.generate( inputs.input_ids, max_length=512, num_beams=4, early_stopping=True ) translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True) return translated_text except Exception as e: print(f"Error during English to Arabic translation: {e}") return None def process_query(query_text, language_code): if language_code == 0: return translate_ar_to_en(query_text) return query_text def embed_query_text(query_text): return global_models.embedding_model.encode([query_text]) def query_embeddings(query_embedding, n_results=5): doc_ids = list(global_models.embeddings_data.keys()) doc_embeddings = np.array(list(global_models.embeddings_data.values())) similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() top_indices = similarities.argsort()[-n_results:][::-1] return [(doc_ids[i], similarities[i]) for i in top_indices] def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'): texts = [] for doc_id in doc_ids: file_path = os.path.join(folder_path, doc_id) try: with open(file_path, 'r', encoding='utf-8') as file: soup = BeautifulSoup(file, 'html.parser') text = soup.get_text(separator=' ', strip=True) texts.append(text) except FileNotFoundError: texts.append("") return texts def extract_entities(text): inputs = global_models.bio_tokenizer(text, return_tensors="pt") outputs = global_models.bio_model(**inputs) predictions = torch.argmax(outputs.logits, dim=2) tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] def create_prompt(question, passage): return f""" As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information. Passage: {passage} Question: {question} Answer: """ def generate_answer(prompt, max_length=860, temperature=0.2): inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True) start_time = time.time() output_ids = global_models.model_f.generate( inputs.input_ids, max_length=max_length, num_return_sequences=1, temperature=temperature, pad_token_id=global_models.tokenizer_f.eos_token_id ) duration = time.time() - start_time answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True) return answer, duration def clean_answer(answer): answer_part = answer.split("Answer:")[-1].strip() if not answer_part.endswith('.'): last_period_index = answer_part.rfind('.') if last_period_index != -1: answer_part = answer_part[:last_period_index + 1].strip() return answer_part @app.post("/retrieve_documents") async def retrieve_documents(input_data: QueryInput): try: # Process query processed_query = process_query(input_data.query_text, input_data.language_code) query_embedding = embed_query_text(processed_query) results = query_embeddings(query_embedding) # Get document texts and rerank document_ids = [doc_id for doc_id, _ in results] document_texts = retrieve_document_texts(document_ids) scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts]) # Prepare response documents = [] for score, doc_id, text in zip(scores, document_ids, document_texts): url = global_models.file_name_to_url.get(doc_id, "") documents.append({ "title": doc_id, "url": url, "text": text if input_data.language_code == 1 else translate_en_to_ar(text), "score": float(score) }) return {"status": "success", "documents": documents} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/get_answer") async def get_answer(input_data: QueryInput): try: # Process query processed_query = process_query(input_data.query_text, input_data.language_code) # Get relevant documents query_embedding = embed_query_text(processed_query) results = query_embeddings(query_embedding) document_ids = [doc_id for doc_id, _ in results] document_texts = retrieve_document_texts(document_ids) # Extract entities and create context entities = extract_entities(processed_query) context = " ".join(document_texts) enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}" # Generate answer prompt = create_prompt(processed_query, enhanced_context) answer, duration = generate_answer(prompt) final_answer = clean_answer(answer) # Translate if needed if input_data.language_code == 0: final_answer = translate_en_to_ar(final_answer) return { "status": "success", "answer": final_answer, "processing_time": duration } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return {"message": "Server is running"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)