thechaiexperiment commited on
Commit
8e3b5f7
·
1 Parent(s): 7b16750

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -284
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import List, Optional, Dict
4
  import pickle
@@ -19,9 +19,9 @@ from transformers import (
19
  import pandas as pd
20
  import time
21
 
22
- # Initialize FastAPI app first
23
  app = FastAPI()
24
 
 
25
  class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
  """Custom unpickler for article embeddings with enhanced persistence handling"""
27
  def find_class(self, module: str, name: str) -> any:
@@ -35,7 +35,6 @@ class ArticleEmbeddingUnpickler(pickle.Unpickler):
35
  def persistent_load(self, pid: any) -> str:
36
  """Enhanced persistent ID handler with better encoding management"""
37
  try:
38
- # Handle different types of persistent IDs
39
  if isinstance(pid, bytes):
40
  return pid.decode('utf-8', errors='replace')
41
  if isinstance(pid, (str, int, float)):
@@ -48,7 +47,6 @@ class ArticleEmbeddingUnpickler(pickle.Unpickler):
48
  def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
49
  """Load embeddings with enhanced error handling, validation, and persistent ID support."""
50
  def persistent_load(pid):
51
- """Handle persistent ID references during unpickling."""
52
  print(f"Warning: Persistent ID encountered: {pid}")
53
  raise ValueError("Persistent IDs are not supported in this application")
54
 
@@ -64,7 +62,7 @@ def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndar
64
  if not isinstance(embeddings_data, dict):
65
  raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
66
 
67
- # Process and validate embeddings (same as before)
68
  valid_embeddings = {}
69
  for key, value in embeddings_data.items():
70
  try:
@@ -101,17 +99,15 @@ def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndar
101
  raise
102
 
103
  def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
104
- # Ensure all keys are ASCII-safe strings
105
  cleaned_embeddings = {
106
- str(key): value # Use str(key) instead of encoding to ASCII
107
  for key, value in embeddings_dict.items()
108
  }
109
-
110
  with open(file_path, 'wb') as f:
111
- # Use a newer protocol for better compatibility
112
  pickle.dump(cleaned_embeddings, f, protocol=4)
113
 
114
- # Models and data structures
 
115
  class GlobalModels:
116
  embedding_model = None
117
  cross_encoder = None
@@ -124,154 +120,35 @@ class GlobalModels:
124
  ar_to_en_model = None
125
  en_to_ar_tokenizer = None
126
  en_to_ar_model = None
127
- embeddings_data = None
128
- file_name_to_url = None
129
  bio_tokenizer = None
130
  bio_model = None
131
-
132
- # Initialize global models
133
- global_models = GlobalModels()
134
-
135
- # Download NLTK data
136
- nltk.download('punkt')
137
-
138
- # Pydantic models for request validation
139
- class QueryInput(BaseModel):
140
- query_text: str
141
- language_code: int # 0 for Arabic, 1 for English
142
- query_type: str # "profile" or "question"
143
- previous_qa: Optional[List[Dict[str, str]]] = None
144
-
145
- class DocumentResponse(BaseModel):
146
- title: str
147
- url: str
148
- text: str
149
- score: float
150
-
151
- # Modified startup event handler
152
- @app.on_event("startup")
153
- @app.on_event("startup")
154
- async def load_models():
155
- try:
156
- print("Starting to load embeddings...")
157
- embeddings_data = safe_load_embeddings()
158
- print(f"Embeddings data type: {type(embeddings_data)}")
159
- if embeddings_data:
160
- print(f"Number of embeddings: {len(embeddings_data)}")
161
- # Print sample of keys
162
- print("Sample keys:", list(embeddings_data.keys())[:3])
163
- # Load embedding models first
164
- global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
165
-
166
- # Load embeddings data with new safe loader
167
- embeddings_data = safe_load_embeddings()
168
- if embeddings_data is None:
169
- raise HTTPException(status_code=500, detail="Failed to load embeddings data")
170
- global_models.embeddings_data = embeddings_data
171
-
172
- # Load remaining models
173
- global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
174
- global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
175
-
176
- # Load BART models
177
- global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
178
- global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
179
-
180
- # Load Orca model
181
- model_name = "M4-ai/Orca-2.0-Tau-1.8B"
182
- global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
183
- global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
184
-
185
- # Load translation models
186
- global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
187
- global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
188
- global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
189
- global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
190
-
191
- # Load Medical NER models
192
- global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
193
- global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
194
-
195
- # Load URL mapping data
196
- try:
197
- df = pd.read_excel('finalcleaned_excel_file.xlsx')
198
- global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
199
- except Exception as e:
200
- print(f"Error loading URL mapping data: {e}")
201
- raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
202
-
203
- print("All models loaded successfully")
204
-
205
- except Exception as e:
206
- print(f"Error during startup: {str(e)}")
207
- raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}")
208
-
209
-
210
- # Models and data structures to store loaded models
211
- class GlobalModels:
212
- embedding_model = None
213
- cross_encoder = None
214
- semantic_model = None
215
- tokenizer = None
216
- model = None
217
- tokenizer_f = None
218
- model_f = None
219
- ar_to_en_tokenizer = None
220
- ar_to_en_model = None
221
- en_to_ar_tokenizer = None
222
- en_to_ar_model = None
223
  embeddings_data = None
224
  file_name_to_url = None
225
- bio_tokenizer = None
226
- bio_model = None
227
 
228
  global_models = GlobalModels()
229
 
230
- # Download NLTK data
231
- nltk.download('punkt')
232
-
233
- # Pydantic models for request validation
234
- class QueryInput(BaseModel):
235
- query_text: str
236
- language_code: int # 0 for Arabic, 1 for English
237
- query_type: str # "profile" or "question"
238
- previous_qa: Optional[List[Dict[str, str]]] = None
239
-
240
- class DocumentResponse(BaseModel):
241
- title: str
242
- url: str
243
- text: str
244
- score: float
245
-
246
  @app.on_event("startup")
247
  async def load_models():
248
- """Initialize all models and data on startup"""
249
  try:
250
- # Load embedding models
251
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
252
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
253
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
254
 
255
- # Load BART models
256
  global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
257
  global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
258
 
259
- # Load Orca model
260
  model_name = "M4-ai/Orca-2.0-Tau-1.8B"
261
  global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
262
  global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
263
 
264
- # Load translation models
265
  global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
266
  global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
267
  global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
268
  global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
269
 
270
- # Load Medical NER models
271
  global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
272
  global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
273
 
274
- # Load embeddings data with better error handling
275
  try:
276
  with open('embeddings.pkl', 'rb') as file:
277
  global_models.embeddings_data = pickle.load(file)
@@ -279,135 +156,38 @@ async def load_models():
279
  print(f"Error loading embeddings data: {e}")
280
  raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
281
 
282
- # Load URL mapping data
283
- try:
284
- df = pd.read_excel('finalcleaned_excel_file.xlsx')
285
- global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
286
- except Exception as e:
287
- print(f"Error loading URL mapping data: {e}")
288
- raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
289
 
290
  except Exception as e:
291
  print(f"Error loading models: {e}")
292
  raise HTTPException(status_code=500, detail="Failed to load models.")
293
 
294
 
295
- def translate_ar_to_en(text):
296
- try:
297
- inputs = global_models.ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
298
- translated_ids = global_models.ar_to_en_model.generate(
299
- inputs.input_ids,
300
- max_length=512,
301
- num_beams=4,
302
- early_stopping=True
303
- )
304
- translated_text = global_models.ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
305
- return translated_text
306
- except Exception as e:
307
- print(f"Error during Arabic to English translation: {e}")
308
- return None
309
-
310
- def translate_en_to_ar(text):
311
- try:
312
- inputs = global_models.en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
313
- translated_ids = global_models.en_to_ar_model.generate(
314
- inputs.input_ids,
315
- max_length=512,
316
- num_beams=4,
317
- early_stopping=True
318
- )
319
- translated_text = global_models.en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
320
- return translated_text
321
- except Exception as e:
322
- print(f"Error during English to Arabic translation: {e}")
323
- return None
324
-
325
- def process_query(query_text, language_code):
326
- if language_code == 0:
327
- return translate_ar_to_en(query_text)
328
- return query_text
329
-
330
- def embed_query_text(query_text):
331
- return global_models.embedding_model.encode([query_text])
332
-
333
- def query_embeddings(query_embedding, n_results=5):
334
- doc_ids = list(global_models.embeddings_data.keys())
335
- doc_embeddings = np.array(list(global_models.embeddings_data.values()))
336
- similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
337
- top_indices = similarities.argsort()[-n_results:][::-1]
338
- return [(doc_ids[i], similarities[i]) for i in top_indices]
339
-
340
- def retrieve_document_texts(doc_ids, folder_path='downloaded_articles'):
341
- texts = []
342
- for doc_id in doc_ids:
343
- file_path = os.path.join(folder_path, doc_id)
344
- try:
345
- with open(file_path, 'r', encoding='utf-8') as file:
346
- soup = BeautifulSoup(file, 'html.parser')
347
- text = soup.get_text(separator=' ', strip=True)
348
- texts.append(text)
349
- except FileNotFoundError:
350
- texts.append("")
351
- return texts
352
-
353
- def extract_entities(text):
354
- inputs = global_models.bio_tokenizer(text, return_tensors="pt")
355
- outputs = global_models.bio_model(**inputs)
356
- predictions = torch.argmax(outputs.logits, dim=2)
357
- tokens = global_models.bio_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
358
- return [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0]
359
-
360
- def create_prompt(question, passage):
361
- return f"""
362
- As a medical expert, you are required to answer the following question based only on the provided passage.
363
- Do not include any information not present in the passage. Your response should directly reflect the content
364
- of the passage. Maintain accuracy and relevance to the provided information.
365
-
366
- Passage: {passage}
367
-
368
- Question: {question}
369
-
370
- Answer:
371
- """
372
-
373
- def generate_answer(prompt, max_length=860, temperature=0.2):
374
- inputs = global_models.tokenizer_f(prompt, return_tensors="pt", truncation=True)
375
-
376
- start_time = time.time()
377
- output_ids = global_models.model_f.generate(
378
- inputs.input_ids,
379
- max_length=max_length,
380
- num_return_sequences=1,
381
- temperature=temperature,
382
- pad_token_id=global_models.tokenizer_f.eos_token_id
383
- )
384
- duration = time.time() - start_time
385
-
386
- answer = global_models.tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
387
- return answer, duration
388
 
389
- def clean_answer(answer):
390
- answer_part = answer.split("Answer:")[-1].strip()
391
- if not answer_part.endswith('.'):
392
- last_period_index = answer_part.rfind('.')
393
- if last_period_index != -1:
394
- answer_part = answer_part[:last_period_index + 1].strip()
395
- return answer_part
396
 
397
  @app.post("/retrieve_documents")
398
  async def retrieve_documents(input_data: QueryInput):
399
  try:
400
- # Process query
401
  processed_query = process_query(input_data.query_text, input_data.language_code)
402
  query_embedding = embed_query_text(processed_query)
403
  results = query_embeddings(query_embedding)
404
 
405
- # Get document texts and rerank
406
  document_ids = [doc_id for doc_id, _ in results]
407
  document_texts = retrieve_document_texts(document_ids)
408
  scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
409
 
410
- # Prepare response
411
  documents = []
412
  for score, doc_id, text in zip(scores, document_ids, document_texts):
413
  url = global_models.file_name_to_url.get(doc_id, "")
@@ -417,53 +197,30 @@ async def retrieve_documents(input_data: QueryInput):
417
  "text": text if input_data.language_code == 1 else translate_en_to_ar(text),
418
  "score": float(score)
419
  })
420
-
421
- return {"status": "success", "documents": documents}
422
-
423
- except Exception as e:
424
- raise HTTPException(status_code=500, detail=str(e))
425
 
426
- @app.post("/get_answer")
427
- async def get_answer(input_data: QueryInput):
428
- try:
429
- # Process query
430
- processed_query = process_query(input_data.query_text, input_data.language_code)
431
-
432
- # Get relevant documents
433
- query_embedding = embed_query_text(processed_query)
434
- results = query_embeddings(query_embedding)
435
- document_ids = [doc_id for doc_id, _ in results]
436
- document_texts = retrieve_document_texts(document_ids)
437
-
438
- # Extract entities and create context
439
- entities = extract_entities(processed_query)
440
- context = " ".join(document_texts)
441
- enhanced_context = f"{context}\n\nEntities: {', '.join(entities)}"
442
-
443
- # Generate answer
444
- prompt = create_prompt(processed_query, enhanced_context)
445
- answer, duration = generate_answer(prompt)
446
- final_answer = clean_answer(answer)
447
-
448
- # Translate if needed
449
- if input_data.language_code == 0:
450
- final_answer = translate_en_to_ar(final_answer)
451
-
452
- return {
453
- "status": "success",
454
- "answer": final_answer,
455
- "processing_time": duration
456
- }
457
-
458
  except Exception as e:
459
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
- @app.get("/")
462
- async def root():
463
- return {"message": "Server is running"}
464
 
465
- if __name__ == "__main__":
466
- import uvicorn
467
- uvicorn.run(app, host="0.0.0.0", port=7860)
468
 
469
-
 
1
+ from fastapi import FastAPI, HTTPException, Query
2
  from pydantic import BaseModel
3
  from typing import List, Optional, Dict
4
  import pickle
 
19
  import pandas as pd
20
  import time
21
 
 
22
  app = FastAPI()
23
 
24
+ # ArticleEmbeddingUnpickler and safe_load_embeddings functions
25
  class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
  """Custom unpickler for article embeddings with enhanced persistence handling"""
27
  def find_class(self, module: str, name: str) -> any:
 
35
  def persistent_load(self, pid: any) -> str:
36
  """Enhanced persistent ID handler with better encoding management"""
37
  try:
 
38
  if isinstance(pid, bytes):
39
  return pid.decode('utf-8', errors='replace')
40
  if isinstance(pid, (str, int, float)):
 
47
  def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
48
  """Load embeddings with enhanced error handling, validation, and persistent ID support."""
49
  def persistent_load(pid):
 
50
  print(f"Warning: Persistent ID encountered: {pid}")
51
  raise ValueError("Persistent IDs are not supported in this application")
52
 
 
62
  if not isinstance(embeddings_data, dict):
63
  raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
64
 
65
+ # Process and validate embeddings
66
  valid_embeddings = {}
67
  for key, value in embeddings_data.items():
68
  try:
 
99
  raise
100
 
101
  def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
 
102
  cleaned_embeddings = {
103
+ str(key): value
104
  for key, value in embeddings_dict.items()
105
  }
 
106
  with open(file_path, 'wb') as f:
 
107
  pickle.dump(cleaned_embeddings, f, protocol=4)
108
 
109
+
110
+ # GlobalModels and load_models
111
  class GlobalModels:
112
  embedding_model = None
113
  cross_encoder = None
 
120
  ar_to_en_model = None
121
  en_to_ar_tokenizer = None
122
  en_to_ar_model = None
 
 
123
  bio_tokenizer = None
124
  bio_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  embeddings_data = None
126
  file_name_to_url = None
 
 
127
 
128
  global_models = GlobalModels()
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  @app.on_event("startup")
131
  async def load_models():
 
132
  try:
 
133
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
134
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
135
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
136
 
 
137
  global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
138
  global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
139
 
 
140
  model_name = "M4-ai/Orca-2.0-Tau-1.8B"
141
  global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
142
  global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
143
 
 
144
  global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
145
  global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
146
  global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
147
  global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
148
 
 
149
  global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
150
  global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
151
 
 
152
  try:
153
  with open('embeddings.pkl', 'rb') as file:
154
  global_models.embeddings_data = pickle.load(file)
 
156
  print(f"Error loading embeddings data: {e}")
157
  raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
158
 
159
+ df = pd.read_excel('finalcleaned_excel_file.xlsx')
160
+ global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
 
 
 
 
 
161
 
162
  except Exception as e:
163
  print(f"Error loading models: {e}")
164
  raise HTTPException(status_code=500, detail="Failed to load models.")
165
 
166
 
167
+ # Query and Document Retrieval Endpoint
168
+ class QueryInput(BaseModel):
169
+ query_text: str
170
+ language_code: int # 0 for Arabic, 1 for English
171
+ query_type: str # "profile" or "question"
172
+ previous_qa: Optional[List[Dict[str, str]]] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ class DocumentResponse(BaseModel):
175
+ title: str
176
+ url: str
177
+ text: str
178
+ score: float
 
 
179
 
180
  @app.post("/retrieve_documents")
181
  async def retrieve_documents(input_data: QueryInput):
182
  try:
 
183
  processed_query = process_query(input_data.query_text, input_data.language_code)
184
  query_embedding = embed_query_text(processed_query)
185
  results = query_embeddings(query_embedding)
186
 
 
187
  document_ids = [doc_id for doc_id, _ in results]
188
  document_texts = retrieve_document_texts(document_ids)
189
  scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
190
 
 
191
  documents = []
192
  for score, doc_id, text in zip(scores, document_ids, document_texts):
193
  url = global_models.file_name_to_url.get(doc_id, "")
 
197
  "text": text if input_data.language_code == 1 else translate_en_to_ar(text),
198
  "score": float(score)
199
  })
200
+ return documents
 
 
 
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  except Exception as e:
203
+ raise HTTPException(status_code=500, detail=f"Error retrieving documents: {str(e)}")
204
+
205
+ def process_query(query_text: str, language_code: int) -> str:
206
+ if language_code == 0:
207
+ return translate_ar_to_en(query_text) # Translate Arabic to English if required
208
+ return query_text
209
+
210
+ def embed_query_text(query_text: str) -> np.ndarray:
211
+ return global_models.embedding_model.encode(query_text, convert_to_tensor=True)
212
+
213
+ def query_embeddings(query_embedding: np.ndarray, top_n: int = 10) -> List[tuple]:
214
+ doc_embeddings = list(global_models.embeddings_data.values())
215
+ document_ids = list(global_models.embeddings_data.keys())
216
+ similarities = cosine_similarity(query_embedding, doc_embeddings)
217
+ top_indices = np.argsort(similarities[0])[-top_n:]
218
+ return [(document_ids[idx], similarities[0][idx]) for idx in reversed(top_indices)]
219
 
220
+ def retrieve_document_texts(document_ids: List[str]) -> List[str]:
221
+ return [global_models.file_name_to_url[doc_id] for doc_id in document_ids]
 
222
 
223
+ def translate_en_to_ar(text: str) -> str:
224
+ # Translation logic here, possibly using `transformers` or another library
225
+ pass
226