Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -33,10 +33,7 @@ from safetensors.numpy import load_file
|
|
33 |
from safetensors.torch import safe_open
|
34 |
nltk.download('punkt_tab')
|
35 |
|
36 |
-
# Initialize FastAPI app
|
37 |
app = FastAPI()
|
38 |
-
|
39 |
-
# Add CORS middleware
|
40 |
app.add_middleware(
|
41 |
CORSMiddleware,
|
42 |
allow_origins=["*"],
|
@@ -44,8 +41,6 @@ app.add_middleware(
|
|
44 |
allow_methods=["*"],
|
45 |
allow_headers=["*"],
|
46 |
)
|
47 |
-
|
48 |
-
# Global variables for models and data
|
49 |
models = {}
|
50 |
data = {}
|
51 |
|
@@ -68,7 +63,6 @@ class ChatMessage(BaseModel):
|
|
68 |
timestamp: str
|
69 |
|
70 |
def init_nltk():
|
71 |
-
"""Initialize NLTK resources"""
|
72 |
try:
|
73 |
nltk.download('punkt', quiet=True)
|
74 |
return True
|
@@ -77,39 +71,25 @@ def init_nltk():
|
|
77 |
return False
|
78 |
|
79 |
def load_models():
|
80 |
-
"""Initialize all required models"""
|
81 |
try:
|
82 |
print("Loading models...")
|
83 |
-
|
84 |
-
# Set device
|
85 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
86 |
print(f"Device set to use {device}")
|
87 |
-
|
88 |
-
# Embedding models
|
89 |
models['embedding_model'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
90 |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
91 |
models['semantic_model'] = SentenceTransformer('all-MiniLM-L6-v2')
|
92 |
-
|
93 |
-
# Translation models
|
94 |
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
95 |
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
96 |
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
97 |
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
98 |
-
|
99 |
-
#Attention model
|
100 |
models['att_tokenizer'] = AutoTokenizer.from_pretrained("facebook/bart-base")
|
101 |
models['att_model'] = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
102 |
-
|
103 |
-
# NER model
|
104 |
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
105 |
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
106 |
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
|
107 |
-
|
108 |
-
# LLM model
|
109 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
110 |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
|
111 |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
|
112 |
-
|
113 |
print("Models loaded successfully")
|
114 |
return True
|
115 |
except Exception as e:
|
@@ -118,7 +98,6 @@ def load_models():
|
|
118 |
|
119 |
def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
|
120 |
try:
|
121 |
-
# Locate or download embeddings file
|
122 |
embeddings_path = 'embeddings.safetensors'
|
123 |
if not os.path.exists(embeddings_path):
|
124 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
@@ -128,62 +107,35 @@ def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
|
|
128 |
repo_type="space"
|
129 |
)
|
130 |
|
131 |
-
# Initialize a dictionary to store embeddings
|
132 |
embeddings = {}
|
133 |
-
|
134 |
-
# Open the safetensors file
|
135 |
with safe_open(embeddings_path, framework="pt") as f:
|
136 |
keys = f.keys()
|
137 |
-
#0print(f"Available keys in the .safetensors file: {list(keys)}") # Debugging info
|
138 |
-
|
139 |
-
# Iterate over the keys and load tensors
|
140 |
for key in keys:
|
141 |
try:
|
142 |
tensor = f.get_tensor(key)
|
143 |
if not isinstance(tensor, torch.Tensor):
|
144 |
-
raise TypeError(f"Value for key {key} is not a valid PyTorch tensor.")
|
145 |
-
|
146 |
-
# Convert tensor to NumPy array
|
147 |
embeddings[key] = tensor.numpy()
|
148 |
except Exception as key_error:
|
149 |
print(f"Failed to process key {key}: {key_error}")
|
150 |
-
|
151 |
if embeddings:
|
152 |
print("Embeddings successfully loaded.")
|
153 |
else:
|
154 |
-
print("No embeddings could be loaded. Please check the file format and content.")
|
155 |
-
|
156 |
return embeddings
|
157 |
-
|
158 |
except Exception as e:
|
159 |
print(f"Error loading embeddings: {e}")
|
160 |
return None
|
161 |
|
162 |
def normalize_key(key: str) -> str:
|
163 |
-
"""Normalize embedding keys to match metadata IDs."""
|
164 |
match = re.search(r'file_(\d+)', key)
|
165 |
if match:
|
166 |
-
return match.group(1)
|
167 |
return key
|
168 |
|
169 |
-
|
170 |
-
import os
|
171 |
-
import numpy as np
|
172 |
-
from typing import Optional
|
173 |
-
from safetensors.numpy import load_file
|
174 |
-
from huggingface_hub import hf_hub_download
|
175 |
-
|
176 |
def load_recipes_embeddings() -> Optional[np.ndarray]:
|
177 |
-
"""
|
178 |
-
Loads recipe embeddings from a .safetensors file, handling local and remote downloads.
|
179 |
-
|
180 |
-
Returns:
|
181 |
-
Optional[np.ndarray]: A numpy array containing all embeddings (shape: (num_recipes, embedding_dim)).
|
182 |
-
"""
|
183 |
try:
|
184 |
-
embeddings_path = 'recipes_embeddings.safetensors'
|
185 |
-
|
186 |
-
# Check if file exists locally, otherwise download from Hugging Face Hub
|
187 |
if not os.path.exists(embeddings_path):
|
188 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
189 |
embeddings_path = hf_hub_download(
|
@@ -191,60 +143,40 @@ def load_recipes_embeddings() -> Optional[np.ndarray]:
|
|
191 |
filename="embeddings.safetensors",
|
192 |
repo_type="space"
|
193 |
)
|
194 |
-
|
195 |
-
# Load the embeddings tensor from the .safetensors file
|
196 |
embeddings = load_file(embeddings_path)
|
197 |
-
|
198 |
-
# Ensure the key 'embeddings' exists in the file
|
199 |
if "embeddings" not in embeddings:
|
200 |
raise ValueError("Key 'embeddings' not found in the safetensors file.")
|
201 |
-
|
202 |
-
# Retrieve the tensor as a numpy array
|
203 |
-
tensor = embeddings["embeddings"]
|
204 |
-
|
205 |
-
# Print information about the embeddings
|
206 |
print(f"Successfully loaded embeddings.")
|
207 |
print(f"Shape of embeddings: {tensor.shape}")
|
208 |
print(f"Dtype of embeddings: {tensor.dtype}")
|
209 |
print(f"First few values of the first embedding: {tensor[0][:5]}")
|
210 |
-
|
211 |
return tensor
|
212 |
-
|
213 |
except Exception as e:
|
214 |
print(f"Error loading embeddings: {e}")
|
215 |
return None
|
216 |
|
217 |
-
|
218 |
-
|
219 |
def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
|
220 |
-
"""Load document data from HTML articles in a specified folder."""
|
221 |
try:
|
222 |
print("Loading documents data...")
|
223 |
-
# Check if the folder exists
|
224 |
if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
|
225 |
print(f"Error: Folder '{folder_path}' not found")
|
226 |
return False
|
227 |
-
# List all HTML files in the folder
|
228 |
html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')]
|
229 |
if not html_files:
|
230 |
print(f"No HTML files found in folder '{folder_path}'")
|
231 |
return False
|
232 |
documents = []
|
233 |
-
# Iterate through each HTML file and parse the content
|
234 |
for file_name in html_files:
|
235 |
file_path = os.path.join(folder_path, file_name)
|
236 |
try:
|
237 |
with open(file_path, 'r', encoding='utf-8') as file:
|
238 |
-
# Parse the HTML file
|
239 |
soup = BeautifulSoup(file, 'html.parser')
|
240 |
-
# Extract text content (or customize this as per your needs)
|
241 |
text = soup.get_text(separator='\n').strip()
|
242 |
documents.append({"file_name": file_name, "content": text})
|
243 |
except Exception as e:
|
244 |
print(f"Error reading file {file_name}: {e}")
|
245 |
-
|
246 |
-
data['df'] = pd.DataFrame(documents)
|
247 |
-
|
248 |
if data['df'].empty:
|
249 |
print("No valid documents loaded.")
|
250 |
return False
|
@@ -254,34 +186,27 @@ def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
|
|
254 |
print(f"Error loading docs: {e}")
|
255 |
return None
|
256 |
|
257 |
-
|
258 |
def load_data():
|
259 |
-
"""Load all required data"""
|
260 |
embeddings_success = load_embeddings()
|
261 |
-
documents_success = load_documents_data()
|
262 |
-
|
263 |
if not embeddings_success:
|
264 |
print("Warning: Failed to load embeddings, falling back to basic functionality")
|
265 |
if not documents_success:
|
266 |
-
print("Warning: Failed to load documents data, falling back to basic functionality")
|
267 |
-
|
268 |
return True
|
269 |
|
270 |
-
# Initialize application
|
271 |
print("Initializing application...")
|
272 |
init_success = load_models() and load_data()
|
273 |
|
274 |
|
275 |
def translate_text(text, source_to_target='ar_to_en'):
|
276 |
-
"""Translate text between Arabic and English"""
|
277 |
try:
|
278 |
if source_to_target == 'ar_to_en':
|
279 |
tokenizer = models['ar_to_en_tokenizer']
|
280 |
model = models['ar_to_en_model']
|
281 |
else:
|
282 |
tokenizer = models['en_to_ar_tokenizer']
|
283 |
-
model = models['en_to_ar_model']
|
284 |
-
|
285 |
inputs = tokenizer(text, return_tensors="pt", truncation=True)
|
286 |
outputs = model.generate(**inputs)
|
287 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -309,40 +234,17 @@ def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
|
|
309 |
print(f"Error in query_embeddings: {e}")
|
310 |
return []
|
311 |
|
312 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
313 |
-
import numpy as np
|
314 |
-
|
315 |
def query_recipes_embeddings(query_embedding, embeddings_data, n_results = 5):
|
316 |
-
"""
|
317 |
-
Query the recipes embeddings to find the most similar recipes based on cosine similarity.
|
318 |
-
|
319 |
-
Args:
|
320 |
-
query_embedding (np.ndarray): A 1D numpy array representing the query embedding.
|
321 |
-
n_results (int): Number of top results to return.
|
322 |
-
|
323 |
-
Returns:
|
324 |
-
List[Tuple[int, float]]: A list of tuples containing the indices of the top results and their similarity scores.
|
325 |
-
"""
|
326 |
-
# Load embeddings
|
327 |
embeddings_data = load_recipes_embeddings()
|
328 |
if embeddings_data is None:
|
329 |
print("No embeddings data available.")
|
330 |
return []
|
331 |
-
|
332 |
try:
|
333 |
-
# Ensure query_embedding is 2D for cosine similarity computation
|
334 |
if query_embedding.ndim == 1:
|
335 |
query_embedding = query_embedding.reshape(1, -1)
|
336 |
-
|
337 |
-
# Compute cosine similarity
|
338 |
similarities = cosine_similarity(query_embedding, embeddings_data).flatten()
|
339 |
-
|
340 |
-
# Get the indices of the top N most similar embeddings
|
341 |
top_indices = similarities.argsort()[-n_results:][::-1]
|
342 |
-
|
343 |
-
# Return the indices and similarity scores of the top results
|
344 |
return [(index, similarities[index]) for index in top_indices]
|
345 |
-
|
346 |
except Exception as e:
|
347 |
print(f"Error in query_recipes_embeddings: {e}")
|
348 |
return []
|
@@ -364,12 +266,10 @@ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded
|
|
364 |
for doc_id in doc_ids:
|
365 |
file_path = os.path.join(folder_path, doc_id)
|
366 |
try:
|
367 |
-
# Check if the file exists
|
368 |
if not os.path.exists(file_path):
|
369 |
print(f"Warning: Document file not found: {file_path}")
|
370 |
texts.append("")
|
371 |
continue
|
372 |
-
# Read and parse the HTML file
|
373 |
with open(file_path, 'r', encoding='utf-8') as file:
|
374 |
soup = BeautifulSoup(file, 'html.parser')
|
375 |
text = soup.get_text(separator=' ', strip=True)
|
@@ -379,82 +279,71 @@ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded
|
|
379 |
texts.append("")
|
380 |
return texts
|
381 |
|
382 |
-
import os
|
383 |
-
import pandas as pd
|
384 |
-
|
385 |
def retrieve_rec_texts(
|
386 |
document_indices,
|
387 |
folder_path='downloaded_articles/downloaded_articles',
|
388 |
metadata_path='recipes_metadata.xlsx'
|
389 |
):
|
390 |
-
"""
|
391 |
-
Retrieve the texts of documents corresponding to the given indices.
|
392 |
-
|
393 |
-
Args:
|
394 |
-
document_indices (List[int]): A list of document indices to retrieve.
|
395 |
-
folder_path (str): Path to the folder containing the article files.
|
396 |
-
metadata_path (str): Path to the metadata file mapping indices to file names.
|
397 |
-
|
398 |
-
Returns:
|
399 |
-
List[str]: A list of document texts corresponding to the given indices.
|
400 |
-
"""
|
401 |
try:
|
402 |
-
# Load metadata file to map indices to original file names
|
403 |
metadata_df = pd.read_excel(metadata_path)
|
404 |
-
|
405 |
-
# Ensure the metadata file has the required columns
|
406 |
if "id" not in metadata_df.columns or "original_file_name" not in metadata_df.columns:
|
407 |
raise ValueError("Metadata file must contain 'id' and 'original_file_name' columns.")
|
408 |
-
|
409 |
-
# Ensure the 'id' column aligns with the embeddings row indices
|
410 |
metadata_df = metadata_df.sort_values(by="id").reset_index(drop=True)
|
411 |
-
|
412 |
-
# Verify the alignment of metadata with embeddings indices
|
413 |
if metadata_df.index.max() < max(document_indices):
|
414 |
raise ValueError("Some document indices exceed the range of metadata.")
|
415 |
-
|
416 |
-
# Retrieve file names for the given indices
|
417 |
document_texts = []
|
418 |
for idx in document_indices:
|
419 |
if idx >= len(metadata_df):
|
420 |
print(f"Warning: Index {idx} is out of range for metadata.")
|
421 |
continue
|
422 |
-
|
423 |
original_file_name = metadata_df.iloc[idx]["original_file_name"]
|
424 |
if not original_file_name:
|
425 |
print(f"Warning: No file name found for index {idx}")
|
426 |
continue
|
427 |
-
|
428 |
-
# Construct the file path using the original file name
|
429 |
file_path = os.path.join(folder_path, original_file_name)
|
430 |
-
|
431 |
-
# Check if the file exists and read its content
|
432 |
if os.path.exists(file_path):
|
433 |
with open(file_path, "r", encoding="utf-8") as f:
|
434 |
document_texts.append(f.read())
|
435 |
else:
|
436 |
print(f"Warning: File not found at {file_path}")
|
437 |
-
|
438 |
return document_texts
|
439 |
-
|
440 |
except Exception as e:
|
441 |
print(f"Error in retrieve_rec_texts: {e}")
|
442 |
return []
|
443 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
446 |
|
447 |
def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
448 |
try:
|
449 |
-
# Prepare pairs for the cross-encoder
|
450 |
pairs = [(query, doc) for doc in document_texts]
|
451 |
-
# Get scores using the cross-encoder model
|
452 |
scores = cross_encoder_model.predict(pairs)
|
453 |
-
# Combine scores with document IDs and texts
|
454 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
455 |
-
# Sort by scores in descending order
|
456 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
457 |
-
# Print reranked results
|
458 |
print("Reranked results:")
|
459 |
for idx, (score, doc_id, doc) in enumerate(scored_documents):
|
460 |
print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id})")
|
@@ -465,12 +354,9 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
|
465 |
|
466 |
def extract_entities(text, ner_pipeline=None):
|
467 |
try:
|
468 |
-
# Use the provided pipeline or default to the model dictionary
|
469 |
if ner_pipeline is None:
|
470 |
ner_pipeline = models['ner_pipeline']
|
471 |
-
# Perform NER using the pipeline
|
472 |
ner_results = ner_pipeline(text)
|
473 |
-
# Extract unique entities that start with "B-"
|
474 |
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
475 |
return list(entities)
|
476 |
except Exception as e:
|
@@ -488,22 +374,16 @@ def match_entities(query_entities, sentence_entities):
|
|
488 |
|
489 |
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
|
490 |
relevant_portions = {}
|
491 |
-
# Extract entities from the query
|
492 |
query_entities = extract_entities(query)
|
493 |
print(f"Extracted Query Entities: {query_entities}")
|
494 |
for doc_id, doc_text in enumerate(document_texts):
|
495 |
-
sentences = nltk.sent_tokenize(doc_text)
|
496 |
doc_relevant_portions = []
|
497 |
-
# Extract entities from the entire document
|
498 |
-
#ner_biobert = models['ner_pipeline']
|
499 |
doc_entities = extract_entities(doc_text)
|
500 |
print(f"Document {doc_id} Entities: {doc_entities}")
|
501 |
for i, sentence in enumerate(sentences):
|
502 |
-
# Extract entities from the sentence
|
503 |
sentence_entities = extract_entities(sentence)
|
504 |
-
# Compute relevance score
|
505 |
relevance_score = match_entities(query_entities, sentence_entities)
|
506 |
-
# Select sentences with at least `min_query_words` matching entities
|
507 |
if relevance_score >= min_query_words:
|
508 |
start_idx = max(0, i - portion_size // 2)
|
509 |
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
@@ -511,13 +391,11 @@ def extract_relevant_portions(document_texts, query, max_portions=3, portion_siz
|
|
511 |
doc_relevant_portions.append(portion)
|
512 |
if len(doc_relevant_portions) >= max_portions:
|
513 |
break
|
514 |
-
# Fallback: Include most entity-dense sentences if no relevant portions were found
|
515 |
if not doc_relevant_portions and len(doc_entities) > 0:
|
516 |
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
517 |
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
518 |
for fallback_sentence in sorted_sentences[:max_portions]:
|
519 |
doc_relevant_portions.append(fallback_sentence)
|
520 |
-
# Add the extracted portions to the result dictionary
|
521 |
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
522 |
return relevant_portions
|
523 |
|
@@ -537,7 +415,6 @@ def extract_entities(text):
|
|
537 |
inputs = biobert_tokenizer(text, return_tensors="pt")
|
538 |
outputs = biobert_model(**inputs)
|
539 |
predictions = torch.argmax(outputs.logits, dim=2)
|
540 |
-
|
541 |
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
542 |
entities = [
|
543 |
tokens[i]
|
@@ -568,9 +445,6 @@ def generate_answer(prompt, max_length=860, temperature=0.2):
|
|
568 |
tokenizer_f = models['llm_tokenizer']
|
569 |
model_f = models['llm_model']
|
570 |
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
571 |
-
# Start timing
|
572 |
-
#start_time = time.time()
|
573 |
-
# Generate the output
|
574 |
output_ids = model_f.generate(
|
575 |
inputs.input_ids,
|
576 |
max_length=max_length,
|
@@ -578,38 +452,27 @@ def generate_answer(prompt, max_length=860, temperature=0.2):
|
|
578 |
temperature=temperature,
|
579 |
pad_token_id=tokenizer_f.eos_token_id
|
580 |
)
|
581 |
-
# End timing
|
582 |
-
#end_time = time.time()
|
583 |
-
# Calculate the duration
|
584 |
-
#duration = end_time - start_time
|
585 |
-
# Decode the answer
|
586 |
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
587 |
-
|
588 |
-
passage_keywords = set(prompt.lower().split()) # Adjusted to check keywords in the full prompt
|
589 |
answer_keywords = set(answer.lower().split())
|
590 |
-
# Verify if the answer aligns with the passage
|
591 |
if passage_keywords.intersection(answer_keywords):
|
592 |
-
return answer
|
593 |
else:
|
594 |
-
return "Sorry, I can't help with that."
|
595 |
|
596 |
def remove_answer_prefix(text):
|
597 |
prefix = "Answer:"
|
598 |
if prefix in text:
|
599 |
-
return text.split(prefix, 1)[-1].strip()
|
600 |
return text
|
601 |
|
602 |
def remove_incomplete_sentence(text):
|
603 |
-
# Check if the text ends with a period
|
604 |
if not text.endswith('.'):
|
605 |
-
# Find the last period or the end of the string
|
606 |
last_period_index = text.rfind('.')
|
607 |
if last_period_index != -1:
|
608 |
-
# Remove everything after the last period
|
609 |
return text[:last_period_index + 1].strip()
|
610 |
return text
|
611 |
|
612 |
-
|
613 |
@app.get("/")
|
614 |
async def root():
|
615 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
@@ -630,7 +493,7 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
630 |
try:
|
631 |
query_text = chat_query.query
|
632 |
language_code = chat_query.language_code
|
633 |
-
query_embedding = embed_query_text(query_text)
|
634 |
embeddings_data = load_embeddings ()
|
635 |
folder_path = 'downloaded_articles/downloaded_articles'
|
636 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
@@ -671,34 +534,21 @@ async def chat_endpoint(chat_query: ChatQuery):
|
|
671 |
|
672 |
@app.post("/api/resources")
|
673 |
async def resources_endpoint(profile: MedicalProfile):
|
674 |
-
try:
|
675 |
-
|
676 |
-
# Build the query text
|
677 |
query_text = profile.conditions + " " + profile.daily_symptoms
|
678 |
-
|
679 |
-
print(f"Generated query text: {query_text}")
|
680 |
-
|
681 |
-
# Generate the query embedding
|
682 |
query_embedding = embed_query_text(query_text)
|
683 |
if query_embedding is None:
|
684 |
raise ValueError("Failed to generate query embedding.")
|
685 |
-
|
686 |
-
# Load embeddings and retrieve initial results
|
687 |
embeddings_data = load_embeddings()
|
688 |
folder_path = 'downloaded_articles/downloaded_articles'
|
689 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=6)
|
690 |
if not initial_results:
|
691 |
raise ValueError("No relevant documents found.")
|
692 |
-
|
693 |
-
# Extract document IDs
|
694 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
695 |
-
|
696 |
-
# Load document metadata (URL mappings)
|
697 |
file_path = 'finalcleaned_excel_file.xlsx'
|
698 |
df = pd.read_excel(file_path)
|
699 |
file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
700 |
-
|
701 |
-
# Map file names to original URLs
|
702 |
resources = []
|
703 |
for file_name in document_ids:
|
704 |
original_url = file_name_to_url.get(file_name, None)
|
@@ -707,54 +557,34 @@ async def resources_endpoint(profile: MedicalProfile):
|
|
707 |
resources.append({"file_name": file_name, "title": title, "url": original_url})
|
708 |
else:
|
709 |
resources.append({"file_name": file_name, "title": "Unknown", "url": None})
|
710 |
-
|
711 |
-
# Retrieve document texts
|
712 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
713 |
if not document_texts:
|
714 |
raise ValueError("Failed to retrieve document texts.")
|
715 |
-
|
716 |
-
# Perform re-ranking
|
717 |
cross_encoder = models['cross_encoder']
|
718 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
719 |
-
scores = [float(score) for score in scores]
|
720 |
-
|
721 |
-
# Combine scores with resources
|
722 |
for i, resource in enumerate(resources):
|
723 |
resource["score"] = scores[i] if i < len(scores) else 0.0
|
724 |
-
|
725 |
-
# Sort resources by score
|
726 |
resources.sort(key=lambda x: x["score"], reverse=True)
|
727 |
-
|
728 |
-
# Limit response to top 5 resources
|
729 |
return {"resources": resources[:5], "success": True}
|
730 |
-
|
731 |
except ValueError as ve:
|
732 |
-
# Handle expected errors
|
733 |
raise HTTPException(status_code=400, detail=str(ve))
|
734 |
except Exception as e:
|
735 |
-
# Handle unexpected errors
|
736 |
print(f"Unexpected error: {e}")
|
737 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
738 |
|
739 |
-
|
740 |
-
|
741 |
@app.post("/api/recipes")
|
742 |
async def recipes_endpoint(profile: MedicalProfile):
|
743 |
try:
|
744 |
-
# Build the query text for recipes
|
745 |
recipe_query = (
|
746 |
f"Recipes foods and meals suitable for someone with: "
|
747 |
f"{profile.conditions} and experiencing {profile.daily_symptoms}"
|
748 |
)
|
749 |
query_text = recipe_query
|
750 |
print(f"Generated query text: {query_text}")
|
751 |
-
|
752 |
-
# Generate the query embedding
|
753 |
query_embedding = embed_query_text(query_text)
|
754 |
if query_embedding is None:
|
755 |
raise ValueError("Failed to generate query embedding.")
|
756 |
-
|
757 |
-
# Load embeddings and retrieve initial results
|
758 |
embeddings_data = load_recipes_embeddings()
|
759 |
folder_path = 'downloaded_articles/downloaded_articles'
|
760 |
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=5)
|
@@ -762,92 +592,26 @@ async def recipes_endpoint(profile: MedicalProfile):
|
|
762 |
raise ValueError("No relevant recipes found.")
|
763 |
print("Initial results (document indices and similarities):")
|
764 |
print(initial_results)
|
765 |
-
|
766 |
-
# Extract document indices from the results
|
767 |
document_indices = [doc_id for doc_id, _ in initial_results]
|
768 |
-
print("Document indices:", document_indices)
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
print("Relevant portions extracted:")
|
780 |
-
print(relevant_portions)
|
781 |
-
|
782 |
-
flattened_relevant_portions = []
|
783 |
-
for doc_id, portions in relevant_portions.items():
|
784 |
-
flattened_relevant_portions.extend(portions)
|
785 |
-
unique_selected_parts = remove_duplicates(flattened_relevant_portions)
|
786 |
-
print("Unique selected parts:")
|
787 |
-
print(unique_selected_parts)
|
788 |
-
|
789 |
-
combined_parts = " ".join(unique_selected_parts)
|
790 |
-
print("Combined text for context:")
|
791 |
-
print(combined_parts)
|
792 |
-
|
793 |
-
context = [query_text] + unique_selected_parts
|
794 |
-
print("Final context for answering:")
|
795 |
-
print(context)
|
796 |
-
|
797 |
-
# Extract entities from the query
|
798 |
-
entities = extract_entities(query_text)
|
799 |
-
print("Extracted entities:")
|
800 |
-
print(entities)
|
801 |
-
|
802 |
-
# Enhance the passage with the extracted entities
|
803 |
-
passage = enhance_passage_with_entities(combined_parts, entities)
|
804 |
-
print("Enhanced passage with entities:")
|
805 |
-
print(passage)
|
806 |
-
|
807 |
-
# Create the prompt for the model
|
808 |
-
prompt = create_prompt(query_text, passage)
|
809 |
-
print("Generated prompt:")
|
810 |
-
print(prompt)
|
811 |
-
|
812 |
-
# Generate the answer from the model
|
813 |
-
answer = generate_answer(prompt)
|
814 |
-
print("Generated answer:")
|
815 |
-
print(answer)
|
816 |
-
|
817 |
-
# Clean up the answer to extract the relevant part
|
818 |
-
answer_part = answer.split("Answer:")[-1].strip()
|
819 |
-
cleaned_answer = remove_answer_prefix(answer_part)
|
820 |
-
print("Cleaned answer:")
|
821 |
-
print(cleaned_answer)
|
822 |
-
|
823 |
-
final_answer = remove_incomplete_sentence(cleaned_answer)
|
824 |
-
print("Final answer:")
|
825 |
-
print(final_answer)
|
826 |
-
|
827 |
-
if language_code == 0:
|
828 |
-
final_answer = translate_en_to_ar(final_answer)
|
829 |
-
|
830 |
-
if final_answer:
|
831 |
-
print("Answer:")
|
832 |
-
print(final_answer)
|
833 |
-
else:
|
834 |
-
print("Sorry, I can't help with that.")
|
835 |
-
|
836 |
-
return {"response": final_answer}
|
837 |
-
|
838 |
except ValueError as ve:
|
839 |
-
# Handle expected errors
|
840 |
raise HTTPException(status_code=400, detail=str(ve))
|
841 |
except Exception as e:
|
842 |
-
# Handle unexpected errors
|
843 |
print(f"Unexpected error: {e}")
|
844 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
845 |
-
|
846 |
-
|
847 |
if not init_success:
|
848 |
print("Warning: Application initialized with partial functionality")
|
849 |
-
|
850 |
-
# For running locally
|
851 |
if __name__ == "__main__":
|
852 |
import uvicorn
|
853 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
33 |
from safetensors.torch import safe_open
|
34 |
nltk.download('punkt_tab')
|
35 |
|
|
|
36 |
app = FastAPI()
|
|
|
|
|
37 |
app.add_middleware(
|
38 |
CORSMiddleware,
|
39 |
allow_origins=["*"],
|
|
|
41 |
allow_methods=["*"],
|
42 |
allow_headers=["*"],
|
43 |
)
|
|
|
|
|
44 |
models = {}
|
45 |
data = {}
|
46 |
|
|
|
63 |
timestamp: str
|
64 |
|
65 |
def init_nltk():
|
|
|
66 |
try:
|
67 |
nltk.download('punkt', quiet=True)
|
68 |
return True
|
|
|
71 |
return False
|
72 |
|
73 |
def load_models():
|
|
|
74 |
try:
|
75 |
print("Loading models...")
|
|
|
|
|
76 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
77 |
print(f"Device set to use {device}")
|
|
|
|
|
78 |
models['embedding_model'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
79 |
models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
|
80 |
models['semantic_model'] = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
|
81 |
models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
82 |
models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
|
83 |
models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
84 |
models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
|
|
|
|
|
85 |
models['att_tokenizer'] = AutoTokenizer.from_pretrained("facebook/bart-base")
|
86 |
models['att_model'] = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
|
|
|
|
87 |
models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
|
88 |
models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
|
89 |
models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
|
|
|
|
|
90 |
model_name = "M4-ai/Orca-2.0-Tau-1.8B"
|
91 |
models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
|
92 |
models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
93 |
print("Models loaded successfully")
|
94 |
return True
|
95 |
except Exception as e:
|
|
|
98 |
|
99 |
def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
|
100 |
try:
|
|
|
101 |
embeddings_path = 'embeddings.safetensors'
|
102 |
if not os.path.exists(embeddings_path):
|
103 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
|
|
107 |
repo_type="space"
|
108 |
)
|
109 |
|
|
|
110 |
embeddings = {}
|
|
|
|
|
111 |
with safe_open(embeddings_path, framework="pt") as f:
|
112 |
keys = f.keys()
|
|
|
|
|
|
|
113 |
for key in keys:
|
114 |
try:
|
115 |
tensor = f.get_tensor(key)
|
116 |
if not isinstance(tensor, torch.Tensor):
|
117 |
+
raise TypeError(f"Value for key {key} is not a valid PyTorch tensor.")
|
|
|
|
|
118 |
embeddings[key] = tensor.numpy()
|
119 |
except Exception as key_error:
|
120 |
print(f"Failed to process key {key}: {key_error}")
|
|
|
121 |
if embeddings:
|
122 |
print("Embeddings successfully loaded.")
|
123 |
else:
|
124 |
+
print("No embeddings could be loaded. Please check the file format and content.")
|
|
|
125 |
return embeddings
|
|
|
126 |
except Exception as e:
|
127 |
print(f"Error loading embeddings: {e}")
|
128 |
return None
|
129 |
|
130 |
def normalize_key(key: str) -> str:
|
|
|
131 |
match = re.search(r'file_(\d+)', key)
|
132 |
if match:
|
133 |
+
return match.group(1)
|
134 |
return key
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def load_recipes_embeddings() -> Optional[np.ndarray]:
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
try:
|
138 |
+
embeddings_path = 'recipes_embeddings.safetensors'
|
|
|
|
|
139 |
if not os.path.exists(embeddings_path):
|
140 |
print("File not found locally. Attempting to download from Hugging Face Hub...")
|
141 |
embeddings_path = hf_hub_download(
|
|
|
143 |
filename="embeddings.safetensors",
|
144 |
repo_type="space"
|
145 |
)
|
|
|
|
|
146 |
embeddings = load_file(embeddings_path)
|
|
|
|
|
147 |
if "embeddings" not in embeddings:
|
148 |
raise ValueError("Key 'embeddings' not found in the safetensors file.")
|
149 |
+
tensor = embeddings["embeddings"]
|
|
|
|
|
|
|
|
|
150 |
print(f"Successfully loaded embeddings.")
|
151 |
print(f"Shape of embeddings: {tensor.shape}")
|
152 |
print(f"Dtype of embeddings: {tensor.dtype}")
|
153 |
print(f"First few values of the first embedding: {tensor[0][:5]}")
|
|
|
154 |
return tensor
|
|
|
155 |
except Exception as e:
|
156 |
print(f"Error loading embeddings: {e}")
|
157 |
return None
|
158 |
|
|
|
|
|
159 |
def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
|
|
|
160 |
try:
|
161 |
print("Loading documents data...")
|
|
|
162 |
if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
|
163 |
print(f"Error: Folder '{folder_path}' not found")
|
164 |
return False
|
|
|
165 |
html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')]
|
166 |
if not html_files:
|
167 |
print(f"No HTML files found in folder '{folder_path}'")
|
168 |
return False
|
169 |
documents = []
|
|
|
170 |
for file_name in html_files:
|
171 |
file_path = os.path.join(folder_path, file_name)
|
172 |
try:
|
173 |
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
174 |
soup = BeautifulSoup(file, 'html.parser')
|
|
|
175 |
text = soup.get_text(separator='\n').strip()
|
176 |
documents.append({"file_name": file_name, "content": text})
|
177 |
except Exception as e:
|
178 |
print(f"Error reading file {file_name}: {e}")
|
179 |
+
data['df'] = pd.DataFrame(documents)
|
|
|
|
|
180 |
if data['df'].empty:
|
181 |
print("No valid documents loaded.")
|
182 |
return False
|
|
|
186 |
print(f"Error loading docs: {e}")
|
187 |
return None
|
188 |
|
|
|
189 |
def load_data():
|
|
|
190 |
embeddings_success = load_embeddings()
|
191 |
+
documents_success = load_documents_data()
|
|
|
192 |
if not embeddings_success:
|
193 |
print("Warning: Failed to load embeddings, falling back to basic functionality")
|
194 |
if not documents_success:
|
195 |
+
print("Warning: Failed to load documents data, falling back to basic functionality")
|
|
|
196 |
return True
|
197 |
|
|
|
198 |
print("Initializing application...")
|
199 |
init_success = load_models() and load_data()
|
200 |
|
201 |
|
202 |
def translate_text(text, source_to_target='ar_to_en'):
|
|
|
203 |
try:
|
204 |
if source_to_target == 'ar_to_en':
|
205 |
tokenizer = models['ar_to_en_tokenizer']
|
206 |
model = models['ar_to_en_model']
|
207 |
else:
|
208 |
tokenizer = models['en_to_ar_tokenizer']
|
209 |
+
model = models['en_to_ar_model']
|
|
|
210 |
inputs = tokenizer(text, return_tensors="pt", truncation=True)
|
211 |
outputs = model.generate(**inputs)
|
212 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
234 |
print(f"Error in query_embeddings: {e}")
|
235 |
return []
|
236 |
|
|
|
|
|
|
|
237 |
def query_recipes_embeddings(query_embedding, embeddings_data, n_results = 5):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
embeddings_data = load_recipes_embeddings()
|
239 |
if embeddings_data is None:
|
240 |
print("No embeddings data available.")
|
241 |
return []
|
|
|
242 |
try:
|
|
|
243 |
if query_embedding.ndim == 1:
|
244 |
query_embedding = query_embedding.reshape(1, -1)
|
|
|
|
|
245 |
similarities = cosine_similarity(query_embedding, embeddings_data).flatten()
|
|
|
|
|
246 |
top_indices = similarities.argsort()[-n_results:][::-1]
|
|
|
|
|
247 |
return [(index, similarities[index]) for index in top_indices]
|
|
|
248 |
except Exception as e:
|
249 |
print(f"Error in query_recipes_embeddings: {e}")
|
250 |
return []
|
|
|
266 |
for doc_id in doc_ids:
|
267 |
file_path = os.path.join(folder_path, doc_id)
|
268 |
try:
|
|
|
269 |
if not os.path.exists(file_path):
|
270 |
print(f"Warning: Document file not found: {file_path}")
|
271 |
texts.append("")
|
272 |
continue
|
|
|
273 |
with open(file_path, 'r', encoding='utf-8') as file:
|
274 |
soup = BeautifulSoup(file, 'html.parser')
|
275 |
text = soup.get_text(separator=' ', strip=True)
|
|
|
279 |
texts.append("")
|
280 |
return texts
|
281 |
|
|
|
|
|
|
|
282 |
def retrieve_rec_texts(
|
283 |
document_indices,
|
284 |
folder_path='downloaded_articles/downloaded_articles',
|
285 |
metadata_path='recipes_metadata.xlsx'
|
286 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
try:
|
|
|
288 |
metadata_df = pd.read_excel(metadata_path)
|
|
|
|
|
289 |
if "id" not in metadata_df.columns or "original_file_name" not in metadata_df.columns:
|
290 |
raise ValueError("Metadata file must contain 'id' and 'original_file_name' columns.")
|
|
|
|
|
291 |
metadata_df = metadata_df.sort_values(by="id").reset_index(drop=True)
|
|
|
|
|
292 |
if metadata_df.index.max() < max(document_indices):
|
293 |
raise ValueError("Some document indices exceed the range of metadata.")
|
|
|
|
|
294 |
document_texts = []
|
295 |
for idx in document_indices:
|
296 |
if idx >= len(metadata_df):
|
297 |
print(f"Warning: Index {idx} is out of range for metadata.")
|
298 |
continue
|
|
|
299 |
original_file_name = metadata_df.iloc[idx]["original_file_name"]
|
300 |
if not original_file_name:
|
301 |
print(f"Warning: No file name found for index {idx}")
|
302 |
continue
|
|
|
|
|
303 |
file_path = os.path.join(folder_path, original_file_name)
|
|
|
|
|
304 |
if os.path.exists(file_path):
|
305 |
with open(file_path, "r", encoding="utf-8") as f:
|
306 |
document_texts.append(f.read())
|
307 |
else:
|
308 |
print(f"Warning: File not found at {file_path}")
|
|
|
309 |
return document_texts
|
|
|
310 |
except Exception as e:
|
311 |
print(f"Error in retrieve_rec_texts: {e}")
|
312 |
return []
|
313 |
|
314 |
+
def retrieve_metadata(document_indices: List[str], metadata_path: str = 'recipes_metadata.xlsx') -> Dict[str, Dict[str, str]]:
|
315 |
+
try:
|
316 |
+
metadata_df = pd.read_excel(metadata_path)
|
317 |
+
required_columns = {'id', 'original_file_name', 'url'}
|
318 |
+
if not required_columns.issubset(metadata_df.columns):
|
319 |
+
raise ValueError(f"Metadata file must contain the following columns: {required_columns}")
|
320 |
+
metadata_mapping = metadata_df.set_index('id')[['original_file_name', 'url']].to_dict('index')
|
321 |
+
result = {doc_id: metadata_mapping.get(doc_id, {}) for doc_id in document_indices}
|
322 |
+
return result
|
323 |
+
except Exception as e:
|
324 |
+
print(f"Error retrieving metadata: {e}")
|
325 |
+
return {}
|
326 |
|
327 |
+
def retrieve_metadata(document_indices: List[str], metadata_path: str = 'recipes_metadata.xlsx') -> Dict[str, Dict[str, str]]:
|
328 |
+
try:
|
329 |
+
metadata_df = pd.read_excel(metadata_path)
|
330 |
+
required_columns = {'id', 'original_file_name', 'url'}
|
331 |
+
if not required_columns.issubset(metadata_df.columns):
|
332 |
+
raise ValueError(f"Metadata file must contain the following columns: {required_columns}")
|
333 |
+
metadata_mapping = metadata_df.set_index('id')[['original_file_name', 'url']].to_dict('index')
|
334 |
+
result = {doc_id: metadata_mapping.get(doc_id, {}) for doc_id in document_indices}
|
335 |
+
return result
|
336 |
+
except Exception as e:
|
337 |
+
print(f"Error retrieving metadata: {e}")
|
338 |
+
return {}
|
339 |
|
340 |
|
341 |
def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
|
342 |
try:
|
|
|
343 |
pairs = [(query, doc) for doc in document_texts]
|
|
|
344 |
scores = cross_encoder_model.predict(pairs)
|
|
|
345 |
scored_documents = list(zip(scores, document_ids, document_texts))
|
|
|
346 |
scored_documents.sort(key=lambda x: x[0], reverse=True)
|
|
|
347 |
print("Reranked results:")
|
348 |
for idx, (score, doc_id, doc) in enumerate(scored_documents):
|
349 |
print(f"Rank {idx + 1} (Score: {score:.4f}, Document ID: {doc_id})")
|
|
|
354 |
|
355 |
def extract_entities(text, ner_pipeline=None):
|
356 |
try:
|
|
|
357 |
if ner_pipeline is None:
|
358 |
ner_pipeline = models['ner_pipeline']
|
|
|
359 |
ner_results = ner_pipeline(text)
|
|
|
360 |
entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
|
361 |
return list(entities)
|
362 |
except Exception as e:
|
|
|
374 |
|
375 |
def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=1):
|
376 |
relevant_portions = {}
|
|
|
377 |
query_entities = extract_entities(query)
|
378 |
print(f"Extracted Query Entities: {query_entities}")
|
379 |
for doc_id, doc_text in enumerate(document_texts):
|
380 |
+
sentences = nltk.sent_tokenize(doc_text)
|
381 |
doc_relevant_portions = []
|
|
|
|
|
382 |
doc_entities = extract_entities(doc_text)
|
383 |
print(f"Document {doc_id} Entities: {doc_entities}")
|
384 |
for i, sentence in enumerate(sentences):
|
|
|
385 |
sentence_entities = extract_entities(sentence)
|
|
|
386 |
relevance_score = match_entities(query_entities, sentence_entities)
|
|
|
387 |
if relevance_score >= min_query_words:
|
388 |
start_idx = max(0, i - portion_size // 2)
|
389 |
end_idx = min(len(sentences), i + portion_size // 2 + 1)
|
|
|
391 |
doc_relevant_portions.append(portion)
|
392 |
if len(doc_relevant_portions) >= max_portions:
|
393 |
break
|
|
|
394 |
if not doc_relevant_portions and len(doc_entities) > 0:
|
395 |
print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
|
396 |
sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s, ner_biobert)), reverse=True)
|
397 |
for fallback_sentence in sorted_sentences[:max_portions]:
|
398 |
doc_relevant_portions.append(fallback_sentence)
|
|
|
399 |
relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
|
400 |
return relevant_portions
|
401 |
|
|
|
415 |
inputs = biobert_tokenizer(text, return_tensors="pt")
|
416 |
outputs = biobert_model(**inputs)
|
417 |
predictions = torch.argmax(outputs.logits, dim=2)
|
|
|
418 |
tokens = biobert_tokenizer.convert_ids_to_tokens(inputs.input_ids[0])
|
419 |
entities = [
|
420 |
tokens[i]
|
|
|
445 |
tokenizer_f = models['llm_tokenizer']
|
446 |
model_f = models['llm_model']
|
447 |
inputs = tokenizer_f(prompt, return_tensors="pt", truncation=True)
|
|
|
|
|
|
|
448 |
output_ids = model_f.generate(
|
449 |
inputs.input_ids,
|
450 |
max_length=max_length,
|
|
|
452 |
temperature=temperature,
|
453 |
pad_token_id=tokenizer_f.eos_token_id
|
454 |
)
|
|
|
|
|
|
|
|
|
|
|
455 |
answer = tokenizer_f.decode(output_ids[0], skip_special_tokens=True)
|
456 |
+
passage_keywords = set(prompt.lower().split())
|
|
|
457 |
answer_keywords = set(answer.lower().split())
|
|
|
458 |
if passage_keywords.intersection(answer_keywords):
|
459 |
+
return answer
|
460 |
else:
|
461 |
+
return "Sorry, I can't help with that."
|
462 |
|
463 |
def remove_answer_prefix(text):
|
464 |
prefix = "Answer:"
|
465 |
if prefix in text:
|
466 |
+
return text.split(prefix, 1)[-1].strip()
|
467 |
return text
|
468 |
|
469 |
def remove_incomplete_sentence(text):
|
|
|
470 |
if not text.endswith('.'):
|
|
|
471 |
last_period_index = text.rfind('.')
|
472 |
if last_period_index != -1:
|
|
|
473 |
return text[:last_period_index + 1].strip()
|
474 |
return text
|
475 |
|
|
|
476 |
@app.get("/")
|
477 |
async def root():
|
478 |
return {"message": "Welcome to the FastAPI application! Use the /health endpoint to check health, and /api/query for processing queries."}
|
|
|
493 |
try:
|
494 |
query_text = chat_query.query
|
495 |
language_code = chat_query.language_code
|
496 |
+
query_embedding = embed_query_text(query_text)
|
497 |
embeddings_data = load_embeddings ()
|
498 |
folder_path = 'downloaded_articles/downloaded_articles'
|
499 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=5)
|
|
|
534 |
|
535 |
@app.post("/api/resources")
|
536 |
async def resources_endpoint(profile: MedicalProfile):
|
537 |
+
try:
|
|
|
|
|
538 |
query_text = profile.conditions + " " + profile.daily_symptoms
|
539 |
+
print(f"Generated query text: {query_text}")
|
|
|
|
|
|
|
540 |
query_embedding = embed_query_text(query_text)
|
541 |
if query_embedding is None:
|
542 |
raise ValueError("Failed to generate query embedding.")
|
|
|
|
|
543 |
embeddings_data = load_embeddings()
|
544 |
folder_path = 'downloaded_articles/downloaded_articles'
|
545 |
initial_results = query_embeddings(query_embedding, embeddings_data, n_results=6)
|
546 |
if not initial_results:
|
547 |
raise ValueError("No relevant documents found.")
|
|
|
|
|
548 |
document_ids = [doc_id for doc_id, _ in initial_results]
|
|
|
|
|
549 |
file_path = 'finalcleaned_excel_file.xlsx'
|
550 |
df = pd.read_excel(file_path)
|
551 |
file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
|
|
|
|
|
552 |
resources = []
|
553 |
for file_name in document_ids:
|
554 |
original_url = file_name_to_url.get(file_name, None)
|
|
|
557 |
resources.append({"file_name": file_name, "title": title, "url": original_url})
|
558 |
else:
|
559 |
resources.append({"file_name": file_name, "title": "Unknown", "url": None})
|
|
|
|
|
560 |
document_texts = retrieve_document_texts(document_ids, folder_path)
|
561 |
if not document_texts:
|
562 |
raise ValueError("Failed to retrieve document texts.")
|
|
|
|
|
563 |
cross_encoder = models['cross_encoder']
|
564 |
scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
|
565 |
+
scores = [float(score) for score in scores]
|
|
|
|
|
566 |
for i, resource in enumerate(resources):
|
567 |
resource["score"] = scores[i] if i < len(scores) else 0.0
|
|
|
|
|
568 |
resources.sort(key=lambda x: x["score"], reverse=True)
|
|
|
|
|
569 |
return {"resources": resources[:5], "success": True}
|
|
|
570 |
except ValueError as ve:
|
|
|
571 |
raise HTTPException(status_code=400, detail=str(ve))
|
572 |
except Exception as e:
|
|
|
573 |
print(f"Unexpected error: {e}")
|
574 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
575 |
|
|
|
|
|
576 |
@app.post("/api/recipes")
|
577 |
async def recipes_endpoint(profile: MedicalProfile):
|
578 |
try:
|
|
|
579 |
recipe_query = (
|
580 |
f"Recipes foods and meals suitable for someone with: "
|
581 |
f"{profile.conditions} and experiencing {profile.daily_symptoms}"
|
582 |
)
|
583 |
query_text = recipe_query
|
584 |
print(f"Generated query text: {query_text}")
|
|
|
|
|
585 |
query_embedding = embed_query_text(query_text)
|
586 |
if query_embedding is None:
|
587 |
raise ValueError("Failed to generate query embedding.")
|
|
|
|
|
588 |
embeddings_data = load_recipes_embeddings()
|
589 |
folder_path = 'downloaded_articles/downloaded_articles'
|
590 |
initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=5)
|
|
|
592 |
raise ValueError("No relevant recipes found.")
|
593 |
print("Initial results (document indices and similarities):")
|
594 |
print(initial_results)
|
|
|
|
|
595 |
document_indices = [doc_id for doc_id, _ in initial_results]
|
596 |
+
print("Document indices:", document_indices)
|
597 |
+
metadata_path = 'recipes_metadata.xlsx'
|
598 |
+
metadata = retrieve_metadata(document_indices, metadata_path=metadata_path)
|
599 |
+
print(f"Retrieved Metadata: {metadata}")
|
600 |
+
response = {
|
601 |
+
"metadata": [
|
602 |
+
{"id": doc_id, "original_file_name": metadata.get(doc_id, {}).get("original_file_name"), "url": metadata.get(doc_id, {}).get("url")}
|
603 |
+
for doc_id in document_indices
|
604 |
+
],
|
605 |
+
}
|
606 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
except ValueError as ve:
|
|
|
608 |
raise HTTPException(status_code=400, detail=str(ve))
|
609 |
except Exception as e:
|
|
|
610 |
print(f"Unexpected error: {e}")
|
611 |
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
|
612 |
+
|
|
|
613 |
if not init_success:
|
614 |
print("Warning: Application initialized with partial functionality")
|
|
|
|
|
615 |
if __name__ == "__main__":
|
616 |
import uvicorn
|
617 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|