Spaces:
Sleeping
Sleeping
Commit
·
35e1586
1
Parent(s):
59344e5
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from typing import List, Optional, Dict
|
4 |
+
import pickle
|
5 |
+
import numpy as np
|
6 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
8 |
+
from bs4 import BeautifulSoup
|
9 |
+
import os
|
10 |
+
import nltk
|
11 |
+
import torch
|
12 |
+
from transformers import (
|
13 |
+
AutoTokenizer,
|
14 |
+
BartForConditionalGeneration,
|
15 |
+
AutoModelForCausalLM,
|
16 |
+
AutoModelForSeq2SeqLM
|
17 |
+
)
|
18 |
+
import pandas as pd
|
19 |
+
import time
|
20 |
+
|
21 |
+
app = FastAPI()
|
22 |
+
|
23 |
+
# Models and data structures to store loaded models
|
24 |
+
class GlobalModels:
|
25 |
+
embedding_model = None
|
26 |
+
cross_encoder = None
|
27 |
+
semantic_model = None
|
28 |
+
tokenizer = None
|
29 |
+
model = None
|
30 |
+
tokenizer_f = None
|
31 |
+
model_f = None
|
32 |
+
ar_to_en_tokenizer = None
|
33 |
+
ar_to_en_model = None
|
34 |
+
en_to_ar_tokenizer = None
|
35 |
+
en_to_ar_model = None
|
36 |
+
embeddings_data = None
|
37 |
+
file_name_to_url = None
|
38 |
+
bio_tokenizer = None
|
39 |
+
bio_model = None
|
40 |
+
|
41 |
+
global_models = GlobalModels()
|
42 |
+
|
43 |
+
# Download NLTK data
|
44 |
+
nltk.download('punkt')
|
45 |
+
|
46 |
+
# Pydantic models for request validation
|
47 |
+
class QueryInput(BaseModel):
|
48 |
+
query_text: str
|
49 |
+
language_code: int # 0 for Arabic, 1 for English
|
50 |
+
query_type: str # "profile" or "question"
|
51 |
+
previous_qa: Optional[List[Dict[str, str]]] = None
|
52 |
+
|
53 |
+
class DocumentResponse(BaseModel):
|
54 |
+
title: str
|
55 |
+
url: str
|
56 |
+
text: str
|
57 |
+
score: float
|
58 |
+
|
59 |
+
@app.on_event("startup")
|
60 |
+
async def load_models():
|
61 |
+
"""Initialize all models and data on startup"""
|
62 |
+
try:
|
63 |
+
# Load embedding models
|
64 |
+
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
65 |
+
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
66 |
+
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
67 |
+
|
68 |
+
# Load BART models
|
69 |
+
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
|
70 |
+
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
71 |
+
|
72 |
+
# Load Orca model
|
73 |
+
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
74 |
+
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
|
75 |
+
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
|
76 |
+
|
77 |
+
# Load translation models
|
78 |
+
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
79 |
+
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
80 |
+
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
81 |
+
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
82 |
+
|
83 |
+
# Load Medical NER models
|
84 |
+
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
85 |
+
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
86 |
+
|
87 |
+
# Load embeddings data
|
88 |
+
with open('embeddings.pkl', 'rb') as file:
|
89 |
+
global_models.embeddings_data = pickle.load(file)
|
90 |
+
|
91 |
+
# Load URL mapping data
|
92 |
+
df = pd.read_excel('finalcleaned_excel_file.xlsx')
|
93 |
+
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
94 |
+
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Error loading models: {e}")
|
97 |
+
raise
|
98 |
+
|
99 |
+
def translate_ar_to_en(text):
|
100 |
+
try:
|
101 |
+
inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
102 |
+
translated_ids = global_models.ar_to_en_model.generate(
|
103 |
+
inputs.input_ids,
|
104 |
+
max_length=512,
|
105 |
+
num_beams=4,
|
106 |
+
early_stopping=True
|
107 |
+
)
|
108 |
+
translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
109 |
+
return translated_text
|
110 |
+
except Exception as e:
|
111 |
+
print(f"Error during Arabic to English translation: {e}")
|
112 |
+
return None
|
113 |
+
|
114 |
+
def translate_en_to_ar(text):
|
115 |
+
try:
|
116 |
+
inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
117 |
+
translated_ids = global_models.en_to_ar_model.generate(
|
118 |
+
inputs.input_ids,
|
119 |
+
max_length=512,
|
120 |
+
num_beams=4,
|
121 |
+
early_stopping=True
|
122 |
+
)
|
123 |
+
translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
124 |
+
return translated_text
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error during English to Arabic translation: {e}")
|
127 |
+
return None
|
128 |
+
|
129 |
+
def process_query(query_text, language_code):
|
130 |
+
if language_code == 0:
|
131 |
+
return translate_ar_to_en(query_text)
|
132 |
+
return query_text
|
133 |
+
|
134 |
+
def embed_query_text(query_text):
|
135 |
+
return global_models.embedding_model.encode([query_text])
|
136 |
+
|
137 |
+
def query_embeddings(query_embedding, n_results=5):
|
138 |
+
doc_ids = list(global_models.embeddings_data.keys())
|
139 |
+
doc_embeddings = np.array(list(global_models.embeddings_data.values()))
|
140 |
+
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
141 |
+
top_indices = similarities.argsort()[-n_results:][::-1]
|
142 |
+
return [(doc_ids[i], similarities[i]) for i in top_indices]
|
143 |
+
|
144 |
+
def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'):
|
145 |
+
texts = []
|
146 |
+
for doc_id in doc_ids:
|
147 |
+
file_path = os.path.join(folder_path, doc_id)
|
148 |
+
try:
|
149 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
150 |
+
soup = BeautifulSoup(file, 'html.parser')
|
151 |
+
text = soup.get_text(separator=' ', strip=True)
|
152 |
+
texts.append(text)
|
153 |
+
except FileNotFoundError:
|
154 |
+
texts.append("")
|
155 |
+
return texts
|
156 |
+
|
157 |
+
def extract_entities(text):
|
158 |
+
inputs = global_models.bio_tokenizer(text, return_tensors="pt")
|
159 |
+
outputs = global_models.bio_model(**inputs)
|
160 |
+
predictions = torch.argmax(outputs.logits, dim=2)
|
161 |
+
tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
162 |
+
return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0]
|
163 |
+
|
164 |
+
def create_prompt(question, passage):
|
165 |
+
return f"""
|
166 |
+
As a medical expert, you are required to answer the following question based only on the provided passage.
|
167 |
+
Do not include any information not present in the passage. Your response should directly reflect the content
|
168 |
+
of the passage. Maintain accuracy and relevance to the provided information.
|
169 |
+
|
170 |
+
Passage: {passage}
|
171 |
+
|
172 |
+
Question: {question}
|
173 |
+
|
174 |
+
Answer:
|
175 |
+
"""
|
176 |
+
|
177 |
+
def generate_answer(prompt, max_length=860, temperature=0.2):
|
178 |
+
inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
179 |
+
|
180 |
+
start_time = time.time()
|
181 |
+
output_ids = global_models.model_f.generate(
|
182 |
+
inputs.input_ids,
|
183 |
+
max_length=max_length,
|
184 |
+
num_return_sequences=1,
|
185 |
+
temperature=temperature,
|
186 |
+
pad_token_id=global_models.tokenizer_f.eos_token_id
|
187 |
+
)
|
188 |
+
duration = time.time() - start_time
|
189 |
+
|
190 |
+
answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
191 |
+
return answer, duration
|
192 |
+
|
193 |
+
def clean_answer(answer):
|
194 |
+
answer_part = answer.split("Answer:")[-1].strip()
|
195 |
+
if not answer_part.endswith('.'):
|
196 |
+
last_period_index = answer_part.rfind('.')
|
197 |
+
if last_period_index != -1:
|
198 |
+
answer_part = answer_part[:last_period_index + 1].strip()
|
199 |
+
return answer_part
|
200 |
+
|
201 |
+
@app.post("/retrieve_documents")
|
202 |
+
async def retrieve_documents(input_data: QueryInput):
|
203 |
+
try:
|
204 |
+
# Process query
|
205 |
+
processed_query = process_query(input_data.query_text, input_data.language_code)
|
206 |
+
query_embedding = embed_query_text(processed_query)
|
207 |
+
results = query_embeddings(query_embedding)
|
208 |
+
|
209 |
+
# Get document texts and rerank
|
210 |
+
document_ids = [doc_id for doc_id, _ in results]
|
211 |
+
document_texts = retrieve_document_texts(document_ids)
|
212 |
+
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
|
213 |
+
|
214 |
+
# Prepare response
|
215 |
+
documents = []
|
216 |
+
for score, doc_id, text in zip(scores, document_ids, document_texts):
|
217 |
+
url = global_models.file_name_to_url.get(doc_id, "")
|
218 |
+
documents.append({
|
219 |
+
"title": doc_id,
|
220 |
+
"url": url,
|
221 |
+
"text": text if input_data.language_code == 1 else translate_en_to_ar(text),
|
222 |
+
"score": float(score)
|
223 |
+
})
|
224 |
+
|
225 |
+
return {"status": "success", "documents": documents}
|
226 |
+
|
227 |
+
except Exception as e:
|
228 |
+
raise HTTPException(status_code=500, detail=str(e))
|
229 |
+
|
230 |
+
@app.post("/get_answer")
|
231 |
+
async def get_answer(input_data: QueryInput):
|
232 |
+
try:
|
233 |
+
# Process query
|
234 |
+
processed_query = process_query(input_data.query_text, input_data.language_code)
|
235 |
+
|
236 |
+
# Get relevant documents
|
237 |
+
query_embedding = embed_query_text(processed_query)
|
238 |
+
results = query_embeddings(query_embedding)
|
239 |
+
document_ids = [doc_id for doc_id, _ in results]
|
240 |
+
document_texts = retrieve_document_texts(document_ids)
|
241 |
+
|
242 |
+
# Extract entities and create context
|
243 |
+
entities = extract_entities(processed_query)
|
244 |
+
context = " ".join(document_texts)
|
245 |
+
enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}"
|
246 |
+
|
247 |
+
# Generate answer
|
248 |
+
prompt = create_prompt(processed_query, enhanced_context)
|
249 |
+
answer, duration = generate_answer(prompt)
|
250 |
+
final_answer = clean_answer(answer)
|
251 |
+
|
252 |
+
# Translate if needed
|
253 |
+
if input_data.language_code == 0:
|
254 |
+
final_answer = translate_en_to_ar(final_answer)
|
255 |
+
|
256 |
+
return {
|
257 |
+
"status": "success",
|
258 |
+
"answer": final_answer,
|
259 |
+
"processing_time": duration
|
260 |
+
}
|
261 |
+
|
262 |
+
except Exception as e:
|
263 |
+
raise HTTPException(status_code=500, detail=str(e))
|
264 |
+
|
265 |
+
if __name__ == "__main__":
|
266 |
+
import uvicorn
|
267 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|