from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict import pickle import numpy as np from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer, CrossEncoder, util from bs4 import BeautifulSoup import os import nltk import torch from transformers import ( AutoTokenizer, BartForConditionalGeneration, AutoModelForCausalLM, AutoModelForSeq2SeqLM ) import pandas as pd import time app = FastAPI() # Models and data structures to store loaded models class GlobalModels: embedding_model = None cross_encoder = None semantic_model = None tokenizer = None model = None tokenizer_f = None model_f = None ar_to_en_tokenizer = None ar_to_en_model = None en_to_ar_tokenizer = None en_to_ar_model = None embeddings_data = None file_name_to_url = None bio_tokenizer = None bio_model = None global_models = GlobalModels() # Download NLTK data nltk.download('punkt') # Pydantic models for request validation class QueryInput(BaseModel): query_text: str language_code: int # 0 for Arabic, 1 for English query_type: str # "profile" or "question" previous_qa: Optional[List[Dict[str, str]]] = None class DocumentResponse(BaseModel): title: str url: str text: str score: float @app.on_event("startup") async def load_models(): """Initialize all models and data on startup""" try: # Load embedding models global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') # Load BART models global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") # Load Orca model model_name = "M4-ai/Orca-2.0-Tau-1.8B" global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) # Load translation models global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") # Load Medical NER models global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") # Load embeddings data with open('embeddings.pkl', 'rb') as file: global_models.embeddings_data = pickle.load(file) # Load URL mapping data df = pd.read_excel('finalcleaned_excel_file.xlsx') global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} except Exception as e: print(f"Error loading models: {e}") raise def translate_ar_to_en(text): try: inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True) translated_ids = global_models.ar_to_en_model.generate( inputs.input_ids, max_length=512, num_beams=4, early_stopping=True ) translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True) return translated_text except Exception as e: print(f"Error during Arabic to English translation: {e}") return None def translate_en_to_ar(text): try: inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True) translated_ids = global_models.en_to_ar_model.generate( inputs.input_ids, max_length=512, num_beams=4, early_stopping=True ) translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True) return translated_text except Exception as e: print(f"Error during English to Arabic translation: {e}") return None def process_query(query_text, language_code): if language_code == 0: return translate_ar_to_en(query_text) return query_text def embed_query_text(query_text): return global_models.embedding_model.encode([query_text]) def query_embeddings(query_embedding, n_results=5): doc_ids = list(global_models.embeddings_data.keys()) doc_embeddings = np.array(list(global_models.embeddings_data.values())) similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() top_indices = similarities.argsort()[-n_results:][::-1] return [(doc_ids[i], similarities[i]) for i in top_indices] def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'): texts = [] for doc_id in doc_ids: file_path = os.path.join(folder_path, doc_id) try: with open(file_path, 'r', encoding='utf-8') as file: soup = BeautifulSoup(file, 'html.parser') text = soup.get_text(separator=' ', strip=True) texts.append(text) except FileNotFoundError: texts.append("") return texts def extract_entities(text): inputs = global_models.bio_tokenizer(text, return_tensors="pt") outputs = global_models.bio_model(**inputs) predictions = torch.argmax(outputs.logits, dim=2) tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] def create_prompt(question, passage): return f""" As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information. Passage: {passage} Question: {question} Answer: """ def generate_answer(prompt, max_length=860, temperature=0.2): inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True) start_time = time.time() output_ids = global_models.model_f.generate( inputs.input_ids, max_length=max_length, num_return_sequences=1, temperature=temperature, pad_token_id=global_models.tokenizer_f.eos_token_id ) duration = time.time() - start_time answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True) return answer, duration def clean_answer(answer): answer_part = answer.split("Answer:")[-1].strip() if not answer_part.endswith('.'): last_period_index = answer_part.rfind('.') if last_period_index != -1: answer_part = answer_part[:last_period_index + 1].strip() return answer_part @app.post("/retrieve_documents") async def retrieve_documents(input_data: QueryInput): try: # Process query processed_query = process_query(input_data.query_text, input_data.language_code) query_embedding = embed_query_text(processed_query) results = query_embeddings(query_embedding) # Get document texts and rerank document_ids = [doc_id for doc_id, _ in results] document_texts = retrieve_document_texts(document_ids) scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts]) # Prepare response documents = [] for score, doc_id, text in zip(scores, document_ids, document_texts): url = global_models.file_name_to_url.get(doc_id, "") documents.append({ "title": doc_id, "url": url, "text": text if input_data.language_code == 1 else translate_en_to_ar(text), "score": float(score) }) return {"status": "success", "documents": documents} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/get_answer") async def get_answer(input_data: QueryInput): try: # Process query processed_query = process_query(input_data.query_text, input_data.language_code) # Get relevant documents query_embedding = embed_query_text(processed_query) results = query_embeddings(query_embedding) document_ids = [doc_id for doc_id, _ in results] document_texts = retrieve_document_texts(document_ids) # Extract entities and create context entities = extract_entities(processed_query) context = " ".join(document_texts) enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}" # Generate answer prompt = create_prompt(processed_query, enhanced_context) answer, duration = generate_answer(prompt) final_answer = clean_answer(answer) # Translate if needed if input_data.language_code == 0: final_answer = translate_en_to_ar(final_answer) return { "status": "success", "answer": final_answer, "processing_time": duration } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)