thechaiexperiment commited on
Commit
31bad44
·
1 Parent(s): 0250187

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -426
app.py CHANGED
@@ -6,8 +6,9 @@ from flask_cors import CORS
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSeq2SeqLM,
 
9
  AutoModelForCausalLM,
10
- AutoModelForTokenClassification
11
  )
12
  from sentence_transformers import SentenceTransformer, CrossEncoder
13
  from sklearn.metrics.pairwise import cosine_similarity
@@ -15,71 +16,65 @@ from bs4 import BeautifulSoup
15
  import nltk
16
  import torch
17
  import pandas as pd
18
- from startup import setup_files
19
-
20
 
21
  app = Flask(__name__)
22
  CORS(app)
23
- # Environment variables for file paths
24
- EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/embeddings.pkl')
25
- LINKS_PATH = os.environ.get('LINKS_PATH', 'data/finalcleaned_excel_file.xlsx')
26
 
27
- def init_app():
28
- # Download and extract files if they don't exist
29
- if not os.path.exists('downloaded_articles'):
30
- setup_files()
31
 
32
- # Initialize models with proper error handling
33
- def initialize_models():
34
  try:
35
- global embedding_model, cross_encoder, semantic_model
36
- global ar_to_en_tokenizer, ar_to_en_model
37
- global en_to_ar_tokenizer, en_to_ar_model
38
- global tokenizer_f, model_f, bio_tokenizer, bio_model
 
39
 
40
- print("Initializing models...")
 
 
 
 
 
 
 
41
 
42
- # Basic embedding models
43
- embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
44
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
45
- semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
46
-
47
  # Translation models
48
- ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
49
- ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
50
- en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
51
- en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
52
-
53
- # Medical NER model
54
- bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
55
- bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
56
-
 
57
  # LLM model
58
  model_name = "M4-ai/Orca-2.0-Tau-1.8B"
59
- tokenizer_f = AutoTokenizer.from_pretrained(model_name)
60
- model_f = AutoModelForCausalLM.from_pretrained(model_name)
61
-
62
- nltk.download('punkt', quiet=True)
63
 
64
- print("Models initialized successfully")
65
  return True
66
  except Exception as e:
67
- print(f"Error initializing models: {e}")
68
  return False
69
 
70
- # Load data with error handling
71
  def load_data():
 
72
  try:
73
- global embeddings_data, df
74
-
75
- print("Loading data files...")
76
 
77
  # Load embeddings
78
- with open(EMBEDDINGS_PATH, 'rb') as file:
79
- embeddings_data = pickle.load(file)
80
 
81
- # Load links data
82
- df = pd.read_excel(LINKS_PATH)
83
 
84
  print("Data loaded successfully")
85
  return True
@@ -87,12 +82,84 @@ def load_data():
87
  print(f"Error loading data: {e}")
88
  return False
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @app.route('/health', methods=['GET'])
91
  def health_check():
 
92
  return jsonify({'status': 'healthy'})
93
 
94
  @app.route('/api/query', methods=['POST'])
95
  def process_query():
 
96
  try:
97
  data = request.json
98
  if not data or 'query' not in data:
@@ -101,24 +168,33 @@ def process_query():
101
  query_text = data['query']
102
  language_code = data.get('language_code', 0)
103
 
104
- # Process query
105
  if language_code == 0:
106
- query_text = translate_ar_to_en(query_text)
107
 
108
- # Get embeddings and find relevant documents
109
- query_embedding = embedding_model.encode([query_text])
110
- initial_results = query_embeddings(query_embedding, embeddings_data)
 
 
 
111
 
112
- # Process documents
113
- document_texts = retrieve_document_texts([doc_id for doc_id, _ in initial_results])
114
- relevant_portions = extract_relevant_portions(document_texts, query_text)
 
 
 
 
 
 
115
 
116
  # Generate answer
117
- combined_text = " ".join([item for sublist in relevant_portions.values() for item in sublist])
118
- answer = generate_answer(query_text, combined_text)
119
 
 
120
  if language_code == 0:
121
- answer = translate_en_to_ar(answer)
122
 
123
  return jsonify({
124
  'answer': answer,
@@ -131,375 +207,13 @@ def process_query():
131
  'success': False
132
  }), 500
133
 
134
- def translate_ar_to_en(text):
135
- try:
136
- inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True)
137
- outputs = ar_to_en_model.generate(**inputs)
138
- return ar_to_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
139
- except Exception as e:
140
- print(f"Translation error (AR->EN): {e}")
141
- return text
142
-
143
- def translate_en_to_ar(text):
144
- try:
145
- inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True)
146
- outputs = en_to_ar_model.generate(**inputs)
147
- return en_to_ar_tokenizer.decode(outputs[0], skip_special_tokens=True)
148
- except Exception as e:
149
- print(f"Translation error (EN->AR): {e}")
150
- return text
151
-
152
- language_code = 0
153
-
154
- query_text = 'How can a patient with chronic kidney disease manage their daily activities and maintain quality of life?' #'symptoms of a heart attack '
155
-
156
- def process_query(query_text):
157
- if language_code == 0:
158
- # Translate Arabic input to English
159
- query_text = translate_ar_to_en(query_text)
160
- return query_text
161
-
162
- def embed_query_text(query_text):
163
- query_embedding = embedding_model.encode([query_text])
164
- return query_embedding
165
-
166
- def query_embeddings(query_embedding, embeddings_data, n_results=5):
167
- doc_ids = list(embeddings_data.keys())
168
- doc_embeddings = np.array(list(embeddings_data.values()))
169
- similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
170
- top_indices = similarities.argsort()[-n_results:][::-1]
171
- top_docs = [(doc_ids[i], similarities[i]) for i in top_indices]
172
-
173
- return top_docs
174
-
175
- query_embedding = embed_query_text(query_text) # Embed the query text
176
- initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
177
- document_ids = [doc_id for doc_id, _ in initial_results]
178
- print(document_ids)
179
-
180
- import pandas as pd
181
- import requests
182
- from bs4 import BeautifulSoup
183
-
184
- # Load the Excel file
185
- file_path = '/kaggle/input/final-links/finalcleaned_excel_file.xlsx'
186
- df = pd.read_excel(file_path)
187
-
188
-
189
- # Create a dictionary mapping file names to URLs
190
- # Assuming the DataFrame index corresponds to file names
191
- file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
192
- def get_page_title(url):
193
- try:
194
- response = requests.get(url)
195
- if response.status_code == 200:
196
- soup = BeautifulSoup(response.text, 'html.parser')
197
- title = soup.find('title')
198
- return title.get_text() if title else "No title found"
199
- else:
200
- return None
201
- except requests.exceptions.RequestException:
202
- return None
203
- # Example file names
204
- file_names = document_ids
205
-
206
- # Retrieve original URLs
207
- for file_name in file_names:
208
- original_url = file_name_to_url.get(file_name, None)
209
- if original_url:
210
- title = get_page_title(original_url)
211
- if title:
212
- print(f"Title: {title},URL: {original_url}")
213
- else:
214
- print(f"Name: {file_name}")
215
- else:
216
- print(f"Name: {file_name}")
217
-
218
- def retrieve_document_texts(doc_ids, folder_path):
219
- texts = []
220
- for doc_id in doc_ids:
221
- file_path = os.path.join(folder_path, doc_id)
222
- try:
223
- with open(file_path, 'r', encoding='utf-8') as file:
224
- soup = BeautifulSoup(file, 'html.parser')
225
- text = soup.get_text(separator=' ', strip=True)
226
- texts.append(text)
227
- except FileNotFoundError:
228
- texts.append("")
229
- return texts
230
- document_ids = [doc_id for doc_id, _ in initial_results]
231
- document_texts = retrieve_document_texts(document_ids, folder_path)
232
-
233
- # Rerank the results using the CrossEncoder
234
- scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
235
- scored_documents = list(zip(scores, document_ids, document_texts))
236
- scored_documents.sort(key=lambda x: x[0], reverse=True)
237
- print("Reranked results:")
238
- for idx, (score, doc_id, doc) in enumerate(scored_documents):
239
- print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id}")
240
-
241
- from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
242
- import nltk
243
-
244
- # Load BioBERT model and tokenizer for NER
245
- bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
246
- bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
247
- ner_biobert = pipeline("ner", model=bio_model, tokenizer=bio_tokenizer)
248
-
249
- def extract_entities(text, ner_pipeline):
250
- """
251
- Extract entities using a NER pipeline.
252
- Args:
253
- text (str): The text from which to extract entities.
254
- ner_pipeline (pipeline): The NER pipeline for entity extraction.
255
- Returns:
256
- List[str]: A list of unique extracted entities.
257
- """
258
- ner_results = ner_pipeline(text)
259
- entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
260
- return list(entities)
261
-
262
- def match_entities(query_entities, sentence_entities):
263
- """
264
- Compute the relevance score based on entity matching.
265
- Args:
266
- query_entities (List[str]): Entities extracted from the query.
267
- sentence_entities (List[str]): Entities extracted from the sentence.
268
- Returns:
269
- float: The relevance score based on entity overlap.
270
- """
271
- query_set, sentence_set = set(query_entities), set(sentence_entities)
272
- matches = query_set.intersection(sentence_set)
273
- return len(matches)
274
-
275
- def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
276
- """
277
- Extract relevant text portions from documents based on entity matching.
278
- Args:
279
- document_texts (List[str]): List of document texts.
280
- query (str): The query text.
281
- max_portions (int): Maximum number of relevant portions to extract per document.
282
- portion_size (int): Number of sentences to include in each portion.
283
- min_query_words (int): Minimum number of matching entities to consider a sentence relevant.
284
- Returns:
285
- Dict[str, List[str]]: Relevant portions for each document.
286
- """
287
- relevant_portions = {}
288
-
289
- # Extract entities from the query
290
- query_entities = extract_entities(query, ner_biobert)
291
- print(f"Extracted Query Entities: {query_entities}")
292
-
293
- for doc_id, doc_text in enumerate(document_texts):
294
- sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
295
- doc_relevant_portions = []
296
-
297
- # Extract entities from the entire document
298
- doc_entities = extract_entities(doc_text, ner_biobert)
299
- print(f"Document {doc_id} Entities: {doc_entities}")
300
-
301
- for i, sentence in enumerate(sentences):
302
- # Extract entities from the sentence
303
- sentence_entities = extract_entities(sentence, ner_biobert)
304
-
305
- # Compute relevance score
306
- relevance_score = match_entities(query_entities, sentence_entities)
307
-
308
- # Select sentences with at least `min_query_words` matching entities
309
- if relevance_score >= min_query_words:
310
- start_idx = max(0, i - portion_size // 2)
311
- end_idx = min(len(sentences), i + portion_size // 2 + 1)
312
- portion = " ".join(sentences[start_idx:end_idx])
313
- doc_relevant_portions.append(portion)
314
-
315
- if len(doc_relevant_portions) >= max_portions:
316
- break
317
-
318
- # Add fallback to include the most entity-dense sentences if no results
319
- if not doc_relevant_portions and len(doc_entities) > 0:
320
- print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
321
- sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
322
- for fallback_sentence in sorted_sentences[:max_portions]:
323
- doc_relevant_portions.append(fallback_sentence)
324
-
325
- relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
326
-
327
- return relevant_portions
328
-
329
- # Extract relevant portions based on query and documents
330
- relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
331
-
332
- for doc_id, portions in relevant_portions.items():
333
- print(f"{doc_id}: {portions}")
334
-
335
- # Remove duplicates from the selected portions
336
- def remove_duplicates(selected_parts):
337
- unique_sentences = set()
338
- unique_selected_parts = []
339
-
340
- for sentence in selected_parts:
341
- if sentence not in unique_sentences:
342
- unique_selected_parts.append(sentence)
343
- unique_sentences.add(sentence)
344
-
345
- return unique_selected_parts
346
-
347
- # Flatten the dictionary of relevant portions (from earlier code)
348
- flattened_relevant_portions = []
349
- for doc_id, portions in relevant_portions.items():
350
- flattened_relevant_portions.extend(portions)
351
-
352
- # Remove duplicate portions
353
- unique_selected_parts = remove_duplicates(flattened_relevant_portions)
354
-
355
- # Combine the unique parts into a single string of context
356
- combined_parts = " ".join(unique_selected_parts)
357
-
358
- # Construct context as a list: first the query, then the unique selected portions
359
- context = [query_text] + unique_selected_parts
360
-
361
- # Print the context (query + relevant portions)
362
- print(context)
363
-
364
- import pickle
365
-
366
- with open('/kaggle/input/art-embeddings-pkl/embeddings.pkl', 'rb') as file:
367
- data = pickle.load(file)
368
-
369
- # Print the type of data
370
- print(f"Data type: {type(data)}")
371
-
372
- # Print the first few keys and values from the dictionary
373
- print("First few keys and values:")
374
- for i, (key, value) in enumerate(data.items()):
375
- if i >= 5: # Limit to printing the first 5 key-value pairs
376
- break
377
- print(f"Key: {key}, Value: {value}")
378
-
379
- import pickle
380
- import pickletools
381
-
382
- # Load the pickle file
383
- file_path = '/kaggle/input/art-embeddings-pkl/embeddings.pkl'
384
-
385
- with open(file_path, 'rb') as f:
386
- # Read the pickle file
387
- data = pickle.load(f)
388
-
389
- # Check for suspicious or corrupted entries
390
- def inspect_pickle(data):
391
- for key, value in data.items():
392
- if isinstance(value, (str, bytes)):
393
- # Try to decode and catch any non-ASCII issues
394
- try:
395
- value.decode('ascii')
396
- except UnicodeDecodeError as e:
397
- print(f"Non-ASCII entry found in key: {key}")
398
- print(f"Corrupted data: {value} ({e})")
399
- continue
400
-
401
- if isinstance(value, list) and any(isinstance(v, (list, dict, str, bytes)) for v in value):
402
- # Inspect list elements recursively
403
- inspect_pickle({f"{key}[{idx}]": v for idx, v in enumerate(value)})
404
-
405
- # Inspect the data
406
- inspect_pickle(data)
407
-
408
- from transformers import AutoTokenizer, AutoModelForTokenClassification
409
- import torch
410
- import time
411
-
412
- # Load Biobert model and tokenizer
413
- biobert_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
414
- biobert_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
415
-
416
- def extract_entities(text):
417
- inputs = biobert_tokenizer(text, return_tensors="pt")
418
- outputs = biobert_model(**inputs)
419
- predictions = torch.argmax(outputs.logits, dim=2)
420
- tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
421
- entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
422
- return entities
423
-
424
- def enhance_passage_with_entities(passage, entities):
425
- # Example: Add entities to the passage for better context
426
- return f"{passage}\n\nEntities: {', '.join(entities)}"
427
-
428
- def create_prompt(question, passage):
429
- prompt = ("""
430
- As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
431
-
432
- Passage: {passage}
433
-
434
- Question: {question}
435
-
436
- Answer:
437
- """)
438
- return prompt.format(passage=passage, question=question)
439
-
440
- def generate_answer(prompt, max_length=860, temperature=0.2):
441
- inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
442
-
443
- # Start timing
444
- start_time = time.time()
445
-
446
- output_ids = model_f.generate(
447
- inputs.input_ids,
448
- max_length=max_length,
449
- num_return_sequences=1,
450
- temperature=temperature,
451
- pad_token_id=tokenizer_f.eos_token_id
452
- )
453
-
454
- # End timing
455
- end_time = time.time()
456
-
457
- # Calculate the duration
458
- duration = end_time - start_time
459
-
460
- # Decode the answer
461
- answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
462
-
463
- passage_keywords = set(passage.lower().split())
464
- answer_keywords = set(answer.lower().split())
465
-
466
- if passage_keywords.intersection(answer_keywords):
467
- return answer, duration
468
- else:
469
- return "Sorry, I can't help with that.", duration
470
-
471
- # Integrate Biobert model
472
- entities = extract_entities(query_text)
473
- passage = enhance_passage_with_entities(combined_parts, entities)
474
- # Generate answer with the enhanced passage
475
- prompt = create_prompt(query_text, passage)
476
- answer, generation_time = generate_answer(prompt)
477
- print(f"\nTime taken to generate the answer: {generation_time:.2f} seconds")
478
- def remove_answer_prefix(text):
479
- prefix = "Answer:"
480
- if prefix in text:
481
- return text.split(prefix)[-1].strip()
482
- return text
483
-
484
- def remove_incomplete_sentence(text):
485
- # Check if the text ends with a period
486
- if not text.endswith('.'):
487
- # Find the last period or the end of the string
488
- last_period_index = text.rfind('.')
489
- if last_period_index != -1:
490
- # Remove everything after the last period
491
- return text[:last_period_index + 1].strip()
492
- return text
493
- # Clean and print the answer
494
- answer_part = answer.split("Answer:")[-1].strip()
495
- cleaned_answer = remove_answer_prefix(answer_part)
496
- final_answer = remove_incomplete_sentence(cleaned_answer)
497
 
498
- if language_code == 0:
499
- final_answer = translate_en_to_ar(final_answer)
 
500
 
501
- if final_answer:
502
- print("Answer:")
503
- print(final_answer)
504
- else:
505
- print("Sorry, I can't help with that.")
 
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForSeq2SeqLM,
9
+ AutoModelForTokenClassification,
10
  AutoModelForCausalLM,
11
+ pipeline
12
  )
13
  from sentence_transformers import SentenceTransformer, CrossEncoder
14
  from sklearn.metrics.pairwise import cosine_similarity
 
16
  import nltk
17
  import torch
18
  import pandas as pd
 
 
19
 
20
  app = Flask(__name__)
21
  CORS(app)
 
 
 
22
 
23
+ # Global variables for models and data
24
+ models = {}
25
+ data = {}
 
26
 
27
+ def init_nltk():
28
+ """Initialize NLTK resources"""
29
  try:
30
+ nltk.download('punkt', quiet=True)
31
+ return True
32
+ except Exception as e:
33
+ print(f"Error initializing NLTK: {e}")
34
+ return False
35
 
36
+ def load_models():
37
+ """Initialize all required models"""
38
+ try:
39
+ print("Loading models...")
40
+
41
+ # Embedding models
42
+ models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
43
+ models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
44
 
 
 
 
 
 
45
  # Translation models
46
+ models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
47
+ models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
48
+ models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
49
+ models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
50
+
51
+ # NER model
52
+ models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
53
+ models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
54
+ models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
55
+
56
  # LLM model
57
  model_name = "M4-ai/Orca-2.0-Tau-1.8B"
58
+ models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
59
+ models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
 
 
60
 
61
+ print("Models loaded successfully")
62
  return True
63
  except Exception as e:
64
+ print(f"Error loading models: {e}")
65
  return False
66
 
 
67
  def load_data():
68
+ """Load embeddings and document data"""
69
  try:
70
+ print("Loading data...")
 
 
71
 
72
  # Load embeddings
73
+ with open('embeddings.pkl', 'rb') as f:
74
+ data['embeddings'] = pickle.load(f)
75
 
76
+ # Load document links
77
+ data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx')
78
 
79
  print("Data loaded successfully")
80
  return True
 
82
  print(f"Error loading data: {e}")
83
  return False
84
 
85
+ def translate_text(text, source_to_target='ar_to_en'):
86
+ """Translate text between Arabic and English"""
87
+ try:
88
+ if source_to_target == 'ar_to_en':
89
+ tokenizer = models['ar_to_en_tokenizer']
90
+ model = models['ar_to_en_model']
91
+ else:
92
+ tokenizer = models['en_to_ar_tokenizer']
93
+ model = models['en_to_ar_model']
94
+
95
+ inputs = tokenizer(text, return_tensors="pt", truncation=True)
96
+ outputs = model.generate(**inputs)
97
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
98
+ except Exception as e:
99
+ print(f"Translation error: {e}")
100
+ return text
101
+
102
+ def query_embeddings(query_embedding, n_results=5):
103
+ """Find relevant documents using embedding similarity"""
104
+ doc_ids = list(data['embeddings'].keys())
105
+ doc_embeddings = np.array(list(data['embeddings'].values()))
106
+ similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
107
+ top_indices = similarities.argsort()[-n_results:][::-1]
108
+ return [(doc_ids[i], similarities[i]) for i in top_indices]
109
+
110
+ def retrieve_document_text(doc_id):
111
+ """Retrieve document text from HTML file"""
112
+ try:
113
+ with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
114
+ soup = BeautifulSoup(file, 'html.parser')
115
+ return soup.get_text(separator=' ', strip=True)
116
+ except Exception as e:
117
+ print(f"Error retrieving document {doc_id}: {e}")
118
+ return ""
119
+
120
+ def extract_entities(text):
121
+ """Extract medical entities from text"""
122
+ try:
123
+ results = models['ner_pipeline'](text)
124
+ return list({result['word'] for result in results if result['entity'].startswith("B-")})
125
+ except Exception as e:
126
+ print(f"Error extracting entities: {e}")
127
+ return []
128
+
129
+ def generate_answer(query, context, max_length=860, temperature=0.2):
130
+ """Generate answer using LLM"""
131
+ try:
132
+ prompt = f"""
133
+ As a medical expert, answer the following question based only on the provided context:
134
+
135
+ Context: {context}
136
+ Question: {query}
137
+
138
+ Answer:"""
139
+
140
+ inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
141
+ outputs = models['llm_model'].generate(
142
+ inputs.input_ids,
143
+ max_length=max_length,
144
+ num_return_sequences=1,
145
+ temperature=temperature,
146
+ pad_token_id=models['llm_tokenizer'].eos_token_id
147
+ )
148
+
149
+ answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
150
+ return answer.split("Answer:")[-1].strip()
151
+ except Exception as e:
152
+ print(f"Error generating answer: {e}")
153
+ return "Sorry, I couldn't generate an answer at this time."
154
+
155
  @app.route('/health', methods=['GET'])
156
  def health_check():
157
+ """Health check endpoint"""
158
  return jsonify({'status': 'healthy'})
159
 
160
  @app.route('/api/query', methods=['POST'])
161
  def process_query():
162
+ """Main query processing endpoint"""
163
  try:
164
  data = request.json
165
  if not data or 'query' not in data:
 
168
  query_text = data['query']
169
  language_code = data.get('language_code', 0)
170
 
171
+ # Translate if Arabic
172
  if language_code == 0:
173
+ query_text = translate_text(query_text, 'ar_to_en')
174
 
175
+ # Get query embedding and find relevant documents
176
+ query_embedding = models['embedding'].encode([query_text])
177
+ relevant_docs = query_embeddings(query_embedding)
178
+
179
+ # Retrieve and process documents
180
+ doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
181
 
182
+ # Extract entities and generate context
183
+ query_entities = extract_entities(query_text)
184
+ contexts = []
185
+ for text in doc_texts:
186
+ doc_entities = extract_entities(text)
187
+ if set(query_entities) & set(doc_entities):
188
+ contexts.append(text)
189
+
190
+ context = " ".join(contexts[:3]) # Use top 3 most relevant contexts
191
 
192
  # Generate answer
193
+ answer = generate_answer(query_text, context)
 
194
 
195
+ # Translate back if needed
196
  if language_code == 0:
197
+ answer = translate_text(answer, 'en_to_ar')
198
 
199
  return jsonify({
200
  'answer': answer,
 
207
  'success': False
208
  }), 500
209
 
210
+ # Initialize everything when the app starts
211
+ print("Initializing application...")
212
+ init_success = init_nltk() and load_models() and load_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ if not init_success:
215
+ print("Failed to initialize application")
216
+ exit(1)
217
 
218
+ if __name__ == "__main__":
219
+ app.run(host='0.0.0.0', port=7860)