Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict | |
import pickle | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from bs4 import BeautifulSoup | |
import os | |
import nltk | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
BartForConditionalGeneration, | |
AutoModelForCausalLM, | |
AutoModelForSeq2SeqLM | |
) | |
import pandas as pd | |
import time | |
app = FastAPI() | |
# Models and data structures to store loaded models | |
class GlobalModels: | |
embedding_model = None | |
cross_encoder = None | |
semantic_model = None | |
tokenizer = None | |
model = None | |
tokenizer_f = None | |
model_f = None | |
ar_to_en_tokenizer = None | |
ar_to_en_model = None | |
en_to_ar_tokenizer = None | |
en_to_ar_model = None | |
embeddings_data = None | |
file_name_to_url = None | |
bio_tokenizer = None | |
bio_model = None | |
global_models = GlobalModels() | |
# Download NLTK data | |
nltk.download('punkt') | |
# Pydantic models for request validation | |
class QueryInput(BaseModel): | |
query_text: str | |
language_code: int # 0 for Arabic, 1 for English | |
query_type: str # "profile" or "question" | |
previous_qa: Optional[List[Dict[str, str]]] = None | |
class DocumentResponse(BaseModel): | |
title: str | |
url: str | |
text: str | |
score: float | |
async def load_models(): | |
"""Initialize all models and data on startup""" | |
try: | |
# Load embedding models | |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512) | |
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Load BART models | |
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") | |
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
# Load Orca model | |
model_name = "M4-ai/Orca-2.0-Tau-1.8B" | |
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name) | |
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name) | |
# Load translation models | |
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en") | |
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar") | |
# Load Medical NER models | |
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER") | |
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER") | |
# Load embeddings data | |
with open('embeddings.pkl', 'rb') as file: | |
global_models.embeddings_data = pickle.load(file) | |
# Load URL mapping data | |
df = pd.read_excel('finalcleaned_excel_file.xlsx') | |
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])} | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
raise | |
def translate_ar_to_en(text): | |
try: | |
inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = global_models.ar_to_en_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during Arabic to English translation: {e}") | |
return None | |
def translate_en_to_ar(text): | |
try: | |
inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
translated_ids = global_models.en_to_ar_model.generate( | |
inputs.input_ids, | |
max_length=512, | |
num_beams=4, | |
early_stopping=True | |
) | |
translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True) | |
return translated_text | |
except Exception as e: | |
print(f"Error during English to Arabic translation: {e}") | |
return None | |
def process_query(query_text, language_code): | |
if language_code == 0: | |
return translate_ar_to_en(query_text) | |
return query_text | |
def embed_query_text(query_text): | |
return global_models.embedding_model.encode([query_text]) | |
def query_embeddings(query_embedding, n_results=5): | |
doc_ids = list(global_models.embeddings_data.keys()) | |
doc_embeddings = np.array(list(global_models.embeddings_data.values())) | |
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten() | |
top_indices = similarities.argsort()[-n_results:][::-1] | |
return [(doc_ids[i], similarities[i]) for i in top_indices] | |
def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'): | |
texts = [] | |
for doc_id in doc_ids: | |
file_path = os.path.join(folder_path, doc_id) | |
try: | |
with open(file_path, 'r', encoding='utf-8') as file: | |
soup = BeautifulSoup(file, 'html.parser') | |
text = soup.get_text(separator=' ', strip=True) | |
texts.append(text) | |
except FileNotFoundError: | |
texts.append("") | |
return texts | |
def extract_entities(text): | |
inputs = global_models.bio_tokenizer(text, return_tensors="pt") | |
outputs = global_models.bio_model(**inputs) | |
predictions = torch.argmax(outputs.logits, dim=2) | |
tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0]) | |
return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] | |
def create_prompt(question, passage): | |
return f""" | |
As a medical expert, you are required to answer the following question based only on the provided passage. | |
Do not include any information not present in the passage. Your response should directly reflect the content | |
of the passage. Maintain accuracy and relevance to the provided information. | |
Passage: {passage} | |
Question: {question} | |
Answer: | |
""" | |
def generate_answer(prompt, max_length=860, temperature=0.2): | |
inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True) | |
start_time = time.time() | |
output_ids = global_models.model_f.generate( | |
inputs.input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
pad_token_id=global_models.tokenizer_f.eos_token_id | |
) | |
duration = time.time() - start_time | |
answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True) | |
return answer, duration | |
def clean_answer(answer): | |
answer_part = answer.split("Answer:")[-1].strip() | |
if not answer_part.endswith('.'): | |
last_period_index = answer_part.rfind('.') | |
if last_period_index != -1: | |
answer_part = answer_part[:last_period_index + 1].strip() | |
return answer_part | |
async def retrieve_documents(input_data: QueryInput): | |
try: | |
# Process query | |
processed_query = process_query(input_data.query_text, input_data.language_code) | |
query_embedding = embed_query_text(processed_query) | |
results = query_embeddings(query_embedding) | |
# Get document texts and rerank | |
document_ids = [doc_id for doc_id, _ in results] | |
document_texts = retrieve_document_texts(document_ids) | |
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts]) | |
# Prepare response | |
documents = [] | |
for score, doc_id, text in zip(scores, document_ids, document_texts): | |
url = global_models.file_name_to_url.get(doc_id, "") | |
documents.append({ | |
"title": doc_id, | |
"url": url, | |
"text": text if input_data.language_code == 1 else translate_en_to_ar(text), | |
"score": float(score) | |
}) | |
return {"status": "success", "documents": documents} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_answer(input_data: QueryInput): | |
try: | |
# Process query | |
processed_query = process_query(input_data.query_text, input_data.language_code) | |
# Get relevant documents | |
query_embedding = embed_query_text(processed_query) | |
results = query_embeddings(query_embedding) | |
document_ids = [doc_id for doc_id, _ in results] | |
document_texts = retrieve_document_texts(document_ids) | |
# Extract entities and create context | |
entities = extract_entities(processed_query) | |
context = " ".join(document_texts) | |
enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}" | |
# Generate answer | |
prompt = create_prompt(processed_query, enhanced_context) | |
answer, duration = generate_answer(prompt) | |
final_answer = clean_answer(answer) | |
# Translate if needed | |
if input_data.language_code == 0: | |
final_answer = translate_en_to_ar(final_answer) | |
return { | |
"status": "success", | |
"answer": final_answer, | |
"processing_time": duration | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |