thechaiexperiment commited on
Commit
f377404
·
1 Parent(s): 8ce8fc9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -78
app.py CHANGED
@@ -38,6 +38,10 @@ def load_models():
38
  try:
39
  print("Loading models...")
40
 
 
 
 
 
41
  # Embedding models
42
  models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
43
  models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
@@ -64,24 +68,78 @@ def load_models():
64
  print(f"Error loading models: {e}")
65
  return False
66
 
67
- def load_data():
68
- """Load embeddings and document data"""
69
  try:
70
- print("Loading data...")
 
71
 
72
- # Load embeddings
73
- with open('embeddings.pkl', 'rb') as f:
74
- data['embeddings'] = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Load document links
77
- data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- print("Data loaded successfully")
 
 
 
 
 
80
  return True
81
  except Exception as e:
82
- print(f"Error loading data: {e}")
 
83
  return False
84
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def translate_text(text, source_to_target='ar_to_en'):
86
  """Translate text between Arabic and English"""
87
  try:
@@ -99,26 +157,8 @@ def translate_text(text, source_to_target='ar_to_en'):
99
  print(f"Translation error: {e}")
100
  return text
101
 
102
- def query_embeddings(query_embedding, n_results=5):
103
- """Find relevant documents using embedding similarity"""
104
- doc_ids = list(data['embeddings'].keys())
105
- doc_embeddings = np.array(list(data['embeddings'].values()))
106
- similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
107
- top_indices = similarities.argsort()[-n_results:][::-1]
108
- return [(doc_ids[i], similarities[i]) for i in top_indices]
109
-
110
- def retrieve_document_text(doc_id):
111
- """Retrieve document text from HTML file"""
112
- try:
113
- with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
114
- soup = BeautifulSoup(file, 'html.parser')
115
- return soup.get_text(separator=' ', strip=True)
116
- except Exception as e:
117
- print(f"Error retrieving document {doc_id}: {e}")
118
- return ""
119
-
120
  def extract_entities(text):
121
- """Extract medical entities from text"""
122
  try:
123
  results = models['ner_pipeline'](text)
124
  return list({result['word'] for result in results if result['entity'].startswith("B-")})
@@ -130,37 +170,101 @@ def generate_answer(query, context, max_length=860, temperature=0.2):
130
  """Generate answer using LLM"""
131
  try:
132
  prompt = f"""
133
- As a medical expert, answer the following question based only on the provided context:
134
 
135
  Context: {context}
136
- Question: {query}
137
 
138
- Answer:"""
139
 
 
 
140
  inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
141
- outputs = models['llm_model'].generate(
142
- inputs.input_ids,
143
- max_length=max_length,
144
- num_return_sequences=1,
145
- temperature=temperature,
146
- pad_token_id=models['llm_tokenizer'].eos_token_id
147
- )
148
-
149
- answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
150
- return answer.split("Answer:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
  print(f"Error generating answer: {e}")
153
- return "Sorry, I couldn't generate an answer at this time."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  @app.route('/health', methods=['GET'])
156
  def health_check():
157
  """Health check endpoint"""
158
- return jsonify({'status': 'healthy'})
 
 
 
 
 
 
159
 
160
  @app.route('/api/query', methods=['POST'])
161
  def process_query():
162
  """Main query processing endpoint"""
163
  try:
 
 
 
164
  data = request.json
165
  if not data or 'query' not in data:
166
  return jsonify({'error': 'No query provided', 'success': False}), 400
@@ -168,40 +272,67 @@ def process_query():
168
  query_text = data['query']
169
  language_code = data.get('language_code', 0)
170
 
171
- # Translate if Arabic
172
- if language_code == 0:
173
- query_text = translate_text(query_text, 'ar_to_en')
 
 
 
174
 
175
- # Get query embedding and find relevant documents
176
- query_embedding = models['embedding'].encode([query_text])
177
- relevant_docs = query_embeddings(query_embedding)
178
-
179
- # Retrieve and process documents
180
- doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
181
-
182
- # Extract entities and generate context
183
- query_entities = extract_entities(query_text)
184
- contexts = []
185
- for text in doc_texts:
186
- doc_entities = extract_entities(text)
187
- if set(query_entities) & set(doc_entities):
188
- contexts.append(text)
189
-
190
- context = " ".join(contexts[:3]) # Use top 3 most relevant contexts
191
-
192
- # Generate answer
193
- answer = generate_answer(query_text, context)
194
-
195
- # Translate back if needed
196
- if language_code == 0:
197
- answer = translate_text(answer, 'en_to_ar')
198
 
199
- return jsonify({
200
- 'answer': answer,
201
- 'success': True
202
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  except Exception as e:
 
205
  return jsonify({
206
  'error': str(e),
207
  'success': False
@@ -212,9 +343,7 @@ print("Initializing application...")
212
  init_success = init_nltk() and load_models() and load_data()
213
 
214
  if not init_success:
215
- print("Failed to initialize application")
216
- exit(1)
217
 
218
  if __name__ == "__main__":
219
- app.run(host='0.0.0.0', port=7860)
220
-
 
38
  try:
39
  print("Loading models...")
40
 
41
+ # Set device
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ print(f"Device set to use {device}")
44
+
45
  # Embedding models
46
  models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
47
  models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
 
68
  print(f"Error loading models: {e}")
69
  return False
70
 
71
+ def load_embeddings():
72
+ """Load embeddings with robust error handling"""
73
  try:
74
+ print("Loading embeddings...")
75
+ embeddings_path = 'embeddings.pkl'
76
 
77
+ if not os.path.exists(embeddings_path):
78
+ print(f"Error: {embeddings_path} not found")
79
+ return False
80
+
81
+ # Custom unpickler to handle potential compatibility issues
82
+ class CustomUnpickler(pickle.Unpickler):
83
+ def find_class(self, module, name):
84
+ if module == "__main__":
85
+ module = "numpy"
86
+ return super().find_class(module, name)
87
+
88
+ with open(embeddings_path, 'rb') as f:
89
+ try:
90
+ data['embeddings'] = pickle.load(f)
91
+ except Exception as e:
92
+ print(f"Standard unpickling failed, trying custom unpickler: {e}")
93
+ f.seek(0)
94
+ try:
95
+ data['embeddings'] = CustomUnpickler(f).load()
96
+ except Exception as e:
97
+ print(f"Custom unpickler failed: {e}")
98
+ data['embeddings'] = {}
99
+ return False
100
 
101
+ if not isinstance(data['embeddings'], dict):
102
+ print("Error: Embeddings data is not in expected format")
103
+ data['embeddings'] = {}
104
+ return False
105
+
106
+ print(f"Successfully loaded {len(data['embeddings'])} embeddings")
107
+ return True
108
+ except Exception as e:
109
+ print(f"Error loading embeddings: {e}")
110
+ data['embeddings'] = {}
111
+ return False
112
+
113
+ def load_documents_data():
114
+ """Load document data with error handling"""
115
+ try:
116
+ print("Loading documents data...")
117
+ docs_path = 'finalcleaned_excel_file.xlsx'
118
 
119
+ if not os.path.exists(docs_path):
120
+ print(f"Error: {docs_path} not found")
121
+ return False
122
+
123
+ data['df'] = pd.read_excel(docs_path)
124
+ print(f"Successfully loaded {len(data['df'])} document records")
125
  return True
126
  except Exception as e:
127
+ print(f"Error loading documents data: {e}")
128
+ data['df'] = pd.DataFrame()
129
  return False
130
 
131
+ def load_data():
132
+ """Load all required data"""
133
+ embeddings_success = load_embeddings()
134
+ documents_success = load_documents_data()
135
+
136
+ if not embeddings_success:
137
+ print("Warning: Failed to load embeddings, falling back to basic functionality")
138
+ if not documents_success:
139
+ print("Warning: Failed to load documents data, falling back to basic functionality")
140
+
141
+ return True
142
+
143
  def translate_text(text, source_to_target='ar_to_en'):
144
  """Translate text between Arabic and English"""
145
  try:
 
157
  print(f"Translation error: {e}")
158
  return text
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  def extract_entities(text):
161
+ """Extract medical entities from text using NER"""
162
  try:
163
  results = models['ner_pipeline'](text)
164
  return list({result['word'] for result in results if result['entity'].startswith("B-")})
 
170
  """Generate answer using LLM"""
171
  try:
172
  prompt = f"""
173
+ As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
174
 
175
  Context: {context}
 
176
 
177
+ Question: {query}
178
 
179
+ Answer: Let me help you with accurate information from reliable medical sources."""
180
+
181
  inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
182
+
183
+ with torch.no_grad():
184
+ outputs = models['llm_model'].generate(
185
+ inputs.input_ids,
186
+ max_length=max_length,
187
+ num_return_sequences=1,
188
+ temperature=temperature,
189
+ do_sample=True,
190
+ top_p=0.9,
191
+ pad_token_id=models['llm_tokenizer'].eos_token_id
192
+ )
193
+
194
+ response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
195
+
196
+ # Clean up the response
197
+ if "Answer:" in response:
198
+ response = response.split("Answer:")[-1].strip()
199
+
200
+ # Remove incomplete sentences at the end
201
+ sentences = nltk.sent_tokenize(response)
202
+ if sentences:
203
+ return " ".join(sentences)
204
+ return response
205
+
206
  except Exception as e:
207
  print(f"Error generating answer: {e}")
208
+ return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
209
+
210
+ def query_embeddings(query_embedding, n_results=5):
211
+ """Find relevant documents using embedding similarity"""
212
+ if not data['embeddings']:
213
+ return []
214
+
215
+ try:
216
+ doc_ids = list(data['embeddings'].keys())
217
+ doc_embeddings = np.array(list(data['embeddings'].values()))
218
+ similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
219
+ top_indices = similarities.argsort()[-n_results:][::-1]
220
+ return [(doc_ids[i], similarities[i]) for i in top_indices]
221
+ except Exception as e:
222
+ print(f"Error in query_embeddings: {e}")
223
+ return []
224
+
225
+ def retrieve_document_text(doc_id):
226
+ """Retrieve document text from HTML file"""
227
+ try:
228
+ file_path = os.path.join('downloaded_articles', doc_id)
229
+ if not os.path.exists(file_path):
230
+ print(f"Warning: Document file not found: {file_path}")
231
+ return ""
232
+
233
+ with open(file_path, 'r', encoding='utf-8') as file:
234
+ soup = BeautifulSoup(file, 'html.parser')
235
+ return soup.get_text(separator=' ', strip=True)
236
+ except Exception as e:
237
+ print(f"Error retrieving document {doc_id}: {e}")
238
+ return ""
239
+
240
+ def rerank_documents(query, doc_texts):
241
+ """Rerank documents using cross-encoder"""
242
+ try:
243
+ pairs = [(query, doc) for doc in doc_texts]
244
+ scores = models['cross_encoder'].predict(pairs)
245
+ return scores
246
+ except Exception as e:
247
+ print(f"Error reranking documents: {e}")
248
+ return np.zeros(len(doc_texts))
249
 
250
  @app.route('/health', methods=['GET'])
251
  def health_check():
252
  """Health check endpoint"""
253
+ status = {
254
+ 'status': 'healthy',
255
+ 'models_loaded': bool(models),
256
+ 'embeddings_loaded': bool(data.get('embeddings')),
257
+ 'documents_loaded': not data.get('df', pd.DataFrame()).empty
258
+ }
259
+ return jsonify(status)
260
 
261
  @app.route('/api/query', methods=['POST'])
262
  def process_query():
263
  """Main query processing endpoint"""
264
  try:
265
+ if not request.is_json:
266
+ return jsonify({'error': 'Request must be JSON', 'success': False}), 400
267
+
268
  data = request.json
269
  if not data or 'query' not in data:
270
  return jsonify({'error': 'No query provided', 'success': False}), 400
 
272
  query_text = data['query']
273
  language_code = data.get('language_code', 0)
274
 
275
+ # Basic response if no models or data are loaded
276
+ if not models or not data.get('embeddings'):
277
+ return jsonify({
278
+ 'answer': 'The system is currently initializing. Please try again in a few minutes.',
279
+ 'success': False
280
+ }), 503
281
 
282
+ # Process query with available models and data
283
+ try:
284
+ # Handle Arabic queries
285
+ if language_code == 0:
286
+ query_text = translate_text(query_text, 'ar_to_en')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ # Get query embedding and find relevant documents
289
+ query_embedding = models['embedding'].encode([query_text])
290
+ relevant_docs = query_embeddings(query_embedding)
291
+
292
+ if not relevant_docs:
293
+ return jsonify({
294
+ 'answer': 'No relevant information found. Please try a different query.',
295
+ 'success': True
296
+ })
297
+
298
+ # Retrieve and process documents
299
+ doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
300
+ doc_texts = [text for text in doc_texts if text.strip()]
301
+
302
+ if not doc_texts:
303
+ return jsonify({
304
+ 'answer': 'Unable to retrieve relevant documents. Please try again.',
305
+ 'success': True
306
+ })
307
+
308
+ # Rerank documents
309
+ rerank_scores = rerank_documents(query_text, doc_texts)
310
+ ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
311
+
312
+ # Combine top documents
313
+ context = " ".join(ranked_texts[:3])
314
+
315
+ # Generate answer
316
+ answer = generate_answer(query_text, context)
317
+
318
+ # Translate answer back to Arabic if needed
319
+ if language_code == 0:
320
+ answer = translate_text(answer, 'en_to_ar')
321
+
322
+ return jsonify({
323
+ 'answer': answer,
324
+ 'success': True
325
+ })
326
+
327
+ except Exception as e:
328
+ print(f"Error processing query: {e}")
329
+ return jsonify({
330
+ 'error': 'An error occurred while processing your query',
331
+ 'success': False
332
+ }), 500
333
 
334
  except Exception as e:
335
+ print(f"Error in process_query: {e}")
336
  return jsonify({
337
  'error': str(e),
338
  'success': False
 
343
  init_success = init_nltk() and load_models() and load_data()
344
 
345
  if not init_success:
346
+ print("Warning: Application initialized with partial functionality")
 
347
 
348
  if __name__ == "__main__":
349
+ app.run(host='0.0.0.0', port=7860)