Spaces:
Sleeping
Sleeping
Commit
·
88b8fb2
1
Parent(s):
c77cccd
Update app.py
Browse files
app.py
CHANGED
@@ -170,54 +170,10 @@ def translate_text(text, source_to_target='ar_to_en'):
|
|
170 |
print(f"Translation error: {e}")
|
171 |
return text
|
172 |
|
173 |
-
def
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
return list({result['word'] for result in results if result['entity'].startswith("B-")})
|
178 |
-
except Exception as e:
|
179 |
-
print(f"Error extracting entities: {e}")
|
180 |
-
return []
|
181 |
-
|
182 |
-
def generate_answer(query, context, max_length=860, temperature=0.2):
|
183 |
-
"""Generate answer using LLM"""
|
184 |
-
try:
|
185 |
-
prompt = f"""
|
186 |
-
As a medical expert, please provide a clear and accurate answer to the following question based solely on the provided context.
|
187 |
-
|
188 |
-
Context: {context}
|
189 |
-
|
190 |
-
Question: {query}
|
191 |
-
|
192 |
-
Answer: Let me help you with accurate information from reliable medical sources."""
|
193 |
-
|
194 |
-
inputs = models['llm_tokenizer'](prompt, return_tensors="pt", truncation=True)
|
195 |
-
|
196 |
-
with torch.no_grad():
|
197 |
-
outputs = models['llm_model'].generate(
|
198 |
-
inputs.input_ids,
|
199 |
-
max_length=max_length,
|
200 |
-
num_return_sequences=1,
|
201 |
-
temperature=temperature,
|
202 |
-
do_sample=True,
|
203 |
-
top_p=0.9,
|
204 |
-
pad_token_id=models['llm_tokenizer'].eos_token_id
|
205 |
-
)
|
206 |
-
|
207 |
-
response = models['llm_tokenizer'].decode(outputs[0], skip_special_tokens=True)
|
208 |
-
|
209 |
-
if "Answer:" in response:
|
210 |
-
response = response.split("Answer:")[-1].strip()
|
211 |
-
|
212 |
-
sentences = nltk.sent_tokenize(response)
|
213 |
-
if sentences:
|
214 |
-
return " ".join(sentences)
|
215 |
-
return response
|
216 |
-
|
217 |
-
except Exception as e:
|
218 |
-
print(f"Error generating answer: {e}")
|
219 |
-
return "I apologize, but I'm unable to generate an answer at this time. Please try again later."
|
220 |
-
|
221 |
def query_embeddings(query_embedding, n_results=5):
|
222 |
"""Find relevant documents using embedding similarity"""
|
223 |
if not data['embeddings']:
|
@@ -258,6 +214,160 @@ def rerank_documents(query, doc_texts):
|
|
258 |
print(f"Error reranking documents: {e}")
|
259 |
return np.zeros(len(doc_texts))
|
260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
@app.get("/")
|
262 |
async def root():
|
263 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
@@ -287,8 +397,8 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
287 |
rerank_scores = rerank_documents(query_text, doc_texts)
|
288 |
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
289 |
|
290 |
-
context =
|
291 |
-
answer =
|
292 |
|
293 |
return {
|
294 |
"response": answer,
|
|
|
170 |
print(f"Translation error: {e}")
|
171 |
return text
|
172 |
|
173 |
+
def embed_query_text(query_text):
|
174 |
+
query_embedding = embedding_model.encode([query_text])
|
175 |
+
return query_embedding
|
176 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
def query_embeddings(query_embedding, n_results=5):
|
178 |
"""Find relevant documents using embedding similarity"""
|
179 |
if not data['embeddings']:
|
|
|
214 |
print(f"Error reranking documents: {e}")
|
215 |
return np.zeros(len(doc_texts))
|
216 |
|
217 |
+
def extract_entities(text):
|
218 |
+
"""Extract medical entities from text using NER"""
|
219 |
+
try:
|
220 |
+
results = models['ner_pipeline'](text)
|
221 |
+
return list({result['word'] for result in results if result['entity'].startswith("B-")})
|
222 |
+
except Exception as e:
|
223 |
+
print(f"Error extracting entities: {e}")
|
224 |
+
return []
|
225 |
+
def match_entities(query_entities, sentence_entities):
|
226 |
+
query_set, sentence_set = set(query_entities), set(sentence_entities)
|
227 |
+
matches = query_set.intersection(sentence_set)
|
228 |
+
return len(matches)
|
229 |
+
|
230 |
+
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
|
231 |
+
relevant_portions = {}
|
232 |
+
|
233 |
+
# Extract entities from the query
|
234 |
+
query_entities = extract_entities(query, ner_biobert)
|
235 |
+
print(f"Extracted Query Entities: {query_entities}")
|
236 |
+
for doc_id, doc_text in enumerate(document_texts):
|
237 |
+
sentences = nltk.sent_tokenize(doc_text) # Split document into sentences
|
238 |
+
doc_relevant_portions = []
|
239 |
+
|
240 |
+
# Extract entities from the entire document
|
241 |
+
doc_entities = extract_entities(doc_text, ner_biobert)
|
242 |
+
print(f"Document {doc_id} Entities: {doc_entities}")
|
243 |
+
|
244 |
+
for i, sentence in enumerate(sentences):
|
245 |
+
# Extract entities from the sentence
|
246 |
+
sentence_entities = extract_entities(sentence, ner_biobert)
|
247 |
+
|
248 |
+
# Compute relevance score
|
249 |
+
relevance_score = match_entities(query_entities, sentence_entities)
|
250 |
+
|
251 |
+
# Select sentences with at least `min_query_words` matching entities
|
252 |
+
if relevance_score >= min_query_words:
|
253 |
+
start_idx = max(0, i - portion_size // 2)
|
254 |
+
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
255 |
+
portion = " ".join(sentences[start_idx:end_idx])
|
256 |
+
doc_relevant_portions.append(portion)
|
257 |
+
if len(doc_relevant_portions) >= max_portions:
|
258 |
+
break
|
259 |
+
|
260 |
+
# Add fallback to include the most entity-dense sentences if no results
|
261 |
+
if not doc_relevant_portions and len(doc_entities) > 0:
|
262 |
+
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
263 |
+
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
264 |
+
for fallback_sentence in sorted_sentences[:max_portions]:
|
265 |
+
doc_relevant_portions.append(fallback_sentence)
|
266 |
+
|
267 |
+
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
268 |
+
|
269 |
+
return relevant_portions
|
270 |
+
def remove_duplicates(selected_parts):
|
271 |
+
unique_sentences = set()
|
272 |
+
unique_selected_parts = []
|
273 |
+
|
274 |
+
for sentence in selected_parts:
|
275 |
+
if sentence not in unique_sentences:
|
276 |
+
unique_selected_parts.append(sentence)
|
277 |
+
unique_sentences.add(sentence)
|
278 |
+
|
279 |
+
return unique_selected_parts
|
280 |
+
|
281 |
+
# Flatten the dictionary of relevant portions (from earlier code)
|
282 |
+
flattened_relevant_portions = []
|
283 |
+
for doc_id, portions in relevant_portions.items():
|
284 |
+
flattened_relevant_portions.extend(portions)
|
285 |
+
|
286 |
+
# Remove duplicate portions
|
287 |
+
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
288 |
+
|
289 |
+
# Combine the unique parts into a single string of context
|
290 |
+
combined_parts = " ".join(unique_selected_parts)
|
291 |
+
|
292 |
+
# Construct context as a list: first the query, then the unique selected portions
|
293 |
+
context = [query_text] + unique_selected_parts
|
294 |
+
|
295 |
+
def extract_entities(text):
|
296 |
+
inputs = biobert_tokenizer(text, return_tensors="pt")
|
297 |
+
outputs = biobert_model(**inputs)
|
298 |
+
predictions = torch.argmax(outputs.logits, dim=2)
|
299 |
+
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
300 |
+
entities = [tokens[i] for i in range(len(tokens)) if predictions[0][i].item() != 0] # Assume 0 is the label for non-entity
|
301 |
+
return entities
|
302 |
+
|
303 |
+
def enhance_passage_with_entities(passage, entities):
|
304 |
+
# Example: Add entities to the passage for better context
|
305 |
+
return f"{passage}\n\nEntities: {', '.join(entities)}"
|
306 |
+
|
307 |
+
def create_prompt(question, passage):
|
308 |
+
prompt = ("""
|
309 |
+
As a medical expert, you are required to answer the following question based only on the provided passage. Do not include any information not present in the passage. Your response should directly reflect the content of the passage. Maintain accuracy and relevance to the provided information.
|
310 |
+
|
311 |
+
Passage: {passage}
|
312 |
+
|
313 |
+
Question: {question}
|
314 |
+
|
315 |
+
Answer:
|
316 |
+
""")
|
317 |
+
return prompt.format(passage=passage, question=question)
|
318 |
+
|
319 |
+
def generate_answer(prompt, max_length=860, temperature=0.2):
|
320 |
+
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
321 |
+
|
322 |
+
# Start timing
|
323 |
+
start_time = time.time()
|
324 |
+
|
325 |
+
output_ids = model_f.generate(
|
326 |
+
inputs.input_ids,
|
327 |
+
max_length=max_length,
|
328 |
+
num_return_sequences=1,
|
329 |
+
temperature=temperature,
|
330 |
+
pad_token_id=tokenizer_f.eos_token_id
|
331 |
+
)
|
332 |
+
|
333 |
+
# End timing
|
334 |
+
end_time = time.time()
|
335 |
+
|
336 |
+
# Calculate the duration
|
337 |
+
duration = end_time - start_time
|
338 |
+
|
339 |
+
# Decode the answer
|
340 |
+
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
341 |
+
|
342 |
+
passage_keywords = set(passage.lower().split())
|
343 |
+
answer_keywords = set(answer.lower().split())
|
344 |
+
|
345 |
+
if passage_keywords.intersection(answer_keywords):
|
346 |
+
return answer, duration
|
347 |
+
else:
|
348 |
+
return "Sorry, I can't help with that.", duration
|
349 |
+
|
350 |
+
def remove_answer_prefix(text):
|
351 |
+
prefix = "Answer:"
|
352 |
+
if prefix in text:
|
353 |
+
return text.split(prefix)[-1].strip()
|
354 |
+
return text
|
355 |
+
|
356 |
+
def remove_incomplete_sentence(text):
|
357 |
+
# Check if the text ends with a period
|
358 |
+
if not text.endswith('.'):
|
359 |
+
# Find the last period or the end of the string
|
360 |
+
last_period_index = text.rfind('.')
|
361 |
+
if last_period_index != -1:
|
362 |
+
# Remove everything after the last period
|
363 |
+
return text[:last_period_index + 1].strip()
|
364 |
+
return text
|
365 |
+
|
366 |
+
answer_part = answer.split("Answer:")[-1].strip()
|
367 |
+
cleaned_answer = remove_answer_prefix(answer_part)
|
368 |
+
final_answer = remove_incomplete_sentence(cleaned_answer)
|
369 |
+
|
370 |
+
|
371 |
@app.get("/")
|
372 |
async def root():
|
373 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
|
|
397 |
rerank_scores = rerank_documents(query_text, doc_texts)
|
398 |
ranked_texts = [text for _, text in sorted(zip(rerank_scores, doc_texts), reverse=True)]
|
399 |
|
400 |
+
context = [query_text] + unique_selected_parts
|
401 |
+
answer = remove_incomplete_sentence(query_text, context)
|
402 |
|
403 |
return {
|
404 |
"response": answer,
|