Spaces:
Sleeping
Sleeping
Commit
·
f377404
1
Parent(s):
8ce8fc9
Update app.py
Browse files
app.py
CHANGED
@@ -38,6 +38,10 @@ def load_models():
|
|
38 |
try:
|
39 |
print("Loading models...")
|
40 |
|
|
|
|
|
|
|
|
|
41 |
# Embedding models
|
42 |
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
43 |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
@@ -64,24 +68,78 @@ def load_models():
|
|
64 |
print(f"Error loading models: {e}")
|
65 |
return False
|
66 |
|
67 |
-
def
|
68 |
-
"""Load embeddings
|
69 |
try:
|
70 |
-
print("Loading
|
|
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
80 |
return True
|
81 |
except Exception as e:
|
82 |
-
print(f"Error loading data: {e}")
|
|
|
83 |
return False
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
def translate_text(text, source_to_target='ar_to_en'):
|
86 |
"""Translate text between Arabic and English"""
|
87 |
try:
|
@@ -99,26 +157,8 @@ def translate_text(text, source_to_target='ar_to_en'):
|
|
99 |
print(f"Translation error: {e}")
|
100 |
return text
|
101 |
|
102 |
-
def query_embeddings(query_embedding, n_results=5):
|
103 |
-
"""Find relevant documents using embedding similarity"""
|
104 |
-
doc_ids = list(data['embeddings'].keys())
|
105 |
-
doc_embeddings = np.array(list(data['embeddings'].values()))
|
106 |
-
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
107 |
-
top_indices = similarities.argsort()[-n_results:][::-1]
|
108 |
-
return [(doc_ids[i], similarities[i]) for i in top_indices]
|
109 |
-
|
110 |
-
def retrieve_document_text(doc_id):
|
111 |
-
"""Retrieve document text from HTML file"""
|
112 |
-
try:
|
113 |
-
with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
|
114 |
-
soup = BeautifulSoup(file, 'html.parser')
|
115 |
-
return soup.get_text(separator=' ', strip=True)
|
116 |
-
except Exception as e:
|
117 |
-
print(f"Error retrieving document {doc_id}: {e}")
|
118 |
-
return ""
|
119 |
-
|
120 |
def extract_entities(text):
|
121 |
-
"""Extract medical entities from text"""
|
122 |
try:
|
123 |
results = models['ner_pipeline'](text)
|
124 |
return list({result['word'] for result in results if result['entity'].startswith("B-")})
|
@@ -130,37 +170,101 @@ def generate_answer(query, context, max_length=860, temperature=0.2):
|
|
130 |
"""Generate answer using LLM"""
|
131 |
try:
|
132 |
prompt = f"""
|
133 |
-
As a medical expert, answer the following question based
|
134 |
|
135 |
Context: {context}
|
136 |
-
Question: {query}
|
137 |
|
138 |
-
|
139 |
|
|
|
|
|
140 |
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
except Exception as e:
|
152 |
print(f"Error generating answer: {e}")
|
153 |
-
return "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
@app.route('/health', methods=['GET'])
|
156 |
def health_check():
|
157 |
"""Health check endpoint"""
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
@app.route('/api/query', methods=['POST'])
|
161 |
def process_query():
|
162 |
"""Main query processing endpoint"""
|
163 |
try:
|
|
|
|
|
|
|
164 |
data = request.json
|
165 |
if not data or 'query' not in data:
|
166 |
return jsonify({'error': 'No query provided', 'success': False}), 400
|
@@ -168,40 +272,67 @@ def process_query():
|
|
168 |
query_text = data['query']
|
169 |
language_code = data.get('language_code', 0)
|
170 |
|
171 |
-
#
|
172 |
-
if
|
173 |
-
|
|
|
|
|
|
|
174 |
|
175 |
-
#
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
181 |
-
|
182 |
-
# Extract entities and generate context
|
183 |
-
query_entities = extract_entities(query_text)
|
184 |
-
contexts = []
|
185 |
-
for text in doc_texts:
|
186 |
-
doc_entities = extract_entities(text)
|
187 |
-
if set(query_entities) & set(doc_entities):
|
188 |
-
contexts.append(text)
|
189 |
-
|
190 |
-
context = " ".join(contexts[:3]) # Use top 3 most relevant contexts
|
191 |
-
|
192 |
-
# Generate answer
|
193 |
-
answer = generate_answer(query_text, context)
|
194 |
-
|
195 |
-
# Translate back if needed
|
196 |
-
if language_code == 0:
|
197 |
-
answer = translate_text(answer, 'en_to_ar')
|
198 |
|
199 |
-
|
200 |
-
'
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
except Exception as e:
|
|
|
205 |
return jsonify({
|
206 |
'error': str(e),
|
207 |
'success': False
|
@@ -212,9 +343,7 @@ print("Initializing application...")
|
|
212 |
init_success = init_nltk() and load_models() and load_data()
|
213 |
|
214 |
if not init_success:
|
215 |
-
print("
|
216 |
-
exit(1)
|
217 |
|
218 |
if __name__ == "__main__":
|
219 |
-
app.run(host='0.0.0.0', port=7860)
|
220 |
-
|
|
|
38 |
try:
|
39 |
print("Loading models...")
|
40 |
|
41 |
+
# Set device
|
42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
43 |
+
print(f"Device set to use {device}")
|
44 |
+
|
45 |
# Embedding models
|
46 |
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
47 |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
|
|
68 |
print(f"Error loading models: {e}")
|
69 |
return False
|
70 |
|
71 |
+
def load_embeddings():
|
72 |
+
"""Load embeddings with robust error handling"""
|
73 |
try:
|
74 |
+
print("Loading embeddings...")
|
75 |
+
embeddings_path = 'embeddings.pkl'
|
76 |
|
77 |
+
if not os.path.exists(embeddings_path):
|
78 |
+
print(f"Error: {embeddings_path} not found")
|
79 |
+
return False
|
80 |
+
|
81 |
+
# Custom unpickler to handle potential compatibility issues
|
82 |
+
class CustomUnpickler(pickle.Unpickler):
|
83 |
+
def find_class(self, module, name):
|
84 |
+
if module == "__main__":
|
85 |
+
module = "numpy"
|
86 |
+
return super().find_class(module, name)
|
87 |
+
|
88 |
+
with open(embeddings_path, 'rb') as f:
|
89 |
+
try:
|
90 |
+
data['embeddings'] = pickle.load(f)
|
91 |
+
except Exception as e:
|
92 |
+
print(f"Standard unpickling failed, trying custom unpickler: {e}")
|
93 |
+
f.seek(0)
|
94 |
+
try:
|
95 |
+
data['embeddings'] = CustomUnpickler(f).load()
|
96 |
+
except Exception as e:
|
97 |
+
print(f"Custom unpickler failed: {e}")
|
98 |
+
data['embeddings'] = {}
|
99 |
+
return False
|
100 |
|
101 |
+
if not isinstance(data['embeddings'], dict):
|
102 |
+
print("Error: Embeddings data is not in expected format")
|
103 |
+
data['embeddings'] = {}
|
104 |
+
return False
|
105 |
+
|
106 |
+
print(f"Successfully loaded {len(data['embeddings'])} embeddings")
|
107 |
+
return True
|
108 |
+
except Exception as e:
|
109 |
+
print(f"Error loading embeddings: {e}")
|
110 |
+
data['embeddings'] = {}
|
111 |
+
return False
|
112 |
+
|
113 |
+
def load_documents_data():
|
114 |
+
"""Load document data with error handling"""
|
115 |
+
try:
|
116 |
+
print("Loading documents data...")
|
117 |
+
docs_path = 'finalcleaned_excel_file.xlsx'
|
118 |
|
119 |
+
if not os.path.exists(docs_path):
|
120 |
+
print(f"Error: {docs_path} not found")
|
121 |
+
return False
|
122 |
+
|
123 |
+
data['df'] = pd.read_excel(docs_path)
|
124 |
+
print(f"Successfully loaded {len(data['df'])} document records")
|
125 |
return True
|
126 |
except Exception as e:
|
127 |
+
print(f"Error loading documents data: {e}")
|
128 |
+
data['df'] = pd.DataFrame()
|
129 |
return False
|
130 |
|
131 |
+
def load_data():
|
132 |
+
"""Load all required data"""
|
133 |
+
embeddings_success = load_embeddings()
|
134 |
+
documents_success = load_documents_data()
|
135 |
+
|
136 |
+
if not embeddings_success:
|
137 |
+
print("Warning: Failed to load embeddings, falling back to basic functionality")
|
138 |
+
if not documents_success:
|
139 |
+
print("Warning: Failed to load documents data, falling back to basic functionality")
|
140 |
+
|
141 |
+
return True
|
142 |
+
|
143 |
def translate_text(text, source_to_target='ar_to_en'):
|
144 |
"""Translate text between Arabic and English"""
|
145 |
try:
|
|
|
157 |
print(f"Translation error: {e}")
|
158 |
return text
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
def extract_entities(text):
|
161 |
+
"""Extract medical entities from text using NER"""
|
162 |
try:
|
163 |
results = models['ner_pipeline'](text)
|
164 |
return list({result['word'] for result in results if result['entity'].startswith("B-")})
|
|
|
170 |
"""Generate answer using LLM"""
|
171 |
try:
|
172 |
prompt = f"""
|
173 |
+
As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
|
174 |
|
175 |
Context: {context}
|
|
|
176 |
|
177 |
+
Question: {query}
|
178 |
|
179 |
+
Answer: Let me help you with accurate information from reliable medical sources."""
|
180 |
+
|
181 |
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
|
182 |
+
|
183 |
+
with torch.no_grad():
|
184 |
+
outputs = models['llm_model'].generate(
|
185 |
+
inputs.input_ids,
|
186 |
+
max_length=max_length,
|
187 |
+
num_return_sequences=1,
|
188 |
+
temperature=temperature,
|
189 |
+
do_sample=True,
|
190 |
+
top_p=0.9,
|
191 |
+
pad_token_id=models['llm_tokenizer'].eos_token_id
|
192 |
+
)
|
193 |
+
|
194 |
+
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
195 |
+
|
196 |
+
# Clean up the response
|
197 |
+
if "Answer:" in response:
|
198 |
+
response = response.split("Answer:")[-1].strip()
|
199 |
+
|
200 |
+
# Remove incomplete sentences at the end
|
201 |
+
sentences = nltk.sent_tokenize(response)
|
202 |
+
if sentences:
|
203 |
+
return " ".join(sentences)
|
204 |
+
return response
|
205 |
+
|
206 |
except Exception as e:
|
207 |
print(f"Error generating answer: {e}")
|
208 |
+
return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
|
209 |
+
|
210 |
+
def query_embeddings(query_embedding, n_results=5):
|
211 |
+
"""Find relevant documents using embedding similarity"""
|
212 |
+
if not data['embeddings']:
|
213 |
+
return []
|
214 |
+
|
215 |
+
try:
|
216 |
+
doc_ids = list(data['embeddings'].keys())
|
217 |
+
doc_embeddings = np.array(list(data['embeddings'].values()))
|
218 |
+
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
219 |
+
top_indices = similarities.argsort()[-n_results:][::-1]
|
220 |
+
return [(doc_ids[i], similarities[i]) for i in top_indices]
|
221 |
+
except Exception as e:
|
222 |
+
print(f"Error in query_embeddings: {e}")
|
223 |
+
return []
|
224 |
+
|
225 |
+
def retrieve_document_text(doc_id):
|
226 |
+
"""Retrieve document text from HTML file"""
|
227 |
+
try:
|
228 |
+
file_path = os.path.join('downloaded_articles', doc_id)
|
229 |
+
if not os.path.exists(file_path):
|
230 |
+
print(f"Warning: Document file not found: {file_path}")
|
231 |
+
return ""
|
232 |
+
|
233 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
234 |
+
soup = BeautifulSoup(file, 'html.parser')
|
235 |
+
return soup.get_text(separator=' ', strip=True)
|
236 |
+
except Exception as e:
|
237 |
+
print(f"Error retrieving document {doc_id}: {e}")
|
238 |
+
return ""
|
239 |
+
|
240 |
+
def rerank_documents(query, doc_texts):
|
241 |
+
"""Rerank documents using cross-encoder"""
|
242 |
+
try:
|
243 |
+
pairs = [(query, doc) for doc in doc_texts]
|
244 |
+
scores = models['cross_encoder'].predict(pairs)
|
245 |
+
return scores
|
246 |
+
except Exception as e:
|
247 |
+
print(f"Error reranking documents: {e}")
|
248 |
+
return np.zeros(len(doc_texts))
|
249 |
|
250 |
@app.route('/health', methods=['GET'])
|
251 |
def health_check():
|
252 |
"""Health check endpoint"""
|
253 |
+
status = {
|
254 |
+
'status': 'healthy',
|
255 |
+
'models_loaded': bool(models),
|
256 |
+
'embeddings_loaded': bool(data.get('embeddings')),
|
257 |
+
'documents_loaded': not data.get('df', pd.DataFrame()).empty
|
258 |
+
}
|
259 |
+
return jsonify(status)
|
260 |
|
261 |
@app.route('/api/query', methods=['POST'])
|
262 |
def process_query():
|
263 |
"""Main query processing endpoint"""
|
264 |
try:
|
265 |
+
if not request.is_json:
|
266 |
+
return jsonify({'error': 'Request must be JSON', 'success': False}), 400
|
267 |
+
|
268 |
data = request.json
|
269 |
if not data or 'query' not in data:
|
270 |
return jsonify({'error': 'No query provided', 'success': False}), 400
|
|
|
272 |
query_text = data['query']
|
273 |
language_code = data.get('language_code', 0)
|
274 |
|
275 |
+
# Basic response if no models or data are loaded
|
276 |
+
if not models or not data.get('embeddings'):
|
277 |
+
return jsonify({
|
278 |
+
'answer': 'The system is currently initializing. Please try again in a few minutes.',
|
279 |
+
'success': False
|
280 |
+
}), 503
|
281 |
|
282 |
+
# Process query with available models and data
|
283 |
+
try:
|
284 |
+
# Handle Arabic queries
|
285 |
+
if language_code == 0:
|
286 |
+
query_text = translate_text(query_text, 'ar_to_en')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
+
# Get query embedding and find relevant documents
|
289 |
+
query_embedding = models['embedding'].encode([query_text])
|
290 |
+
relevant_docs = query_embeddings(query_embedding)
|
291 |
+
|
292 |
+
if not relevant_docs:
|
293 |
+
return jsonify({
|
294 |
+
'answer': 'No relevant information found. Please try a different query.',
|
295 |
+
'success': True
|
296 |
+
})
|
297 |
+
|
298 |
+
# Retrieve and process documents
|
299 |
+
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
300 |
+
doc_texts = [text for text in doc_texts if text.strip()]
|
301 |
+
|
302 |
+
if not doc_texts:
|
303 |
+
return jsonify({
|
304 |
+
'answer': 'Unable to retrieve relevant documents. Please try again.',
|
305 |
+
'success': True
|
306 |
+
})
|
307 |
+
|
308 |
+
# Rerank documents
|
309 |
+
rerank_scores = rerank_documents(query_text, doc_texts)
|
310 |
+
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
311 |
+
|
312 |
+
# Combine top documents
|
313 |
+
context = " ".join(ranked_texts[:3])
|
314 |
+
|
315 |
+
# Generate answer
|
316 |
+
answer = generate_answer(query_text, context)
|
317 |
+
|
318 |
+
# Translate answer back to Arabic if needed
|
319 |
+
if language_code == 0:
|
320 |
+
answer = translate_text(answer, 'en_to_ar')
|
321 |
+
|
322 |
+
return jsonify({
|
323 |
+
'answer': answer,
|
324 |
+
'success': True
|
325 |
+
})
|
326 |
+
|
327 |
+
except Exception as e:
|
328 |
+
print(f"Error processing query: {e}")
|
329 |
+
return jsonify({
|
330 |
+
'error': 'An error occurred while processing your query',
|
331 |
+
'success': False
|
332 |
+
}), 500
|
333 |
|
334 |
except Exception as e:
|
335 |
+
print(f"Error in process_query: {e}")
|
336 |
return jsonify({
|
337 |
'error': str(e),
|
338 |
'success': False
|
|
|
343 |
init_success = init_nltk() and load_models() and load_data()
|
344 |
|
345 |
if not init_success:
|
346 |
+
print("Warning: Application initialized with partial functionality")
|
|
|
347 |
|
348 |
if __name__ == "__main__":
|
349 |
+
app.run(host='0.0.0.0', port=7860)
|
|