thechaiexperiment commited on
Commit
eee7a65
·
1 Parent(s): 8e3b5f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +468 -189
app.py CHANGED
@@ -1,226 +1,505 @@
1
- from fastapi import FastAPI, HTTPException, Query
2
- from pydantic import BaseModel
3
- from typing import List, Optional, Dict
4
  import pickle
5
  import numpy as np
6
- from sklearn.metrics.pairwise import cosine_similarity
7
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
8
- from bs4 import BeautifulSoup
9
- import os
10
- import nltk
11
- import torch
12
  from transformers import (
13
  AutoTokenizer,
14
- BartForConditionalGeneration,
15
  AutoModelForCausalLM,
16
- AutoModelForSeq2SeqLM,
17
  AutoModelForTokenClassification
18
  )
 
 
 
 
 
19
  import pandas as pd
20
- import time
21
 
22
- app = FastAPI()
23
-
24
- # ArticleEmbeddingUnpickler and safe_load_embeddings functions
25
- class ArticleEmbeddingUnpickler(pickle.Unpickler):
26
- """Custom unpickler for article embeddings with enhanced persistence handling"""
27
- def find_class(self, module: str, name: str) -> any:
28
- if module == 'numpy':
29
- return getattr(np, name)
30
- if module == 'sentence_transformers.SentenceTransformer':
31
- from sentence_transformers import SentenceTransformer
32
- return SentenceTransformer
33
- return super().find_class(module, name)
34
-
35
- def persistent_load(self, pid: any) -> str:
36
- """Enhanced persistent ID handler with better encoding management"""
37
- try:
38
- if isinstance(pid, bytes):
39
- return pid.decode('utf-8', errors='replace')
40
- if isinstance(pid, (str, int, float)):
41
- return str(pid)
42
- return repr(pid)
43
- except Exception as e:
44
- print(f"Warning: Error in persistent_load: {str(e)}")
45
- return repr(pid)
46
-
47
- def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
48
- """Load embeddings with enhanced error handling, validation, and persistent ID support."""
49
- def persistent_load(pid):
50
- print(f"Warning: Persistent ID encountered: {pid}")
51
- raise ValueError("Persistent IDs are not supported in this application")
52
 
53
- try:
54
- if not os.path.exists(file_path):
55
- raise FileNotFoundError(f"Embeddings file not found at {file_path}")
 
 
56
 
57
- with open(file_path, 'rb') as file:
58
- unpickler = pickle.Unpickler(file)
59
- unpickler.persistent_load = persistent_load # Assign custom handler
60
- embeddings_data = unpickler.load()
61
 
62
- if not isinstance(embeddings_data, dict):
63
- raise ValueError(f"Invalid data structure: expected dict, got {type(embeddings_data)}")
 
 
 
 
 
64
 
65
- # Process and validate embeddings
66
- valid_embeddings = {}
67
- for key, value in embeddings_data.items():
68
- try:
69
- key_str = str(key).encode('ascii', errors='replace').decode('ascii').strip()
70
- if not key_str:
71
- print(f"Skipping invalid or empty key: {key}")
72
- continue
73
-
74
- if isinstance(value, list):
75
- value = np.array(value, dtype=np.float32)
76
- elif isinstance(value, np.ndarray):
77
- value = value.astype(np.float32)
78
- else:
79
- print(f"Skipping invalid embedding type for key {key_str}: {type(value)}")
80
- continue
81
-
82
- if value.ndim != 1 or np.isnan(value).any() or np.isinf(value).any():
83
- print(f"Skipping embedding with invalid shape or values for key {key_str}")
84
- continue
85
-
86
- valid_embeddings[key_str] = value
87
- except Exception as e:
88
- print(f"Error processing embedding for key {key}: {str(e)}")
89
- continue
90
-
91
- if not valid_embeddings:
92
- raise ValueError("No valid embeddings found in file")
93
-
94
- print(f"Successfully loaded {len(valid_embeddings)} valid embeddings")
95
- return valid_embeddings
96
 
 
 
 
 
97
  except Exception as e:
98
- print(f"Error loading embeddings: {str(e)}")
99
- raise
100
-
101
- def safe_save_embeddings(embeddings_dict, file_path='embeddings.pkl'):
102
- cleaned_embeddings = {
103
- str(key): value
104
- for key, value in embeddings_dict.items()
105
- }
106
- with open(file_path, 'wb') as f:
107
- pickle.dump(cleaned_embeddings, f, protocol=4)
108
-
109
-
110
- # GlobalModels and load_models
111
- class GlobalModels:
112
- embedding_model = None
113
- cross_encoder = None
114
- semantic_model = None
115
- tokenizer = None
116
- model = None
117
- tokenizer_f = None
118
- model_f = None
119
- ar_to_en_tokenizer = None
120
- ar_to_en_model = None
121
- en_to_ar_tokenizer = None
122
- en_to_ar_model = None
123
- bio_tokenizer = None
124
- bio_model = None
125
- embeddings_data = None
126
- file_name_to_url = None
127
-
128
- global_models = GlobalModels()
129
-
130
- @app.on_event("startup")
131
- async def load_models():
132
  try:
133
- global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
134
- global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
135
- global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
138
- global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
 
139
 
140
- model_name = "M4-ai/Orca-2.0-Tau-1.8B"
141
- global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
142
- global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
143
 
144
- global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
145
- global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
146
- global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
147
- global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
148
 
149
- global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
150
- global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
 
151
 
152
- try:
153
- with open('embeddings.pkl', 'rb') as file:
154
- global_models.embeddings_data = pickle.load(file)
155
- except (FileNotFoundError, pickle.UnpicklingError) as e:
156
- print(f"Error loading embeddings data: {e}")
157
- raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
158
 
159
- df = pd.read_excel('finalcleaned_excel_file.xlsx')
160
- global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
 
161
 
162
- except Exception as e:
163
- print(f"Error loading models: {e}")
164
- raise HTTPException(status_code=500, detail="Failed to load models.")
165
-
 
 
166
 
167
- # Query and Document Retrieval Endpoint
168
- class QueryInput(BaseModel):
169
- query_text: str
170
- language_code: int # 0 for Arabic, 1 for English
171
- query_type: str # "profile" or "question"
172
- previous_qa: Optional[List[Dict[str, str]]] = None
173
 
174
- class DocumentResponse(BaseModel):
175
- title: str
176
- url: str
177
- text: str
178
- score: float
179
 
180
- @app.post("/retrieve_documents")
181
- async def retrieve_documents(input_data: QueryInput):
182
  try:
183
- processed_query = process_query(input_data.query_text, input_data.language_code)
184
- query_embedding = embed_query_text(processed_query)
185
- results = query_embeddings(query_embedding)
186
-
187
- document_ids = [doc_id for doc_id, _ in results]
188
- document_texts = retrieve_document_texts(document_ids)
189
- scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
190
-
191
- documents = []
192
- for score, doc_id, text in zip(scores, document_ids, document_texts):
193
- url = global_models.file_name_to_url.get(doc_id, "")
194
- documents.append({
195
- "title": doc_id,
196
- "url": url,
197
- "text": text if input_data.language_code == 1 else translate_en_to_ar(text),
198
- "score": float(score)
199
- })
200
- return documents
201
 
 
 
 
 
 
202
  except Exception as e:
203
- raise HTTPException(status_code=500, detail=f"Error retrieving documents: {str(e)}")
 
 
 
204
 
205
- def process_query(query_text: str, language_code: int) -> str:
 
 
206
  if language_code == 0:
207
- return translate_ar_to_en(query_text) # Translate Arabic to English if required
 
208
  return query_text
209
 
210
- def embed_query_text(query_text: str) -> np.ndarray:
211
- return global_models.embedding_model.encode(query_text, convert_to_tensor=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- def query_embeddings(query_embedding: np.ndarray, top_n: int = 10) -> List[tuple]:
214
- doc_embeddings = list(global_models.embeddings_data.values())
215
- document_ids = list(global_models.embeddings_data.keys())
216
- similarities = cosine_similarity(query_embedding, doc_embeddings)
217
- top_indices = np.argsort(similarities[0])[-top_n:]
218
- return [(document_ids[idx], similarities[0][idx]) for idx in reversed(top_indices)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- def retrieve_document_texts(document_ids: List[str]) -> List[str]:
221
- return [global_models.file_name_to_url[doc_id] for doc_id in document_ids]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- def translate_en_to_ar(text: str) -> str:
224
- # Translation logic here, possibly using `transformers` or another library
225
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
2
  import pickle
3
  import numpy as np
4
+ from flask import Flask, request, jsonify
5
+ from flask_cors import CORS
 
 
 
 
6
  from transformers import (
7
  AutoTokenizer,
8
+ AutoModelForSeq2SeqLM,
9
  AutoModelForCausalLM,
 
10
  AutoModelForTokenClassification
11
  )
12
+ from sentence_transformers import SentenceTransformer, CrossEncoder
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+ from bs4 import BeautifulSoup
15
+ import nltk
16
+ import torch
17
  import pandas as pd
18
+ from startup import setup_files
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ app = Flask(__name__)
22
+ CORS(app)
23
+ # Environment variables for file paths
24
+ EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/embeddings.pkl')
25
+ LINKS_PATH = os.environ.get('LINKS_PATH', 'data/finalcleaned_excel_file.xlsx')
26
 
27
+ def init_app():
28
+ # Download and extract files if they don't exist
29
+ if not os.path.exists('downloaded_articles'):
30
+ setup_files()
31
 
32
+ # Initialize models with proper error handling
33
+ def initialize_models():
34
+ try:
35
+ global embedding_model, cross_encoder, semantic_model
36
+ global ar_to_en_tokenizer, ar_to_en_model
37
+ global en_to_ar_tokenizer, en_to_ar_model
38
+ global tokenizer_f, model_f, bio_tokenizer, bio_model
39
 
40
+ print("Initializing models...")
41
+
42
+ # Basic embedding models
43
+ embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
44
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
45
+ semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
46
+
47
+ # Translation models
48
+ ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
49
+ ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
50
+ en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
51
+ en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
52
+
53
+ # Medical NER model
54
+ bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
55
+ bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
56
+
57
+ # LLM model
58
+ model_name = "M4-ai/Orca-2.0-Tau-1.8B"
59
+ tokenizer_f = AutoTokenizer.from_pretrained(model_name)
60
+ model_f = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
61
 
62
+ nltk.download('punkt', quiet=True)
63
+
64
+ print("Models initialized successfully")
65
+ return True
66
  except Exception as e:
67
+ print(f"Error initializing models: {e}")
68
+ return False
69
+
70
+ # Load data with error handling
71
+ def load_data():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
+ global embeddings_data, df
74
+
75
+ print("Loading data files...")
76
+
77
+ # Load embeddings
78
+ with open(EMBEDDINGS_PATH, 'rb') as file:
79
+ embeddings_data = pickle.load(file)
80
+
81
+ # Load links data
82
+ df = pd.read_excel(LINKS_PATH)
83
+
84
+ print("Data loaded successfully")
85
+ return True
86
+ except Exception as e:
87
+ print(f"Error loading data: {e}")
88
+ return False
89
 
90
+ @app.route('/health', methods=['GET'])
91
+ def health_check():
92
+ return jsonify({'status': 'healthy'})
93
 
94
+ @app.route('/api/query', methods=['POST'])
95
+ def process_query():
96
+ try:
97
+ data = request.json
98
+ if not data or 'query' not in data:
99
+ return jsonify({'error': 'No query provided', 'success': False}), 400
100
 
101
+ query_text = data['query']
102
+ language_code = data.get('language_code', 0)
 
 
103
 
104
+ # Process query
105
+ if language_code == 0:
106
+ query_text = translate_ar_to_en(query_text)
107
 
108
+ # Get embeddings and find relevant documents
109
+ query_embedding = embedding_model.encode([query_text])
110
+ initial_results = query_embeddings(query_embedding, embeddings_data)
 
 
 
111
 
112
+ # Process documents
113
+ document_texts = retrieve_document_texts([doc_id for doc_id, _ in initial_results])
114
+ relevant_portions = extract_relevant_portions(document_texts, query_text)
115
 
116
+ # Generate answer
117
+ combined_text = " ".join([item for sublist in relevant_portions.values() for item in sublist])
118
+ answer = generate_answer(query_text, combined_text)
119
+
120
+ if language_code == 0:
121
+ answer = translate_en_to_ar(answer)
122
 
123
+ return jsonify({
124
+ 'answer': answer,
125
+ 'success': True
126
+ })
 
 
127
 
128
+ except Exception as e:
129
+ return jsonify({
130
+ 'error': str(e),
131
+ 'success': False
132
+ }), 500
133
 
134
+ def translate_ar_to_en(text):
 
135
  try:
136
+ inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True)
137
+ outputs = ar_to_en_model.generate(**inputs)
138
+ return ar_to_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
139
+ except Exception as e:
140
+ print(f"Translation error (AR->EN): {e}")
141
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def translate_en_to_ar(text):
144
+ try:
145
+ inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True)
146
+ outputs = en_to_ar_model.generate(**inputs)
147
+ return en_to_ar_tokenizer.decode(outputs[0], skip_special_tokens=True)
148
  except Exception as e:
149
+ print(f"Translation error (EN->AR): {e}")
150
+ return text
151
+
152
+ language_code = 0
153
 
154
+ query_text = 'How can a patient with chronic kidney disease manage their daily activities and maintain quality of life?' #'symptoms of a heart attack '
155
+
156
+ def process_query(query_text):
157
  if language_code == 0:
158
+ # Translate Arabic input to English
159
+ query_text = translate_ar_to_en(query_text)
160
  return query_text
161
 
162
+ def embed_query_text(query_text):
163
+ query_embedding = embedding_model.encode([query_text])
164
+ return query_embedding
165
+
166
+ def query_embeddings(query_embedding, embeddings_data, n_results=5):
167
+ doc_ids = list(embeddings_data.keys())
168
+ doc_embeddings = np.array(list(embeddings_data.values()))
169
+ similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
170
+ top_indices = similarities.argsort()[-n_results:][::-1]
171
+ top_docs = [(doc_ids[i], similarities[i]) for i in top_indices]
172
+
173
+ return top_docs
174
+
175
+ query_embedding = embed_query_text(query_text) # Embed the query text
176
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
177
+ document_ids = [doc_id for doc_id, _ in initial_results]
178
+ print(document_ids)
179
 
180
+ import pandas as pd
181
+ import requests
182
+ from bs4 import BeautifulSoup
183
+
184
+ # Load the Excel file
185
+ file_path = '/kaggle/input/final-links/finalcleaned_excel_file.xlsx'
186
+ df = pd.read_excel(file_path)
187
+
188
+
189
+ # Create a dictionary mapping file names to URLs
190
+ # Assuming the DataFrame index corresponds to file names
191
+ file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
192
+ def get_page_title(url):
193
+ try:
194
+ response = requests.get(url)
195
+ if response.status_code == 200:
196
+ soup = BeautifulSoup(response.text, 'html.parser')
197
+ title = soup.find('title')
198
+ return title.get_text() if title else "No title found"
199
+ else:
200
+ return None
201
+ except requests.exceptions.RequestException:
202
+ return None
203
+ # Example file names
204
+ file_names = document_ids
205
+
206
+ # Retrieve original URLs
207
+ for file_name in file_names:
208
+ original_url = file_name_to_url.get(file_name, None)
209
+ if original_url:
210
+ title = get_page_title(original_url)
211
+ if title:
212
+ print(f"Title: {title},URL: {original_url}")
213
+ else:
214
+ print(f"Name: {file_name}")
215
+ else:
216
+ print(f"Name: {file_name}")
217
+
218
+ def retrieve_document_texts(doc_ids, folder_path):
219
+ texts = []
220
+ for doc_id in doc_ids:
221
+ file_path = os.path.join(folder_path, doc_id)
222
+ try:
223
+ with open(file_path, 'r', encoding='utf-8') as file:
224
+ soup = BeautifulSoup(file, 'html.parser')
225
+ text = soup.get_text(separator=' ', strip=True)
226
+ texts.append(text)
227
+ except FileNotFoundError:
228
+ texts.append("")
229
+ return texts
230
+ document_ids = [doc_id for doc_id, _ in initial_results]
231
+ document_texts = retrieve_document_texts(document_ids, folder_path)
232
+
233
+ # Rerank the results using the CrossEncoder
234
+ scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
235
+ scored_documents = list(zip(scores, document_ids, document_texts))
236
+ scored_documents.sort(key=lambda x: x[0], reverse=True)
237
+ print("Reranked results:")
238
+ for idx, (score, doc_id, doc) in enumerate(scored_documents):
239
+ print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id}")
240
+
241
+ from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
242
+ import nltk
243
 
244
+ # Load BioBERT model and tokenizer for NER
245
+ bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
246
+ bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
247
+ ner_biobert = pipeline("ner", model=bio_model, tokenizer=bio_tokenizer)
248
+
249
+ def extract_entities(text, ner_pipeline):
250
+ """
251
+ Extract entities using a NER pipeline.
252
+ Args:
253
+ text (str): The text from which to extract entities.
254
+ ner_pipeline (pipeline): The NER pipeline for entity extraction.
255
+ Returns:
256
+ List[str]: A list of unique extracted entities.
257
+ """
258
+ ner_results = ner_pipeline(text)
259
+ entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
260
+ return list(entities)
261
+
262
+ def match_entities(query_entities, sentence_entities):
263
+ """
264
+ Compute the relevance score based on entity matching.
265
+ Args:
266
+ query_entities (List[str]): Entities extracted from the query.
267
+ sentence_entities (List[str]): Entities extracted from the sentence.
268
+ Returns:
269
+ float: The relevance score based on entity overlap.
270
+ """
271
+ query_set, sentence_set = set(query_entities), set(sentence_entities)
272
+ matches = query_set.intersection(sentence_set)
273
+ return len(matches)
274
+
275
+ def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
276
+ """
277
+ Extract relevant text portions from documents based on entity matching.
278
+ Args:
279
+ document_texts (List[str]): List of document texts.
280
+ query (str): The query text.
281
+ max_portions (int): Maximum number of relevant portions to extract per document.
282
+ portion_size (int): Number of sentences to include in each portion.
283
+ min_query_words (int): Minimum number of matching entities to consider a sentence relevant.
284
+ Returns:
285
+ Dict[str, List[str]]: Relevant portions for each document.
286
+ """
287
+ relevant_portions = {}
288
+
289
+ # Extract entities from the query
290
+ query_entities = extract_entities(query, ner_biobert)
291
+ print(f"Extracted Query Entities: {query_entities}")
292
+
293
+ for doc_id, doc_text in enumerate(document_texts):
294
+ sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
295
+ doc_relevant_portions = []
296
+
297
+ # Extract entities from the entire document
298
+ doc_entities = extract_entities(doc_text, ner_biobert)
299
+ print(f"Document {doc_id} Entities: {doc_entities}")
300
+
301
+ for i, sentence in enumerate(sentences):
302
+ # Extract entities from the sentence
303
+ sentence_entities = extract_entities(sentence, ner_biobert)
304
+
305
+ # Compute relevance score
306
+ relevance_score = match_entities(query_entities, sentence_entities)
307
+
308
+ # Select sentences with at least `min_query_words` matching entities
309
+ if relevance_score >= min_query_words:
310
+ start_idx = max(0, i - portion_size // 2)
311
+ end_idx = min(len(sentences), i + portion_size // 2 + 1)
312
+ portion = " ".join(sentences[start_idx:end_idx])
313
+ doc_relevant_portions.append(portion)
314
+
315
+ if len(doc_relevant_portions) >= max_portions:
316
+ break
317
+
318
+ # Add fallback to include the most entity-dense sentences if no results
319
+ if not doc_relevant_portions and len(doc_entities) > 0:
320
+ print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
321
+ sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
322
+ for fallback_sentence in sorted_sentences[:max_portions]:
323
+ doc_relevant_portions.append(fallback_sentence)
324
+
325
+ relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
326
+
327
+ return relevant_portions
328
+
329
+ # Extract relevant portions based on query and documents
330
+ relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
331
+
332
+ for doc_id, portions in relevant_portions.items():
333
+ print(f"{doc_id}: {portions}")
334
+
335
+ # Remove duplicates from the selected portions
336
+ def remove_duplicates(selected_parts):
337
+ unique_sentences = set()
338
+ unique_selected_parts = []
339
+
340
+ for sentence in selected_parts:
341
+ if sentence not in unique_sentences:
342
+ unique_selected_parts.append(sentence)
343
+ unique_sentences.add(sentence)
344
+
345
+ return unique_selected_parts
346
+
347
+ # Flatten the dictionary of relevant portions (from earlier code)
348
+ flattened_relevant_portions = []
349
+ for doc_id, portions in relevant_portions.items():
350
+ flattened_relevant_portions.extend(portions)
351
+
352
+ # Remove duplicate portions
353
+ unique_selected_parts = remove_duplicates(flattened_relevant_portions)
354
+
355
+ # Combine the unique parts into a single string of context
356
+ combined_parts = " ".join(unique_selected_parts)
357
+
358
+ # Construct context as a list: first the query, then the unique selected portions
359
+ context = [query_text] + unique_selected_parts
360
+
361
+ # Print the context (query + relevant portions)
362
+ print(context)
363
 
364
+ import pickle
365
+
366
+ with open('/kaggle/input/art-embeddings-pkl/embeddings.pkl', 'rb') as file:
367
+ data = pickle.load(file)
368
+
369
+ # Print the type of data
370
+ print(f"Data type: {type(data)}")
371
+
372
+ # Print the first few keys and values from the dictionary
373
+ print("First few keys and values:")
374
+ for i, (key, value) in enumerate(data.items()):
375
+ if i >= 5: # Limit to printing the first 5 key-value pairs
376
+ break
377
+ print(f"Key: {key}, Value: {value}")
378
+
379
+ import pickle
380
+ import pickletools
381
+
382
+ # Load the pickle file
383
+ file_path = '/kaggle/input/art-embeddings-pkl/embeddings.pkl'
384
+
385
+ with open(file_path, 'rb') as f:
386
+ # Read the pickle file
387
+ data = pickle.load(f)
388
+
389
+ # Check for suspicious or corrupted entries
390
+ def inspect_pickle(data):
391
+ for key, value in data.items():
392
+ if isinstance(value, (str, bytes)):
393
+ # Try to decode and catch any non-ASCII issues
394
+ try:
395
+ value.decode('ascii')
396
+ except UnicodeDecodeError as e:
397
+ print(f"Non-ASCII entry found in key: {key}")
398
+ print(f"Corrupted data: {value} ({e})")
399
+ continue
400
+
401
+ if isinstance(value, list) and any(isinstance(v, (list, dict, str, bytes)) for v in value):
402
+ # Inspect list elements recursively
403
+ inspect_pickle({f"{key}[{idx}]": v for idx, v in enumerate(value)})
404
+
405
+ # Inspect the data
406
+ inspect_pickle(data)
407
+
408
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
409
+ import torch
410
+ import time
411
 
412
+ # Load Biobert model and tokenizer
413
+ biobert_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
414
+ biobert_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
415
+
416
+ def extract_entities(text):
417
+ inputs = biobert_tokenizer(text, return_tensors="pt")
418
+ outputs = biobert_model(**inputs)
419
+ predictions = torch.argmax(outputs.logits, dim=2)
420
+ tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
421
+ entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
422
+ return entities
423
+
424
+ def enhance_passage_with_entities(passage, entities):
425
+ # Example: Add entities to the passage for better context
426
+ return f"{passage}\n\nEntities: {', '.join(entities)}"
427
+
428
+ def create_prompt(question, passage):
429
+ prompt = ("""
430
+ As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
431
+
432
+ Passage: {passage}
433
+
434
+ Question: {question}
435
+
436
+ Answer:
437
+ """)
438
+ return prompt.format(passage=passage, question=question)
439
+
440
+ def generate_answer(prompt, max_length=860, temperature=0.2):
441
+ inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
442
+
443
+ # Start timing
444
+ start_time = time.time()
445
+
446
+ output_ids = model_f.generate(
447
+ inputs.input_ids,
448
+ max_length=max_length,
449
+ num_return_sequences=1,
450
+ temperature=temperature,
451
+ pad_token_id=tokenizer_f.eos_token_id
452
+ )
453
+
454
+ # End timing
455
+ end_time = time.time()
456
+
457
+ # Calculate the duration
458
+ duration = end_time - start_time
459
+
460
+ # Decode the answer
461
+ answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
462
+
463
+ passage_keywords = set(passage.lower().split())
464
+ answer_keywords = set(answer.lower().split())
465
+
466
+ if passage_keywords.intersection(answer_keywords):
467
+ return answer, duration
468
+ else:
469
+ return "Sorry, I can't help with that.", duration
470
+
471
+ # Integrate Biobert model
472
+ entities = extract_entities(query_text)
473
+ passage = enhance_passage_with_entities(combined_parts, entities)
474
+ # Generate answer with the enhanced passage
475
+ prompt = create_prompt(query_text, passage)
476
+ answer, generation_time = generate_answer(prompt)
477
+ print(f"\nTime taken to generate the answer: {generation_time:.2f} seconds")
478
+ def remove_answer_prefix(text):
479
+ prefix = "Answer:"
480
+ if prefix in text:
481
+ return text.split(prefix)[-1].strip()
482
+ return text
483
+
484
+ def remove_incomplete_sentence(text):
485
+ # Check if the text ends with a period
486
+ if not text.endswith('.'):
487
+ # Find the last period or the end of the string
488
+ last_period_index = text.rfind('.')
489
+ if last_period_index != -1:
490
+ # Remove everything after the last period
491
+ return text[:last_period_index + 1].strip()
492
+ return text
493
+ # Clean and print the answer
494
+ answer_part = answer.split("Answer:")[-1].strip()
495
+ cleaned_answer = remove_answer_prefix(answer_part)
496
+ final_answer = remove_incomplete_sentence(cleaned_answer)
497
+
498
+ if language_code == 0:
499
+ final_answer = translate_en_to_ar(final_answer)
500
+
501
+ if final_answer:
502
+ print("Answer:")
503
+ print(final_answer)
504
+ else:
505
+ print("Sorry, I can't help with that.")