Spaces:
Sleeping
Sleeping
Commit
·
8e3b5f7
1
Parent(s):
7b16750
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
from typing import List, Optional, Dict
|
4 |
import pickle
|
@@ -19,9 +19,9 @@ from transformers import (
|
|
19 |
import pandas as pd
|
20 |
import time
|
21 |
|
22 |
-
# Initialize FastAPI app first
|
23 |
app = FastAPI()
|
24 |
|
|
|
25 |
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
"""Custom unpickler for article embeddings with enhanced persistence handling"""
|
27 |
def find_class(self, module: str, name: str) -> any:
|
@@ -35,7 +35,6 @@ class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
|
35 |
def persistent_load(self, pid: any) -> str:
|
36 |
"""Enhanced persistent ID handler with better encoding management"""
|
37 |
try:
|
38 |
-
# Handle different types of persistent IDs
|
39 |
if isinstance(pid, bytes):
|
40 |
return pid.decode('utf-8', errors='replace')
|
41 |
if isinstance(pid, (str, int, float)):
|
@@ -48,7 +47,6 @@ class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
|
48 |
def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
|
49 |
"""Load embeddings with enhanced error handling, validation, and persistent ID support."""
|
50 |
def persistent_load(pid):
|
51 |
-
"""Handle persistent ID references during unpickling."""
|
52 |
print(f"Warning: Persistent ID encountered: {pid}")
|
53 |
raise ValueError("Persistent IDs are not supported in this application")
|
54 |
|
@@ -64,7 +62,7 @@ def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndar
|
|
64 |
if not isinstance(embeddings_data, dict):
|
65 |
raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
|
66 |
|
67 |
-
# Process and validate embeddings
|
68 |
valid_embeddings = {}
|
69 |
for key, value in embeddings_data.items():
|
70 |
try:
|
@@ -101,17 +99,15 @@ def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndar
|
|
101 |
raise
|
102 |
|
103 |
def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
|
104 |
-
# Ensure all keys are ASCII-safe strings
|
105 |
cleaned_embeddings = {
|
106 |
-
str(key): value
|
107 |
for key, value in embeddings_dict.items()
|
108 |
}
|
109 |
-
|
110 |
with open(file_path, 'wb') as f:
|
111 |
-
# Use a newer protocol for better compatibility
|
112 |
pickle.dump(cleaned_embeddings, f, protocol=4)
|
113 |
|
114 |
-
|
|
|
115 |
class GlobalModels:
|
116 |
embedding_model = None
|
117 |
cross_encoder = None
|
@@ -124,154 +120,35 @@ class GlobalModels:
|
|
124 |
ar_to_en_model = None
|
125 |
en_to_ar_tokenizer = None
|
126 |
en_to_ar_model = None
|
127 |
-
embeddings_data = None
|
128 |
-
file_name_to_url = None
|
129 |
bio_tokenizer = None
|
130 |
bio_model = None
|
131 |
-
|
132 |
-
# Initialize global models
|
133 |
-
global_models = GlobalModels()
|
134 |
-
|
135 |
-
# Download NLTK data
|
136 |
-
nltk.download('punkt')
|
137 |
-
|
138 |
-
# Pydantic models for request validation
|
139 |
-
class QueryInput(BaseModel):
|
140 |
-
query_text: str
|
141 |
-
language_code: int # 0 for Arabic, 1 for English
|
142 |
-
query_type: str # "profile" or "question"
|
143 |
-
previous_qa: Optional[List[Dict[str, str]]] = None
|
144 |
-
|
145 |
-
class DocumentResponse(BaseModel):
|
146 |
-
title: str
|
147 |
-
url: str
|
148 |
-
text: str
|
149 |
-
score: float
|
150 |
-
|
151 |
-
# Modified startup event handler
|
152 |
-
@app.on_event("startup")
|
153 |
-
@app.on_event("startup")
|
154 |
-
async def load_models():
|
155 |
-
try:
|
156 |
-
print("Starting to load embeddings...")
|
157 |
-
embeddings_data = safe_load_embeddings()
|
158 |
-
print(f"Embeddings data type: {type(embeddings_data)}")
|
159 |
-
if embeddings_data:
|
160 |
-
print(f"Number of embeddings: {len(embeddings_data)}")
|
161 |
-
# Print sample of keys
|
162 |
-
print("Sample keys:", list(embeddings_data.keys())[:3])
|
163 |
-
# Load embedding models first
|
164 |
-
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
165 |
-
|
166 |
-
# Load embeddings data with new safe loader
|
167 |
-
embeddings_data = safe_load_embeddings()
|
168 |
-
if embeddings_data is None:
|
169 |
-
raise HTTPException(status_code=500, detail="Failed to load embeddings data")
|
170 |
-
global_models.embeddings_data = embeddings_data
|
171 |
-
|
172 |
-
# Load remaining models
|
173 |
-
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
174 |
-
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
175 |
-
|
176 |
-
# Load BART models
|
177 |
-
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
|
178 |
-
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
179 |
-
|
180 |
-
# Load Orca model
|
181 |
-
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
182 |
-
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
|
183 |
-
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
|
184 |
-
|
185 |
-
# Load translation models
|
186 |
-
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
187 |
-
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
188 |
-
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
189 |
-
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
190 |
-
|
191 |
-
# Load Medical NER models
|
192 |
-
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
193 |
-
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
194 |
-
|
195 |
-
# Load URL mapping data
|
196 |
-
try:
|
197 |
-
df = pd.read_excel('finalcleaned_excel_file.xlsx')
|
198 |
-
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
199 |
-
except Exception as e:
|
200 |
-
print(f"Error loading URL mapping data: {e}")
|
201 |
-
raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
|
202 |
-
|
203 |
-
print("All models loaded successfully")
|
204 |
-
|
205 |
-
except Exception as e:
|
206 |
-
print(f"Error during startup: {str(e)}")
|
207 |
-
raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}")
|
208 |
-
|
209 |
-
|
210 |
-
# Models and data structures to store loaded models
|
211 |
-
class GlobalModels:
|
212 |
-
embedding_model = None
|
213 |
-
cross_encoder = None
|
214 |
-
semantic_model = None
|
215 |
-
tokenizer = None
|
216 |
-
model = None
|
217 |
-
tokenizer_f = None
|
218 |
-
model_f = None
|
219 |
-
ar_to_en_tokenizer = None
|
220 |
-
ar_to_en_model = None
|
221 |
-
en_to_ar_tokenizer = None
|
222 |
-
en_to_ar_model = None
|
223 |
embeddings_data = None
|
224 |
file_name_to_url = None
|
225 |
-
bio_tokenizer = None
|
226 |
-
bio_model = None
|
227 |
|
228 |
global_models = GlobalModels()
|
229 |
|
230 |
-
# Download NLTK data
|
231 |
-
nltk.download('punkt')
|
232 |
-
|
233 |
-
# Pydantic models for request validation
|
234 |
-
class QueryInput(BaseModel):
|
235 |
-
query_text: str
|
236 |
-
language_code: int # 0 for Arabic, 1 for English
|
237 |
-
query_type: str # "profile" or "question"
|
238 |
-
previous_qa: Optional[List[Dict[str, str]]] = None
|
239 |
-
|
240 |
-
class DocumentResponse(BaseModel):
|
241 |
-
title: str
|
242 |
-
url: str
|
243 |
-
text: str
|
244 |
-
score: float
|
245 |
-
|
246 |
@app.on_event("startup")
|
247 |
async def load_models():
|
248 |
-
"""Initialize all models and data on startup"""
|
249 |
try:
|
250 |
-
# Load embedding models
|
251 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
252 |
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
253 |
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
254 |
|
255 |
-
# Load BART models
|
256 |
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
|
257 |
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
258 |
|
259 |
-
# Load Orca model
|
260 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
261 |
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
|
262 |
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
|
263 |
|
264 |
-
# Load translation models
|
265 |
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
266 |
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
267 |
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
268 |
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
269 |
|
270 |
-
# Load Medical NER models
|
271 |
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
272 |
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
273 |
|
274 |
-
# Load embeddings data with better error handling
|
275 |
try:
|
276 |
with open('embeddings.pkl', 'rb') as file:
|
277 |
global_models.embeddings_data = pickle.load(file)
|
@@ -279,135 +156,38 @@ async def load_models():
|
|
279 |
print(f"Error loading embeddings data: {e}")
|
280 |
raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
|
281 |
|
282 |
-
|
283 |
-
|
284 |
-
df = pd.read_excel('finalcleaned_excel_file.xlsx')
|
285 |
-
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
286 |
-
except Exception as e:
|
287 |
-
print(f"Error loading URL mapping data: {e}")
|
288 |
-
raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
|
289 |
|
290 |
except Exception as e:
|
291 |
print(f"Error loading models: {e}")
|
292 |
raise HTTPException(status_code=500, detail="Failed to load models.")
|
293 |
|
294 |
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
num_beams=4,
|
302 |
-
early_stopping=True
|
303 |
-
)
|
304 |
-
translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
305 |
-
return translated_text
|
306 |
-
except Exception as e:
|
307 |
-
print(f"Error during Arabic to English translation: {e}")
|
308 |
-
return None
|
309 |
-
|
310 |
-
def translate_en_to_ar(text):
|
311 |
-
try:
|
312 |
-
inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
313 |
-
translated_ids = global_models.en_to_ar_model.generate(
|
314 |
-
inputs.input_ids,
|
315 |
-
max_length=512,
|
316 |
-
num_beams=4,
|
317 |
-
early_stopping=True
|
318 |
-
)
|
319 |
-
translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
|
320 |
-
return translated_text
|
321 |
-
except Exception as e:
|
322 |
-
print(f"Error during English to Arabic translation: {e}")
|
323 |
-
return None
|
324 |
-
|
325 |
-
def process_query(query_text, language_code):
|
326 |
-
if language_code == 0:
|
327 |
-
return translate_ar_to_en(query_text)
|
328 |
-
return query_text
|
329 |
-
|
330 |
-
def embed_query_text(query_text):
|
331 |
-
return global_models.embedding_model.encode([query_text])
|
332 |
-
|
333 |
-
def query_embeddings(query_embedding, n_results=5):
|
334 |
-
doc_ids = list(global_models.embeddings_data.keys())
|
335 |
-
doc_embeddings = np.array(list(global_models.embeddings_data.values()))
|
336 |
-
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
337 |
-
top_indices = similarities.argsort()[-n_results:][::-1]
|
338 |
-
return [(doc_ids[i], similarities[i]) for i in top_indices]
|
339 |
-
|
340 |
-
def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'):
|
341 |
-
texts = []
|
342 |
-
for doc_id in doc_ids:
|
343 |
-
file_path = os.path.join(folder_path, doc_id)
|
344 |
-
try:
|
345 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
346 |
-
soup = BeautifulSoup(file, 'html.parser')
|
347 |
-
text = soup.get_text(separator=' ', strip=True)
|
348 |
-
texts.append(text)
|
349 |
-
except FileNotFoundError:
|
350 |
-
texts.append("")
|
351 |
-
return texts
|
352 |
-
|
353 |
-
def extract_entities(text):
|
354 |
-
inputs = global_models.bio_tokenizer(text, return_tensors="pt")
|
355 |
-
outputs = global_models.bio_model(**inputs)
|
356 |
-
predictions = torch.argmax(outputs.logits, dim=2)
|
357 |
-
tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
358 |
-
return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0]
|
359 |
-
|
360 |
-
def create_prompt(question, passage):
|
361 |
-
return f"""
|
362 |
-
As a medical expert, you are required to answer the following question based only on the provided passage.
|
363 |
-
Do not include any information not present in the passage. Your response should directly reflect the content
|
364 |
-
of the passage. Maintain accuracy and relevance to the provided information.
|
365 |
-
|
366 |
-
Passage: {passage}
|
367 |
-
|
368 |
-
Question: {question}
|
369 |
-
|
370 |
-
Answer:
|
371 |
-
"""
|
372 |
-
|
373 |
-
def generate_answer(prompt, max_length=860, temperature=0.2):
|
374 |
-
inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
375 |
-
|
376 |
-
start_time = time.time()
|
377 |
-
output_ids = global_models.model_f.generate(
|
378 |
-
inputs.input_ids,
|
379 |
-
max_length=max_length,
|
380 |
-
num_return_sequences=1,
|
381 |
-
temperature=temperature,
|
382 |
-
pad_token_id=global_models.tokenizer_f.eos_token_id
|
383 |
-
)
|
384 |
-
duration = time.time() - start_time
|
385 |
-
|
386 |
-
answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
387 |
-
return answer, duration
|
388 |
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
answer_part = answer_part[:last_period_index + 1].strip()
|
395 |
-
return answer_part
|
396 |
|
397 |
@app.post("/retrieve_documents")
|
398 |
async def retrieve_documents(input_data: QueryInput):
|
399 |
try:
|
400 |
-
# Process query
|
401 |
processed_query = process_query(input_data.query_text, input_data.language_code)
|
402 |
query_embedding = embed_query_text(processed_query)
|
403 |
results = query_embeddings(query_embedding)
|
404 |
|
405 |
-
# Get document texts and rerank
|
406 |
document_ids = [doc_id for doc_id, _ in results]
|
407 |
document_texts = retrieve_document_texts(document_ids)
|
408 |
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
|
409 |
|
410 |
-
# Prepare response
|
411 |
documents = []
|
412 |
for score, doc_id, text in zip(scores, document_ids, document_texts):
|
413 |
url = global_models.file_name_to_url.get(doc_id, "")
|
@@ -417,53 +197,30 @@ async def retrieve_documents(input_data: QueryInput):
|
|
417 |
"text": text if input_data.language_code == 1 else translate_en_to_ar(text),
|
418 |
"score": float(score)
|
419 |
})
|
420 |
-
|
421 |
-
return {"status": "success", "documents": documents}
|
422 |
-
|
423 |
-
except Exception as e:
|
424 |
-
raise HTTPException(status_code=500, detail=str(e))
|
425 |
|
426 |
-
@app.post("/get_answer")
|
427 |
-
async def get_answer(input_data: QueryInput):
|
428 |
-
try:
|
429 |
-
# Process query
|
430 |
-
processed_query = process_query(input_data.query_text, input_data.language_code)
|
431 |
-
|
432 |
-
# Get relevant documents
|
433 |
-
query_embedding = embed_query_text(processed_query)
|
434 |
-
results = query_embeddings(query_embedding)
|
435 |
-
document_ids = [doc_id for doc_id, _ in results]
|
436 |
-
document_texts = retrieve_document_texts(document_ids)
|
437 |
-
|
438 |
-
# Extract entities and create context
|
439 |
-
entities = extract_entities(processed_query)
|
440 |
-
context = " ".join(document_texts)
|
441 |
-
enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}"
|
442 |
-
|
443 |
-
# Generate answer
|
444 |
-
prompt = create_prompt(processed_query, enhanced_context)
|
445 |
-
answer, duration = generate_answer(prompt)
|
446 |
-
final_answer = clean_answer(answer)
|
447 |
-
|
448 |
-
# Translate if needed
|
449 |
-
if input_data.language_code == 0:
|
450 |
-
final_answer = translate_en_to_ar(final_answer)
|
451 |
-
|
452 |
-
return {
|
453 |
-
"status": "success",
|
454 |
-
"answer": final_answer,
|
455 |
-
"processing_time": duration
|
456 |
-
}
|
457 |
-
|
458 |
except Exception as e:
|
459 |
-
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
-
|
462 |
-
|
463 |
-
return {"message": "Server is running"}
|
464 |
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
|
469 |
-
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Query
|
2 |
from pydantic import BaseModel
|
3 |
from typing import List, Optional, Dict
|
4 |
import pickle
|
|
|
19 |
import pandas as pd
|
20 |
import time
|
21 |
|
|
|
22 |
app = FastAPI()
|
23 |
|
24 |
+
# ArticleEmbeddingUnpickler and safe_load_embeddings functions
|
25 |
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
"""Custom unpickler for article embeddings with enhanced persistence handling"""
|
27 |
def find_class(self, module: str, name: str) -> any:
|
|
|
35 |
def persistent_load(self, pid: any) -> str:
|
36 |
"""Enhanced persistent ID handler with better encoding management"""
|
37 |
try:
|
|
|
38 |
if isinstance(pid, bytes):
|
39 |
return pid.decode('utf-8', errors='replace')
|
40 |
if isinstance(pid, (str, int, float)):
|
|
|
47 |
def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
|
48 |
"""Load embeddings with enhanced error handling, validation, and persistent ID support."""
|
49 |
def persistent_load(pid):
|
|
|
50 |
print(f"Warning: Persistent ID encountered: {pid}")
|
51 |
raise ValueError("Persistent IDs are not supported in this application")
|
52 |
|
|
|
62 |
if not isinstance(embeddings_data, dict):
|
63 |
raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
|
64 |
|
65 |
+
# Process and validate embeddings
|
66 |
valid_embeddings = {}
|
67 |
for key, value in embeddings_data.items():
|
68 |
try:
|
|
|
99 |
raise
|
100 |
|
101 |
def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
|
|
|
102 |
cleaned_embeddings = {
|
103 |
+
str(key): value
|
104 |
for key, value in embeddings_dict.items()
|
105 |
}
|
|
|
106 |
with open(file_path, 'wb') as f:
|
|
|
107 |
pickle.dump(cleaned_embeddings, f, protocol=4)
|
108 |
|
109 |
+
|
110 |
+
# GlobalModels and load_models
|
111 |
class GlobalModels:
|
112 |
embedding_model = None
|
113 |
cross_encoder = None
|
|
|
120 |
ar_to_en_model = None
|
121 |
en_to_ar_tokenizer = None
|
122 |
en_to_ar_model = None
|
|
|
|
|
123 |
bio_tokenizer = None
|
124 |
bio_model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
embeddings_data = None
|
126 |
file_name_to_url = None
|
|
|
|
|
127 |
|
128 |
global_models = GlobalModels()
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
@app.on_event("startup")
|
131 |
async def load_models():
|
|
|
132 |
try:
|
|
|
133 |
global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
134 |
global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
135 |
global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
136 |
|
|
|
137 |
global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
|
138 |
global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
139 |
|
|
|
140 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
141 |
global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
|
142 |
global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
|
143 |
|
|
|
144 |
global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
145 |
global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
146 |
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
147 |
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
148 |
|
|
|
149 |
global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
150 |
global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
151 |
|
|
|
152 |
try:
|
153 |
with open('embeddings.pkl', 'rb') as file:
|
154 |
global_models.embeddings_data = pickle.load(file)
|
|
|
156 |
print(f"Error loading embeddings data: {e}")
|
157 |
raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
|
158 |
|
159 |
+
df = pd.read_excel('finalcleaned_excel_file.xlsx')
|
160 |
+
global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
except Exception as e:
|
163 |
print(f"Error loading models: {e}")
|
164 |
raise HTTPException(status_code=500, detail="Failed to load models.")
|
165 |
|
166 |
|
167 |
+
# Query and Document Retrieval Endpoint
|
168 |
+
class QueryInput(BaseModel):
|
169 |
+
query_text: str
|
170 |
+
language_code: int # 0 for Arabic, 1 for English
|
171 |
+
query_type: str # "profile" or "question"
|
172 |
+
previous_qa: Optional[List[Dict[str, str]]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
+
class DocumentResponse(BaseModel):
|
175 |
+
title: str
|
176 |
+
url: str
|
177 |
+
text: str
|
178 |
+
score: float
|
|
|
|
|
179 |
|
180 |
@app.post("/retrieve_documents")
|
181 |
async def retrieve_documents(input_data: QueryInput):
|
182 |
try:
|
|
|
183 |
processed_query = process_query(input_data.query_text, input_data.language_code)
|
184 |
query_embedding = embed_query_text(processed_query)
|
185 |
results = query_embeddings(query_embedding)
|
186 |
|
|
|
187 |
document_ids = [doc_id for doc_id, _ in results]
|
188 |
document_texts = retrieve_document_texts(document_ids)
|
189 |
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
|
190 |
|
|
|
191 |
documents = []
|
192 |
for score, doc_id, text in zip(scores, document_ids, document_texts):
|
193 |
url = global_models.file_name_to_url.get(doc_id, "")
|
|
|
197 |
"text": text if input_data.language_code == 1 else translate_en_to_ar(text),
|
198 |
"score": float(score)
|
199 |
})
|
200 |
+
return documents
|
|
|
|
|
|
|
|
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
except Exception as e:
|
203 |
+
raise HTTPException(status_code=500, detail=f"Error retrieving documents: {str(e)}")
|
204 |
+
|
205 |
+
def process_query(query_text: str, language_code: int) -> str:
|
206 |
+
if language_code == 0:
|
207 |
+
return translate_ar_to_en(query_text) # Translate Arabic to English if required
|
208 |
+
return query_text
|
209 |
+
|
210 |
+
def embed_query_text(query_text: str) -> np.ndarray:
|
211 |
+
return global_models.embedding_model.encode(query_text, convert_to_tensor=True)
|
212 |
+
|
213 |
+
def query_embeddings(query_embedding: np.ndarray, top_n: int = 10) -> List[tuple]:
|
214 |
+
doc_embeddings = list(global_models.embeddings_data.values())
|
215 |
+
document_ids = list(global_models.embeddings_data.keys())
|
216 |
+
similarities = cosine_similarity(query_embedding, doc_embeddings)
|
217 |
+
top_indices = np.argsort(similarities[0])[-top_n:]
|
218 |
+
return [(document_ids[idx], similarities[0][idx]) for idx in reversed(top_indices)]
|
219 |
|
220 |
+
def retrieve_document_texts(document_ids: List[str]) -> List[str]:
|
221 |
+
return [global_models.file_name_to_url[doc_id] for doc_id in document_ids]
|
|
|
222 |
|
223 |
+
def translate_en_to_ar(text: str) -> str:
|
224 |
+
# Translation logic here, possibly using `transformers` or another library
|
225 |
+
pass
|
226 |
|
|