Spaces:
Sleeping
Sleeping
Commit
·
eee7a65
1
Parent(s):
8e3b5f7
Update app.py
Browse files
app.py
CHANGED
@@ -1,226 +1,505 @@
|
|
1 |
-
|
2 |
-
from pydantic import BaseModel
|
3 |
-
from typing import List, Optional, Dict
|
4 |
import pickle
|
5 |
import numpy as np
|
6 |
-
from
|
7 |
-
from
|
8 |
-
from bs4 import BeautifulSoup
|
9 |
-
import os
|
10 |
-
import nltk
|
11 |
-
import torch
|
12 |
from transformers import (
|
13 |
AutoTokenizer,
|
14 |
-
|
15 |
AutoModelForCausalLM,
|
16 |
-
AutoModelForSeq2SeqLM,
|
17 |
AutoModelForTokenClassification
|
18 |
)
|
|
|
|
|
|
|
|
|
|
|
19 |
import pandas as pd
|
20 |
-
import
|
21 |
|
22 |
-
app = FastAPI()
|
23 |
-
|
24 |
-
# ArticleEmbeddingUnpickler and safe_load_embeddings functions
|
25 |
-
class ArticleEmbeddingUnpickler(pickle.Unpickler):
|
26 |
-
"""Custom unpickler for article embeddings with enhanced persistence handling"""
|
27 |
-
def find_class(self, module: str, name: str) -> any:
|
28 |
-
if module == 'numpy':
|
29 |
-
return getattr(np, name)
|
30 |
-
if module == 'sentence_transformers.SentenceTransformer':
|
31 |
-
from sentence_transformers import SentenceTransformer
|
32 |
-
return SentenceTransformer
|
33 |
-
return super().find_class(module, name)
|
34 |
-
|
35 |
-
def persistent_load(self, pid: any) -> str:
|
36 |
-
"""Enhanced persistent ID handler with better encoding management"""
|
37 |
-
try:
|
38 |
-
if isinstance(pid, bytes):
|
39 |
-
return pid.decode('utf-8', errors='replace')
|
40 |
-
if isinstance(pid, (str, int, float)):
|
41 |
-
return str(pid)
|
42 |
-
return repr(pid)
|
43 |
-
except Exception as e:
|
44 |
-
print(f"Warning: Error in persistent_load: {str(e)}")
|
45 |
-
return repr(pid)
|
46 |
-
|
47 |
-
def safe_load_embeddings(file_path: str = 'embeddings.pkl') -> Dict[str, np.ndarray]:
|
48 |
-
"""Load embeddings with enhanced error handling, validation, and persistent ID support."""
|
49 |
-
def persistent_load(pid):
|
50 |
-
print(f"Warning: Persistent ID encountered: {pid}")
|
51 |
-
raise ValueError("Persistent IDs are not supported in this application")
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
valid_embeddings[key_str] = value
|
87 |
-
except Exception as e:
|
88 |
-
print(f"Error processing embedding for key {key}: {str(e)}")
|
89 |
-
continue
|
90 |
-
|
91 |
-
if not valid_embeddings:
|
92 |
-
raise ValueError("No valid embeddings found in file")
|
93 |
-
|
94 |
-
print(f"Successfully loaded {len(valid_embeddings)} valid embeddings")
|
95 |
-
return valid_embeddings
|
96 |
|
|
|
|
|
|
|
|
|
97 |
except Exception as e:
|
98 |
-
print(f"Error
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
str(key): value
|
104 |
-
for key, value in embeddings_dict.items()
|
105 |
-
}
|
106 |
-
with open(file_path, 'wb') as f:
|
107 |
-
pickle.dump(cleaned_embeddings, f, protocol=4)
|
108 |
-
|
109 |
-
|
110 |
-
# GlobalModels and load_models
|
111 |
-
class GlobalModels:
|
112 |
-
embedding_model = None
|
113 |
-
cross_encoder = None
|
114 |
-
semantic_model = None
|
115 |
-
tokenizer = None
|
116 |
-
model = None
|
117 |
-
tokenizer_f = None
|
118 |
-
model_f = None
|
119 |
-
ar_to_en_tokenizer = None
|
120 |
-
ar_to_en_model = None
|
121 |
-
en_to_ar_tokenizer = None
|
122 |
-
en_to_ar_model = None
|
123 |
-
bio_tokenizer = None
|
124 |
-
bio_model = None
|
125 |
-
embeddings_data = None
|
126 |
-
file_name_to_url = None
|
127 |
-
|
128 |
-
global_models = GlobalModels()
|
129 |
-
|
130 |
-
@app.on_event("startup")
|
131 |
-
async def load_models():
|
132 |
try:
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
|
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
147 |
-
global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
148 |
|
149 |
-
|
150 |
-
|
|
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
except (FileNotFoundError, pickle.UnpicklingError) as e:
|
156 |
-
print(f"Error loading embeddings data: {e}")
|
157 |
-
raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
|
158 |
|
159 |
-
|
160 |
-
|
|
|
161 |
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
query_type: str # "profile" or "question"
|
172 |
-
previous_qa: Optional[List[Dict[str, str]]] = None
|
173 |
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
|
180 |
-
|
181 |
-
async def retrieve_documents(input_data: QueryInput):
|
182 |
try:
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
scores = global_models.cross_encoder.predict([(processed_query, doc) for doc in document_texts])
|
190 |
-
|
191 |
-
documents = []
|
192 |
-
for score, doc_id, text in zip(scores, document_ids, document_texts):
|
193 |
-
url = global_models.file_name_to_url.get(doc_id, "")
|
194 |
-
documents.append({
|
195 |
-
"title": doc_id,
|
196 |
-
"url": url,
|
197 |
-
"text": text if input_data.language_code == 1 else translate_en_to_ar(text),
|
198 |
-
"score": float(score)
|
199 |
-
})
|
200 |
-
return documents
|
201 |
|
|
|
|
|
|
|
|
|
|
|
202 |
except Exception as e:
|
203 |
-
|
|
|
|
|
|
|
204 |
|
205 |
-
|
|
|
|
|
206 |
if language_code == 0:
|
207 |
-
|
|
|
208 |
return query_text
|
209 |
|
210 |
-
def embed_query_text(query_text
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
-
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
|
|
|
|
2 |
import pickle
|
3 |
import numpy as np
|
4 |
+
from flask import Flask, request, jsonify
|
5 |
+
from flask_cors import CORS
|
|
|
|
|
|
|
|
|
6 |
from transformers import (
|
7 |
AutoTokenizer,
|
8 |
+
AutoModelForSeq2SeqLM,
|
9 |
AutoModelForCausalLM,
|
|
|
10 |
AutoModelForTokenClassification
|
11 |
)
|
12 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
13 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
14 |
+
from bs4 import BeautifulSoup
|
15 |
+
import nltk
|
16 |
+
import torch
|
17 |
import pandas as pd
|
18 |
+
from startup import setup_files
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
app = Flask(__name__)
|
22 |
+
CORS(app)
|
23 |
+
# Environment variables for file paths
|
24 |
+
EMBEDDINGS_PATH = os.environ.get('EMBEDDINGS_PATH', 'data/embeddings.pkl')
|
25 |
+
LINKS_PATH = os.environ.get('LINKS_PATH', 'data/finalcleaned_excel_file.xlsx')
|
26 |
|
27 |
+
def init_app():
|
28 |
+
# Download and extract files if they don't exist
|
29 |
+
if not os.path.exists('downloaded_articles'):
|
30 |
+
setup_files()
|
31 |
|
32 |
+
# Initialize models with proper error handling
|
33 |
+
def initialize_models():
|
34 |
+
try:
|
35 |
+
global embedding_model, cross_encoder, semantic_model
|
36 |
+
global ar_to_en_tokenizer, ar_to_en_model
|
37 |
+
global en_to_ar_tokenizer, en_to_ar_model
|
38 |
+
global tokenizer_f, model_f, bio_tokenizer, bio_model
|
39 |
|
40 |
+
print("Initializing models...")
|
41 |
+
|
42 |
+
# Basic embedding models
|
43 |
+
embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
44 |
+
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
45 |
+
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
|
46 |
+
|
47 |
+
# Translation models
|
48 |
+
ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
49 |
+
ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
50 |
+
en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
51 |
+
en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
52 |
+
|
53 |
+
# Medical NER model
|
54 |
+
bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
55 |
+
bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
56 |
+
|
57 |
+
# LLM model
|
58 |
+
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
59 |
+
tokenizer_f = AutoTokenizer.from_pretrained(model_name)
|
60 |
+
model_f = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
nltk.download('punkt', quiet=True)
|
63 |
+
|
64 |
+
print("Models initialized successfully")
|
65 |
+
return True
|
66 |
except Exception as e:
|
67 |
+
print(f"Error initializing models: {e}")
|
68 |
+
return False
|
69 |
+
|
70 |
+
# Load data with error handling
|
71 |
+
def load_data():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
try:
|
73 |
+
global embeddings_data, df
|
74 |
+
|
75 |
+
print("Loading data files...")
|
76 |
+
|
77 |
+
# Load embeddings
|
78 |
+
with open(EMBEDDINGS_PATH, 'rb') as file:
|
79 |
+
embeddings_data = pickle.load(file)
|
80 |
+
|
81 |
+
# Load links data
|
82 |
+
df = pd.read_excel(LINKS_PATH)
|
83 |
+
|
84 |
+
print("Data loaded successfully")
|
85 |
+
return True
|
86 |
+
except Exception as e:
|
87 |
+
print(f"Error loading data: {e}")
|
88 |
+
return False
|
89 |
|
90 |
+
@app.route('/health', methods=['GET'])
|
91 |
+
def health_check():
|
92 |
+
return jsonify({'status': 'healthy'})
|
93 |
|
94 |
+
@app.route('/api/query', methods=['POST'])
|
95 |
+
def process_query():
|
96 |
+
try:
|
97 |
+
data = request.json
|
98 |
+
if not data or 'query' not in data:
|
99 |
+
return jsonify({'error': 'No query provided', 'success': False}), 400
|
100 |
|
101 |
+
query_text = data['query']
|
102 |
+
language_code = data.get('language_code', 0)
|
|
|
|
|
103 |
|
104 |
+
# Process query
|
105 |
+
if language_code == 0:
|
106 |
+
query_text = translate_ar_to_en(query_text)
|
107 |
|
108 |
+
# Get embeddings and find relevant documents
|
109 |
+
query_embedding = embedding_model.encode([query_text])
|
110 |
+
initial_results = query_embeddings(query_embedding, embeddings_data)
|
|
|
|
|
|
|
111 |
|
112 |
+
# Process documents
|
113 |
+
document_texts = retrieve_document_texts([doc_id for doc_id, _ in initial_results])
|
114 |
+
relevant_portions = extract_relevant_portions(document_texts, query_text)
|
115 |
|
116 |
+
# Generate answer
|
117 |
+
combined_text = " ".join([item for sublist in relevant_portions.values() for item in sublist])
|
118 |
+
answer = generate_answer(query_text, combined_text)
|
119 |
+
|
120 |
+
if language_code == 0:
|
121 |
+
answer = translate_en_to_ar(answer)
|
122 |
|
123 |
+
return jsonify({
|
124 |
+
'answer': answer,
|
125 |
+
'success': True
|
126 |
+
})
|
|
|
|
|
127 |
|
128 |
+
except Exception as e:
|
129 |
+
return jsonify({
|
130 |
+
'error': str(e),
|
131 |
+
'success': False
|
132 |
+
}), 500
|
133 |
|
134 |
+
def translate_ar_to_en(text):
|
|
|
135 |
try:
|
136 |
+
inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True)
|
137 |
+
outputs = ar_to_en_model.generate(**inputs)
|
138 |
+
return ar_to_en_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
139 |
+
except Exception as e:
|
140 |
+
print(f"Translation error (AR->EN): {e}")
|
141 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
def translate_en_to_ar(text):
|
144 |
+
try:
|
145 |
+
inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True)
|
146 |
+
outputs = en_to_ar_model.generate(**inputs)
|
147 |
+
return en_to_ar_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
148 |
except Exception as e:
|
149 |
+
print(f"Translation error (EN->AR): {e}")
|
150 |
+
return text
|
151 |
+
|
152 |
+
language_code = 0
|
153 |
|
154 |
+
query_text = 'How can a patient with chronic kidney disease manage their daily activities and maintain quality of life?' #'symptoms of a heart attack '
|
155 |
+
|
156 |
+
def process_query(query_text):
|
157 |
if language_code == 0:
|
158 |
+
# Translate Arabic input to English
|
159 |
+
query_text = translate_ar_to_en(query_text)
|
160 |
return query_text
|
161 |
|
162 |
+
def embed_query_text(query_text):
|
163 |
+
query_embedding = embedding_model.encode([query_text])
|
164 |
+
return query_embedding
|
165 |
+
|
166 |
+
def query_embeddings(query_embedding, embeddings_data, n_results=5):
|
167 |
+
doc_ids = list(embeddings_data.keys())
|
168 |
+
doc_embeddings = np.array(list(embeddings_data.values()))
|
169 |
+
similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
|
170 |
+
top_indices = similarities.argsort()[-n_results:][::-1]
|
171 |
+
top_docs = [(doc_ids[i], similarities[i]) for i in top_indices]
|
172 |
+
|
173 |
+
return top_docs
|
174 |
+
|
175 |
+
query_embedding = embed_query_text(query_text) # Embed the query text
|
176 |
+
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
177 |
+
document_ids = [doc_id for doc_id, _ in initial_results]
|
178 |
+
print(document_ids)
|
179 |
|
180 |
+
import pandas as pd
|
181 |
+
import requests
|
182 |
+
from bs4 import BeautifulSoup
|
183 |
+
|
184 |
+
# Load the Excel file
|
185 |
+
file_path = '/kaggle/input/final-links/finalcleaned_excel_file.xlsx'
|
186 |
+
df = pd.read_excel(file_path)
|
187 |
+
|
188 |
+
|
189 |
+
# Create a dictionary mapping file names to URLs
|
190 |
+
# Assuming the DataFrame index corresponds to file names
|
191 |
+
file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
192 |
+
def get_page_title(url):
|
193 |
+
try:
|
194 |
+
response = requests.get(url)
|
195 |
+
if response.status_code == 200:
|
196 |
+
soup = BeautifulSoup(response.text, 'html.parser')
|
197 |
+
title = soup.find('title')
|
198 |
+
return title.get_text() if title else "No title found"
|
199 |
+
else:
|
200 |
+
return None
|
201 |
+
except requests.exceptions.RequestException:
|
202 |
+
return None
|
203 |
+
# Example file names
|
204 |
+
file_names = document_ids
|
205 |
+
|
206 |
+
# Retrieve original URLs
|
207 |
+
for file_name in file_names:
|
208 |
+
original_url = file_name_to_url.get(file_name, None)
|
209 |
+
if original_url:
|
210 |
+
title = get_page_title(original_url)
|
211 |
+
if title:
|
212 |
+
print(f"Title: {title},URL: {original_url}")
|
213 |
+
else:
|
214 |
+
print(f"Name: {file_name}")
|
215 |
+
else:
|
216 |
+
print(f"Name: {file_name}")
|
217 |
+
|
218 |
+
def retrieve_document_texts(doc_ids, folder_path):
|
219 |
+
texts = []
|
220 |
+
for doc_id in doc_ids:
|
221 |
+
file_path = os.path.join(folder_path, doc_id)
|
222 |
+
try:
|
223 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
224 |
+
soup = BeautifulSoup(file, 'html.parser')
|
225 |
+
text = soup.get_text(separator=' ', strip=True)
|
226 |
+
texts.append(text)
|
227 |
+
except FileNotFoundError:
|
228 |
+
texts.append("")
|
229 |
+
return texts
|
230 |
+
document_ids = [doc_id for doc_id, _ in initial_results]
|
231 |
+
document_texts = retrieve_document_texts(document_ids, folder_path)
|
232 |
+
|
233 |
+
# Rerank the results using the CrossEncoder
|
234 |
+
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
235 |
+
scored_documents = list(zip(scores, document_ids, document_texts))
|
236 |
+
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
237 |
+
print("Reranked results:")
|
238 |
+
for idx, (score, doc_id, doc) in enumerate(scored_documents):
|
239 |
+
print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id}")
|
240 |
+
|
241 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
|
242 |
+
import nltk
|
243 |
|
244 |
+
# Load BioBERT model and tokenizer for NER
|
245 |
+
bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
246 |
+
bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
247 |
+
ner_biobert = pipeline("ner", model=bio_model, tokenizer=bio_tokenizer)
|
248 |
+
|
249 |
+
def extract_entities(text, ner_pipeline):
|
250 |
+
"""
|
251 |
+
Extract entities using a NER pipeline.
|
252 |
+
Args:
|
253 |
+
text (str): The text from which to extract entities.
|
254 |
+
ner_pipeline (pipeline): The NER pipeline for entity extraction.
|
255 |
+
Returns:
|
256 |
+
List[str]: A list of unique extracted entities.
|
257 |
+
"""
|
258 |
+
ner_results = ner_pipeline(text)
|
259 |
+
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
260 |
+
return list(entities)
|
261 |
+
|
262 |
+
def match_entities(query_entities, sentence_entities):
|
263 |
+
"""
|
264 |
+
Compute the relevance score based on entity matching.
|
265 |
+
Args:
|
266 |
+
query_entities (List[str]): Entities extracted from the query.
|
267 |
+
sentence_entities (List[str]): Entities extracted from the sentence.
|
268 |
+
Returns:
|
269 |
+
float: The relevance score based on entity overlap.
|
270 |
+
"""
|
271 |
+
query_set, sentence_set = set(query_entities), set(sentence_entities)
|
272 |
+
matches = query_set.intersection(sentence_set)
|
273 |
+
return len(matches)
|
274 |
+
|
275 |
+
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
|
276 |
+
"""
|
277 |
+
Extract relevant text portions from documents based on entity matching.
|
278 |
+
Args:
|
279 |
+
document_texts (List[str]): List of document texts.
|
280 |
+
query (str): The query text.
|
281 |
+
max_portions (int): Maximum number of relevant portions to extract per document.
|
282 |
+
portion_size (int): Number of sentences to include in each portion.
|
283 |
+
min_query_words (int): Minimum number of matching entities to consider a sentence relevant.
|
284 |
+
Returns:
|
285 |
+
Dict[str, List[str]]: Relevant portions for each document.
|
286 |
+
"""
|
287 |
+
relevant_portions = {}
|
288 |
+
|
289 |
+
# Extract entities from the query
|
290 |
+
query_entities = extract_entities(query, ner_biobert)
|
291 |
+
print(f"Extracted Query Entities: {query_entities}")
|
292 |
+
|
293 |
+
for doc_id, doc_text in enumerate(document_texts):
|
294 |
+
sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
|
295 |
+
doc_relevant_portions = []
|
296 |
+
|
297 |
+
# Extract entities from the entire document
|
298 |
+
doc_entities = extract_entities(doc_text, ner_biobert)
|
299 |
+
print(f"Document {doc_id} Entities: {doc_entities}")
|
300 |
+
|
301 |
+
for i, sentence in enumerate(sentences):
|
302 |
+
# Extract entities from the sentence
|
303 |
+
sentence_entities = extract_entities(sentence, ner_biobert)
|
304 |
+
|
305 |
+
# Compute relevance score
|
306 |
+
relevance_score = match_entities(query_entities, sentence_entities)
|
307 |
+
|
308 |
+
# Select sentences with at least `min_query_words` matching entities
|
309 |
+
if relevance_score >= min_query_words:
|
310 |
+
start_idx = max(0, i - portion_size // 2)
|
311 |
+
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
312 |
+
portion = " ".join(sentences[start_idx:end_idx])
|
313 |
+
doc_relevant_portions.append(portion)
|
314 |
+
|
315 |
+
if len(doc_relevant_portions) >= max_portions:
|
316 |
+
break
|
317 |
+
|
318 |
+
# Add fallback to include the most entity-dense sentences if no results
|
319 |
+
if not doc_relevant_portions and len(doc_entities) > 0:
|
320 |
+
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
321 |
+
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
322 |
+
for fallback_sentence in sorted_sentences[:max_portions]:
|
323 |
+
doc_relevant_portions.append(fallback_sentence)
|
324 |
+
|
325 |
+
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
326 |
+
|
327 |
+
return relevant_portions
|
328 |
+
|
329 |
+
# Extract relevant portions based on query and documents
|
330 |
+
relevant_portions = extract_relevant_portions(document_texts, query_text, max_portions=3, portion_size=1, min_query_words=1)
|
331 |
+
|
332 |
+
for doc_id, portions in relevant_portions.items():
|
333 |
+
print(f"{doc_id}: {portions}")
|
334 |
+
|
335 |
+
# Remove duplicates from the selected portions
|
336 |
+
def remove_duplicates(selected_parts):
|
337 |
+
unique_sentences = set()
|
338 |
+
unique_selected_parts = []
|
339 |
+
|
340 |
+
for sentence in selected_parts:
|
341 |
+
if sentence not in unique_sentences:
|
342 |
+
unique_selected_parts.append(sentence)
|
343 |
+
unique_sentences.add(sentence)
|
344 |
+
|
345 |
+
return unique_selected_parts
|
346 |
+
|
347 |
+
# Flatten the dictionary of relevant portions (from earlier code)
|
348 |
+
flattened_relevant_portions = []
|
349 |
+
for doc_id, portions in relevant_portions.items():
|
350 |
+
flattened_relevant_portions.extend(portions)
|
351 |
+
|
352 |
+
# Remove duplicate portions
|
353 |
+
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
354 |
+
|
355 |
+
# Combine the unique parts into a single string of context
|
356 |
+
combined_parts = " ".join(unique_selected_parts)
|
357 |
+
|
358 |
+
# Construct context as a list: first the query, then the unique selected portions
|
359 |
+
context = [query_text] + unique_selected_parts
|
360 |
+
|
361 |
+
# Print the context (query + relevant portions)
|
362 |
+
print(context)
|
363 |
|
364 |
+
import pickle
|
365 |
+
|
366 |
+
with open('/kaggle/input/art-embeddings-pkl/embeddings.pkl', 'rb') as file:
|
367 |
+
data = pickle.load(file)
|
368 |
+
|
369 |
+
# Print the type of data
|
370 |
+
print(f"Data type: {type(data)}")
|
371 |
+
|
372 |
+
# Print the first few keys and values from the dictionary
|
373 |
+
print("First few keys and values:")
|
374 |
+
for i, (key, value) in enumerate(data.items()):
|
375 |
+
if i >= 5: # Limit to printing the first 5 key-value pairs
|
376 |
+
break
|
377 |
+
print(f"Key: {key}, Value: {value}")
|
378 |
+
|
379 |
+
import pickle
|
380 |
+
import pickletools
|
381 |
+
|
382 |
+
# Load the pickle file
|
383 |
+
file_path = '/kaggle/input/art-embeddings-pkl/embeddings.pkl'
|
384 |
+
|
385 |
+
with open(file_path, 'rb') as f:
|
386 |
+
# Read the pickle file
|
387 |
+
data = pickle.load(f)
|
388 |
+
|
389 |
+
# Check for suspicious or corrupted entries
|
390 |
+
def inspect_pickle(data):
|
391 |
+
for key, value in data.items():
|
392 |
+
if isinstance(value, (str, bytes)):
|
393 |
+
# Try to decode and catch any non-ASCII issues
|
394 |
+
try:
|
395 |
+
value.decode('ascii')
|
396 |
+
except UnicodeDecodeError as e:
|
397 |
+
print(f"Non-ASCII entry found in key: {key}")
|
398 |
+
print(f"Corrupted data: {value} ({e})")
|
399 |
+
continue
|
400 |
+
|
401 |
+
if isinstance(value, list) and any(isinstance(v, (list, dict, str, bytes)) for v in value):
|
402 |
+
# Inspect list elements recursively
|
403 |
+
inspect_pickle({f"{key}[{idx}]": v for idx, v in enumerate(value)})
|
404 |
+
|
405 |
+
# Inspect the data
|
406 |
+
inspect_pickle(data)
|
407 |
+
|
408 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
409 |
+
import torch
|
410 |
+
import time
|
411 |
|
412 |
+
# Load Biobert model and tokenizer
|
413 |
+
biobert_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
414 |
+
biobert_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
415 |
+
|
416 |
+
def extract_entities(text):
|
417 |
+
inputs = biobert_tokenizer(text, return_tensors="pt")
|
418 |
+
outputs = biobert_model(**inputs)
|
419 |
+
predictions = torch.argmax(outputs.logits, dim=2)
|
420 |
+
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
421 |
+
entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
|
422 |
+
return entities
|
423 |
+
|
424 |
+
def enhance_passage_with_entities(passage, entities):
|
425 |
+
# Example: Add entities to the passage for better context
|
426 |
+
return f"{passage}\n\nEntities: {', '.join(entities)}"
|
427 |
+
|
428 |
+
def create_prompt(question, passage):
|
429 |
+
prompt = ("""
|
430 |
+
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
431 |
+
|
432 |
+
Passage: {passage}
|
433 |
+
|
434 |
+
Question: {question}
|
435 |
+
|
436 |
+
Answer:
|
437 |
+
""")
|
438 |
+
return prompt.format(passage=passage, question=question)
|
439 |
+
|
440 |
+
def generate_answer(prompt, max_length=860, temperature=0.2):
|
441 |
+
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
442 |
+
|
443 |
+
# Start timing
|
444 |
+
start_time = time.time()
|
445 |
+
|
446 |
+
output_ids = model_f.generate(
|
447 |
+
inputs.input_ids,
|
448 |
+
max_length=max_length,
|
449 |
+
num_return_sequences=1,
|
450 |
+
temperature=temperature,
|
451 |
+
pad_token_id=tokenizer_f.eos_token_id
|
452 |
+
)
|
453 |
+
|
454 |
+
# End timing
|
455 |
+
end_time = time.time()
|
456 |
+
|
457 |
+
# Calculate the duration
|
458 |
+
duration = end_time - start_time
|
459 |
+
|
460 |
+
# Decode the answer
|
461 |
+
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
462 |
+
|
463 |
+
passage_keywords = set(passage.lower().split())
|
464 |
+
answer_keywords = set(answer.lower().split())
|
465 |
+
|
466 |
+
if passage_keywords.intersection(answer_keywords):
|
467 |
+
return answer, duration
|
468 |
+
else:
|
469 |
+
return "Sorry, I can't help with that.", duration
|
470 |
+
|
471 |
+
# Integrate Biobert model
|
472 |
+
entities = extract_entities(query_text)
|
473 |
+
passage = enhance_passage_with_entities(combined_parts, entities)
|
474 |
+
# Generate answer with the enhanced passage
|
475 |
+
prompt = create_prompt(query_text, passage)
|
476 |
+
answer, generation_time = generate_answer(prompt)
|
477 |
+
print(f"\nTime taken to generate the answer: {generation_time:.2f} seconds")
|
478 |
+
def remove_answer_prefix(text):
|
479 |
+
prefix = "Answer:"
|
480 |
+
if prefix in text:
|
481 |
+
return text.split(prefix)[-1].strip()
|
482 |
+
return text
|
483 |
+
|
484 |
+
def remove_incomplete_sentence(text):
|
485 |
+
# Check if the text ends with a period
|
486 |
+
if not text.endswith('.'):
|
487 |
+
# Find the last period or the end of the string
|
488 |
+
last_period_index = text.rfind('.')
|
489 |
+
if last_period_index != -1:
|
490 |
+
# Remove everything after the last period
|
491 |
+
return text[:last_period_index + 1].strip()
|
492 |
+
return text
|
493 |
+
# Clean and print the answer
|
494 |
+
answer_part = answer.split("Answer:")[-1].strip()
|
495 |
+
cleaned_answer = remove_answer_prefix(answer_part)
|
496 |
+
final_answer = remove_incomplete_sentence(cleaned_answer)
|
497 |
+
|
498 |
+
if language_code == 0:
|
499 |
+
final_answer = translate_en_to_ar(final_answer)
|
500 |
+
|
501 |
+
if final_answer:
|
502 |
+
print("Answer:")
|
503 |
+
print(final_answer)
|
504 |
+
else:
|
505 |
+
print("Sorry, I can't help with that.")
|