Spaces:
Sleeping
Sleeping
Commit
·
31bad44
1
Parent(s):
0250187
Update app.py
Browse files
app.py
CHANGED
@@ -6,8 +6,9 @@ from flask_cors import CORS
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSeq2SeqLM,
|
|
|
9 |
AutoModelForCausalLM,
|
10 |
-
|
11 |
)
|
12 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
13 |
from sklearn.metrics.pairwise import cosine_similarity
|
@@ -15,71 +16,65 @@ from bs4 import BeautifulSoup
|
|
15 |
import nltk
|
16 |
import torch
|
17 |
import pandas as pd
|
18 |
-
from startup import setup_files
|
19 |
-
|
20 |
|
21 |
app = Flask(__name__)
|
22 |
CORS(app)
|
23 |
-
# Environment variables for file paths
|
24 |
-
EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/embeddings.pkl')
|
25 |
-
LINKS_PATH = os.environ.get('LINKS_PATH', 'data/finalcleaned_excel_file.xlsx')
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
setup_files()
|
31 |
|
32 |
-
|
33 |
-
|
34 |
try:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
# Basic embedding models
|
43 |
-
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
44 |
-
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
45 |
-
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
46 |
-
|
47 |
# Translation models
|
48 |
-
ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
49 |
-
ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
50 |
-
en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
51 |
-
en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
52 |
-
|
53 |
-
#
|
54 |
-
bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
55 |
-
bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
56 |
-
|
|
|
57 |
# LLM model
|
58 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
nltk.download('punkt', quiet=True)
|
63 |
|
64 |
-
print("Models
|
65 |
return True
|
66 |
except Exception as e:
|
67 |
-
print(f"Error
|
68 |
return False
|
69 |
|
70 |
-
# Load data with error handling
|
71 |
def load_data():
|
|
|
72 |
try:
|
73 |
-
|
74 |
-
|
75 |
-
print("Loading data files...")
|
76 |
|
77 |
# Load embeddings
|
78 |
-
with open(
|
79 |
-
|
80 |
|
81 |
-
# Load links
|
82 |
-
df = pd.read_excel(
|
83 |
|
84 |
print("Data loaded successfully")
|
85 |
return True
|
@@ -87,12 +82,84 @@ def load_data():
|
|
87 |
print(f"Error loading data: {e}")
|
88 |
return False
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
@app.route('/health', methods=['GET'])
|
91 |
def health_check():
|
|
|
92 |
return jsonify({'status': 'healthy'})
|
93 |
|
94 |
@app.route('/api/query', methods=['POST'])
|
95 |
def process_query():
|
|
|
96 |
try:
|
97 |
data = request.json
|
98 |
if not data or 'query' not in data:
|
@@ -101,24 +168,33 @@ def process_query():
|
|
101 |
query_text = data['query']
|
102 |
language_code = data.get('language_code', 0)
|
103 |
|
104 |
-
#
|
105 |
if language_code == 0:
|
106 |
-
query_text =
|
107 |
|
108 |
-
# Get
|
109 |
-
query_embedding =
|
110 |
-
|
|
|
|
|
|
|
111 |
|
112 |
-
#
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# Generate answer
|
117 |
-
|
118 |
-
answer = generate_answer(query_text, combined_text)
|
119 |
|
|
|
120 |
if language_code == 0:
|
121 |
-
answer =
|
122 |
|
123 |
return jsonify({
|
124 |
'answer': answer,
|
@@ -131,375 +207,13 @@ def process_query():
|
|
131 |
'success': False
|
132 |
}), 500
|
133 |
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
outputs = ar_to_en_model.generate(**inputs)
|
138 |
-
return ar_to_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
139 |
-
except Exception as e:
|
140 |
-
print(f"Translation error (AR->EN): {e}")
|
141 |
-
return text
|
142 |
-
|
143 |
-
def translate_en_to_ar(text):
|
144 |
-
try:
|
145 |
-
inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True)
|
146 |
-
outputs = en_to_ar_model.generate(**inputs)
|
147 |
-
return en_to_ar_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
148 |
-
except Exception as e:
|
149 |
-
print(f"Translation error (EN->AR): {e}")
|
150 |
-
return text
|
151 |
-
|
152 |
-
language_code = 0
|
153 |
-
|
154 |
-
query_text = 'How can a patient with chronic kidney disease manage their daily activities and maintain quality of life?' #'symptoms of a heart attack '
|
155 |
-
|
156 |
-
def process_query(query_text):
|
157 |
-
if language_code == 0:
|
158 |
-
# Translate Arabic input to English
|
159 |
-
query_text = translate_ar_to_en(query_text)
|
160 |
-
return query_text
|
161 |
-
|
162 |
-
def embed_query_text(query_text):
|
163 |
-
query_embedding = embedding_model.encode([query_text])
|
164 |
-
return query_embedding
|
165 |
-
|
166 |
-
def query_embeddings(query_embedding, embeddings_data, n_results=5):
|
167 |
-
doc_ids = list(embeddings_data.keys())
|
168 |
-
doc_embeddings = np.array(list(embeddings_data.values()))
|
169 |
-
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
170 |
-
top_indices = similarities.argsort()[-n_results:][::-1]
|
171 |
-
top_docs = [(doc_ids[i], similarities[i]) for i in top_indices]
|
172 |
-
|
173 |
-
return top_docs
|
174 |
-
|
175 |
-
query_embedding = embed_query_text(query_text) # Embed the query text
|
176 |
-
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
177 |
-
document_ids = [doc_id for doc_id, _ in initial_results]
|
178 |
-
print(document_ids)
|
179 |
-
|
180 |
-
import pandas as pd
|
181 |
-
import requests
|
182 |
-
from bs4 import BeautifulSoup
|
183 |
-
|
184 |
-
# Load the Excel file
|
185 |
-
file_path = '/kaggle/input/final-links/finalcleaned_excel_file.xlsx'
|
186 |
-
df = pd.read_excel(file_path)
|
187 |
-
|
188 |
-
|
189 |
-
# Create a dictionary mapping file names to URLs
|
190 |
-
# Assuming the DataFrame index corresponds to file names
|
191 |
-
file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
192 |
-
def get_page_title(url):
|
193 |
-
try:
|
194 |
-
response = requests.get(url)
|
195 |
-
if response.status_code == 200:
|
196 |
-
soup = BeautifulSoup(response.text, 'html.parser')
|
197 |
-
title = soup.find('title')
|
198 |
-
return title.get_text() if title else "No title found"
|
199 |
-
else:
|
200 |
-
return None
|
201 |
-
except requests.exceptions.RequestException:
|
202 |
-
return None
|
203 |
-
# Example file names
|
204 |
-
file_names = document_ids
|
205 |
-
|
206 |
-
# Retrieve original URLs
|
207 |
-
for file_name in file_names:
|
208 |
-
original_url = file_name_to_url.get(file_name, None)
|
209 |
-
if original_url:
|
210 |
-
title = get_page_title(original_url)
|
211 |
-
if title:
|
212 |
-
print(f"Title: {title},URL: {original_url}")
|
213 |
-
else:
|
214 |
-
print(f"Name: {file_name}")
|
215 |
-
else:
|
216 |
-
print(f"Name: {file_name}")
|
217 |
-
|
218 |
-
def retrieve_document_texts(doc_ids, folder_path):
|
219 |
-
texts = []
|
220 |
-
for doc_id in doc_ids:
|
221 |
-
file_path = os.path.join(folder_path, doc_id)
|
222 |
-
try:
|
223 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
224 |
-
soup = BeautifulSoup(file, 'html.parser')
|
225 |
-
text = soup.get_text(separator=' ', strip=True)
|
226 |
-
texts.append(text)
|
227 |
-
except FileNotFoundError:
|
228 |
-
texts.append("")
|
229 |
-
return texts
|
230 |
-
document_ids = [doc_id for doc_id, _ in initial_results]
|
231 |
-
document_texts = retrieve_document_texts(document_ids, folder_path)
|
232 |
-
|
233 |
-
# Rerank the results using the CrossEncoder
|
234 |
-
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
235 |
-
scored_documents = list(zip(scores, document_ids, document_texts))
|
236 |
-
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
237 |
-
print("Reranked results:")
|
238 |
-
for idx, (score, doc_id, doc) in enumerate(scored_documents):
|
239 |
-
print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id}")
|
240 |
-
|
241 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
242 |
-
import nltk
|
243 |
-
|
244 |
-
# Load BioBERT model and tokenizer for NER
|
245 |
-
bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
246 |
-
bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
247 |
-
ner_biobert = pipeline("ner", model=bio_model, tokenizer=bio_tokenizer)
|
248 |
-
|
249 |
-
def extract_entities(text, ner_pipeline):
|
250 |
-
"""
|
251 |
-
Extract entities using a NER pipeline.
|
252 |
-
Args:
|
253 |
-
text (str): The text from which to extract entities.
|
254 |
-
ner_pipeline (pipeline): The NER pipeline for entity extraction.
|
255 |
-
Returns:
|
256 |
-
List[str]: A list of unique extracted entities.
|
257 |
-
"""
|
258 |
-
ner_results = ner_pipeline(text)
|
259 |
-
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
260 |
-
return list(entities)
|
261 |
-
|
262 |
-
def match_entities(query_entities, sentence_entities):
|
263 |
-
"""
|
264 |
-
Compute the relevance score based on entity matching.
|
265 |
-
Args:
|
266 |
-
query_entities (List[str]): Entities extracted from the query.
|
267 |
-
sentence_entities (List[str]): Entities extracted from the sentence.
|
268 |
-
Returns:
|
269 |
-
float: The relevance score based on entity overlap.
|
270 |
-
"""
|
271 |
-
query_set, sentence_set = set(query_entities), set(sentence_entities)
|
272 |
-
matches = query_set.intersection(sentence_set)
|
273 |
-
return len(matches)
|
274 |
-
|
275 |
-
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
|
276 |
-
"""
|
277 |
-
Extract relevant text portions from documents based on entity matching.
|
278 |
-
Args:
|
279 |
-
document_texts (List[str]): List of document texts.
|
280 |
-
query (str): The query text.
|
281 |
-
max_portions (int): Maximum number of relevant portions to extract per document.
|
282 |
-
portion_size (int): Number of sentences to include in each portion.
|
283 |
-
min_query_words (int): Minimum number of matching entities to consider a sentence relevant.
|
284 |
-
Returns:
|
285 |
-
Dict[str, List[str]]: Relevant portions for each document.
|
286 |
-
"""
|
287 |
-
relevant_portions = {}
|
288 |
-
|
289 |
-
# Extract entities from the query
|
290 |
-
query_entities = extract_entities(query, ner_biobert)
|
291 |
-
print(f"Extracted Query Entities: {query_entities}")
|
292 |
-
|
293 |
-
for doc_id, doc_text in enumerate(document_texts):
|
294 |
-
sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
|
295 |
-
doc_relevant_portions = []
|
296 |
-
|
297 |
-
# Extract entities from the entire document
|
298 |
-
doc_entities = extract_entities(doc_text, ner_biobert)
|
299 |
-
print(f"Document {doc_id} Entities: {doc_entities}")
|
300 |
-
|
301 |
-
for i, sentence in enumerate(sentences):
|
302 |
-
# Extract entities from the sentence
|
303 |
-
sentence_entities = extract_entities(sentence, ner_biobert)
|
304 |
-
|
305 |
-
# Compute relevance score
|
306 |
-
relevance_score = match_entities(query_entities, sentence_entities)
|
307 |
-
|
308 |
-
# Select sentences with at least `min_query_words` matching entities
|
309 |
-
if relevance_score >= min_query_words:
|
310 |
-
start_idx = max(0, i - portion_size // 2)
|
311 |
-
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
312 |
-
portion = " ".join(sentences[start_idx:end_idx])
|
313 |
-
doc_relevant_portions.append(portion)
|
314 |
-
|
315 |
-
if len(doc_relevant_portions) >= max_portions:
|
316 |
-
break
|
317 |
-
|
318 |
-
# Add fallback to include the most entity-dense sentences if no results
|
319 |
-
if not doc_relevant_portions and len(doc_entities) > 0:
|
320 |
-
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
321 |
-
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
322 |
-
for fallback_sentence in sorted_sentences[:max_portions]:
|
323 |
-
doc_relevant_portions.append(fallback_sentence)
|
324 |
-
|
325 |
-
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
326 |
-
|
327 |
-
return relevant_portions
|
328 |
-
|
329 |
-
# Extract relevant portions based on query and documents
|
330 |
-
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
|
331 |
-
|
332 |
-
for doc_id, portions in relevant_portions.items():
|
333 |
-
print(f"{doc_id}: {portions}")
|
334 |
-
|
335 |
-
# Remove duplicates from the selected portions
|
336 |
-
def remove_duplicates(selected_parts):
|
337 |
-
unique_sentences = set()
|
338 |
-
unique_selected_parts = []
|
339 |
-
|
340 |
-
for sentence in selected_parts:
|
341 |
-
if sentence not in unique_sentences:
|
342 |
-
unique_selected_parts.append(sentence)
|
343 |
-
unique_sentences.add(sentence)
|
344 |
-
|
345 |
-
return unique_selected_parts
|
346 |
-
|
347 |
-
# Flatten the dictionary of relevant portions (from earlier code)
|
348 |
-
flattened_relevant_portions = []
|
349 |
-
for doc_id, portions in relevant_portions.items():
|
350 |
-
flattened_relevant_portions.extend(portions)
|
351 |
-
|
352 |
-
# Remove duplicate portions
|
353 |
-
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
354 |
-
|
355 |
-
# Combine the unique parts into a single string of context
|
356 |
-
combined_parts = " ".join(unique_selected_parts)
|
357 |
-
|
358 |
-
# Construct context as a list: first the query, then the unique selected portions
|
359 |
-
context = [query_text] + unique_selected_parts
|
360 |
-
|
361 |
-
# Print the context (query + relevant portions)
|
362 |
-
print(context)
|
363 |
-
|
364 |
-
import pickle
|
365 |
-
|
366 |
-
with open('/kaggle/input/art-embeddings-pkl/embeddings.pkl', 'rb') as file:
|
367 |
-
data = pickle.load(file)
|
368 |
-
|
369 |
-
# Print the type of data
|
370 |
-
print(f"Data type: {type(data)}")
|
371 |
-
|
372 |
-
# Print the first few keys and values from the dictionary
|
373 |
-
print("First few keys and values:")
|
374 |
-
for i, (key, value) in enumerate(data.items()):
|
375 |
-
if i >= 5: # Limit to printing the first 5 key-value pairs
|
376 |
-
break
|
377 |
-
print(f"Key: {key}, Value: {value}")
|
378 |
-
|
379 |
-
import pickle
|
380 |
-
import pickletools
|
381 |
-
|
382 |
-
# Load the pickle file
|
383 |
-
file_path = '/kaggle/input/art-embeddings-pkl/embeddings.pkl'
|
384 |
-
|
385 |
-
with open(file_path, 'rb') as f:
|
386 |
-
# Read the pickle file
|
387 |
-
data = pickle.load(f)
|
388 |
-
|
389 |
-
# Check for suspicious or corrupted entries
|
390 |
-
def inspect_pickle(data):
|
391 |
-
for key, value in data.items():
|
392 |
-
if isinstance(value, (str, bytes)):
|
393 |
-
# Try to decode and catch any non-ASCII issues
|
394 |
-
try:
|
395 |
-
value.decode('ascii')
|
396 |
-
except UnicodeDecodeError as e:
|
397 |
-
print(f"Non-ASCII entry found in key: {key}")
|
398 |
-
print(f"Corrupted data: {value} ({e})")
|
399 |
-
continue
|
400 |
-
|
401 |
-
if isinstance(value, list) and any(isinstance(v, (list, dict, str, bytes)) for v in value):
|
402 |
-
# Inspect list elements recursively
|
403 |
-
inspect_pickle({f"{key}[{idx}]": v for idx, v in enumerate(value)})
|
404 |
-
|
405 |
-
# Inspect the data
|
406 |
-
inspect_pickle(data)
|
407 |
-
|
408 |
-
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
409 |
-
import torch
|
410 |
-
import time
|
411 |
-
|
412 |
-
# Load Biobert model and tokenizer
|
413 |
-
biobert_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
414 |
-
biobert_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
415 |
-
|
416 |
-
def extract_entities(text):
|
417 |
-
inputs = biobert_tokenizer(text, return_tensors="pt")
|
418 |
-
outputs = biobert_model(**inputs)
|
419 |
-
predictions = torch.argmax(outputs.logits, dim=2)
|
420 |
-
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
421 |
-
entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
|
422 |
-
return entities
|
423 |
-
|
424 |
-
def enhance_passage_with_entities(passage, entities):
|
425 |
-
# Example: Add entities to the passage for better context
|
426 |
-
return f"{passage}\n\nEntities: {', '.join(entities)}"
|
427 |
-
|
428 |
-
def create_prompt(question, passage):
|
429 |
-
prompt = ("""
|
430 |
-
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
431 |
-
|
432 |
-
Passage: {passage}
|
433 |
-
|
434 |
-
Question: {question}
|
435 |
-
|
436 |
-
Answer:
|
437 |
-
""")
|
438 |
-
return prompt.format(passage=passage, question=question)
|
439 |
-
|
440 |
-
def generate_answer(prompt, max_length=860, temperature=0.2):
|
441 |
-
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
442 |
-
|
443 |
-
# Start timing
|
444 |
-
start_time = time.time()
|
445 |
-
|
446 |
-
output_ids = model_f.generate(
|
447 |
-
inputs.input_ids,
|
448 |
-
max_length=max_length,
|
449 |
-
num_return_sequences=1,
|
450 |
-
temperature=temperature,
|
451 |
-
pad_token_id=tokenizer_f.eos_token_id
|
452 |
-
)
|
453 |
-
|
454 |
-
# End timing
|
455 |
-
end_time = time.time()
|
456 |
-
|
457 |
-
# Calculate the duration
|
458 |
-
duration = end_time - start_time
|
459 |
-
|
460 |
-
# Decode the answer
|
461 |
-
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
462 |
-
|
463 |
-
passage_keywords = set(passage.lower().split())
|
464 |
-
answer_keywords = set(answer.lower().split())
|
465 |
-
|
466 |
-
if passage_keywords.intersection(answer_keywords):
|
467 |
-
return answer, duration
|
468 |
-
else:
|
469 |
-
return "Sorry, I can't help with that.", duration
|
470 |
-
|
471 |
-
# Integrate Biobert model
|
472 |
-
entities = extract_entities(query_text)
|
473 |
-
passage = enhance_passage_with_entities(combined_parts, entities)
|
474 |
-
# Generate answer with the enhanced passage
|
475 |
-
prompt = create_prompt(query_text, passage)
|
476 |
-
answer, generation_time = generate_answer(prompt)
|
477 |
-
print(f"\nTime taken to generate the answer: {generation_time:.2f} seconds")
|
478 |
-
def remove_answer_prefix(text):
|
479 |
-
prefix = "Answer:"
|
480 |
-
if prefix in text:
|
481 |
-
return text.split(prefix)[-1].strip()
|
482 |
-
return text
|
483 |
-
|
484 |
-
def remove_incomplete_sentence(text):
|
485 |
-
# Check if the text ends with a period
|
486 |
-
if not text.endswith('.'):
|
487 |
-
# Find the last period or the end of the string
|
488 |
-
last_period_index = text.rfind('.')
|
489 |
-
if last_period_index != -1:
|
490 |
-
# Remove everything after the last period
|
491 |
-
return text[:last_period_index + 1].strip()
|
492 |
-
return text
|
493 |
-
# Clean and print the answer
|
494 |
-
answer_part = answer.split("Answer:")[-1].strip()
|
495 |
-
cleaned_answer = remove_answer_prefix(answer_part)
|
496 |
-
final_answer = remove_incomplete_sentence(cleaned_answer)
|
497 |
|
498 |
-
if
|
499 |
-
|
|
|
500 |
|
501 |
-
if
|
502 |
-
|
503 |
-
print(final_answer)
|
504 |
-
else:
|
505 |
-
print("Sorry, I can't help with that.")
|
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
AutoModelForSeq2SeqLM,
|
9 |
+
AutoModelForTokenClassification,
|
10 |
AutoModelForCausalLM,
|
11 |
+
pipeline
|
12 |
)
|
13 |
from sentence_transformers import SentenceTransformer, CrossEncoder
|
14 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
16 |
import nltk
|
17 |
import torch
|
18 |
import pandas as pd
|
|
|
|
|
19 |
|
20 |
app = Flask(__name__)
|
21 |
CORS(app)
|
|
|
|
|
|
|
22 |
|
23 |
+
# Global variables for models and data
|
24 |
+
models = {}
|
25 |
+
data = {}
|
|
|
26 |
|
27 |
+
def init_nltk():
|
28 |
+
"""Initialize NLTK resources"""
|
29 |
try:
|
30 |
+
nltk.download('punkt', quiet=True)
|
31 |
+
return True
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Error initializing NLTK: {e}")
|
34 |
+
return False
|
35 |
|
36 |
+
def load_models():
|
37 |
+
"""Initialize all required models"""
|
38 |
+
try:
|
39 |
+
print("Loading models...")
|
40 |
+
|
41 |
+
# Embedding models
|
42 |
+
models['embedding'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
43 |
+
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
45 |
# Translation models
|
46 |
+
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
47 |
+
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
48 |
+
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
49 |
+
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
50 |
+
|
51 |
+
# NER model
|
52 |
+
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
53 |
+
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
54 |
+
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
|
55 |
+
|
56 |
# LLM model
|
57 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
58 |
+
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
|
59 |
+
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
60 |
|
61 |
+
print("Models loaded successfully")
|
62 |
return True
|
63 |
except Exception as e:
|
64 |
+
print(f"Error loading models: {e}")
|
65 |
return False
|
66 |
|
|
|
67 |
def load_data():
|
68 |
+
"""Load embeddings and document data"""
|
69 |
try:
|
70 |
+
print("Loading data...")
|
|
|
|
|
71 |
|
72 |
# Load embeddings
|
73 |
+
with open('embeddings.pkl', 'rb') as f:
|
74 |
+
data['embeddings'] = pickle.load(f)
|
75 |
|
76 |
+
# Load document links
|
77 |
+
data['df'] = pd.read_excel('finalcleaned_excel_file.xlsx')
|
78 |
|
79 |
print("Data loaded successfully")
|
80 |
return True
|
|
|
82 |
print(f"Error loading data: {e}")
|
83 |
return False
|
84 |
|
85 |
+
def translate_text(text, source_to_target='ar_to_en'):
|
86 |
+
"""Translate text between Arabic and English"""
|
87 |
+
try:
|
88 |
+
if source_to_target == 'ar_to_en':
|
89 |
+
tokenizer = models['ar_to_en_tokenizer']
|
90 |
+
model = models['ar_to_en_model']
|
91 |
+
else:
|
92 |
+
tokenizer = models['en_to_ar_tokenizer']
|
93 |
+
model = models['en_to_ar_model']
|
94 |
+
|
95 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True)
|
96 |
+
outputs = model.generate(**inputs)
|
97 |
+
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Translation error: {e}")
|
100 |
+
return text
|
101 |
+
|
102 |
+
def query_embeddings(query_embedding, n_results=5):
|
103 |
+
"""Find relevant documents using embedding similarity"""
|
104 |
+
doc_ids = list(data['embeddings'].keys())
|
105 |
+
doc_embeddings = np.array(list(data['embeddings'].values()))
|
106 |
+
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
107 |
+
top_indices = similarities.argsort()[-n_results:][::-1]
|
108 |
+
return [(doc_ids[i], similarities[i]) for i in top_indices]
|
109 |
+
|
110 |
+
def retrieve_document_text(doc_id):
|
111 |
+
"""Retrieve document text from HTML file"""
|
112 |
+
try:
|
113 |
+
with open(f"downloaded_articles/{doc_id}", 'r', encoding='utf-8') as file:
|
114 |
+
soup = BeautifulSoup(file, 'html.parser')
|
115 |
+
return soup.get_text(separator=' ', strip=True)
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Error retrieving document {doc_id}: {e}")
|
118 |
+
return ""
|
119 |
+
|
120 |
+
def extract_entities(text):
|
121 |
+
"""Extract medical entities from text"""
|
122 |
+
try:
|
123 |
+
results = models['ner_pipeline'](text)
|
124 |
+
return list({result['word'] for result in results if result['entity'].startswith("B-")})
|
125 |
+
except Exception as e:
|
126 |
+
print(f"Error extracting entities: {e}")
|
127 |
+
return []
|
128 |
+
|
129 |
+
def generate_answer(query, context, max_length=860, temperature=0.2):
|
130 |
+
"""Generate answer using LLM"""
|
131 |
+
try:
|
132 |
+
prompt = f"""
|
133 |
+
As a medical expert, answer the following question based only on the provided context:
|
134 |
+
|
135 |
+
Context: {context}
|
136 |
+
Question: {query}
|
137 |
+
|
138 |
+
Answer:"""
|
139 |
+
|
140 |
+
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
|
141 |
+
outputs = models['llm_model'].generate(
|
142 |
+
inputs.input_ids,
|
143 |
+
max_length=max_length,
|
144 |
+
num_return_sequences=1,
|
145 |
+
temperature=temperature,
|
146 |
+
pad_token_id=models['llm_tokenizer'].eos_token_id
|
147 |
+
)
|
148 |
+
|
149 |
+
answer = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
150 |
+
return answer.split("Answer:")[-1].strip()
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error generating answer: {e}")
|
153 |
+
return "Sorry, I couldn't generate an answer at this time."
|
154 |
+
|
155 |
@app.route('/health', methods=['GET'])
|
156 |
def health_check():
|
157 |
+
"""Health check endpoint"""
|
158 |
return jsonify({'status': 'healthy'})
|
159 |
|
160 |
@app.route('/api/query', methods=['POST'])
|
161 |
def process_query():
|
162 |
+
"""Main query processing endpoint"""
|
163 |
try:
|
164 |
data = request.json
|
165 |
if not data or 'query' not in data:
|
|
|
168 |
query_text = data['query']
|
169 |
language_code = data.get('language_code', 0)
|
170 |
|
171 |
+
# Translate if Arabic
|
172 |
if language_code == 0:
|
173 |
+
query_text = translate_text(query_text, 'ar_to_en')
|
174 |
|
175 |
+
# Get query embedding and find relevant documents
|
176 |
+
query_embedding = models['embedding'].encode([query_text])
|
177 |
+
relevant_docs = query_embeddings(query_embedding)
|
178 |
+
|
179 |
+
# Retrieve and process documents
|
180 |
+
doc_texts = [retrieve_document_text(doc_id) for doc_id, _ in relevant_docs]
|
181 |
|
182 |
+
# Extract entities and generate context
|
183 |
+
query_entities = extract_entities(query_text)
|
184 |
+
contexts = []
|
185 |
+
for text in doc_texts:
|
186 |
+
doc_entities = extract_entities(text)
|
187 |
+
if set(query_entities) & set(doc_entities):
|
188 |
+
contexts.append(text)
|
189 |
+
|
190 |
+
context = " ".join(contexts[:3]) # Use top 3 most relevant contexts
|
191 |
|
192 |
# Generate answer
|
193 |
+
answer = generate_answer(query_text, context)
|
|
|
194 |
|
195 |
+
# Translate back if needed
|
196 |
if language_code == 0:
|
197 |
+
answer = translate_text(answer, 'en_to_ar')
|
198 |
|
199 |
return jsonify({
|
200 |
'answer': answer,
|
|
|
207 |
'success': False
|
208 |
}), 500
|
209 |
|
210 |
+
# Initialize everything when the app starts
|
211 |
+
print("Initializing application...")
|
212 |
+
init_success = init_nltk() and load_models() and load_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
+
if not init_success:
|
215 |
+
print("Failed to initialize application")
|
216 |
+
exit(1)
|
217 |
|
218 |
+
if __name__ == "__main__":
|
219 |
+
app.run(host='0.0.0.0', port=7860)
|
|
|
|
|
|