thechaiexperiment commited on
Commit
66387cc
·
1 Parent(s): 8497042

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -9
app.py CHANGED
@@ -7,6 +7,9 @@ import nltk
7
  import torch
8
  import pandas as pd
9
  import requests
 
 
 
10
  from fastapi import FastAPI, HTTPException
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from pydantic import BaseModel
@@ -130,7 +133,7 @@ def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
130
  # Open the safetensors file
131
  with safe_open(embeddings_path, framework="pt") as f:
132
  keys = f.keys()
133
- print(f"Available keys in the .safetensors file: {list(keys)}") # Debugging info
134
 
135
  # Iterate over the keys and load tensors
136
  for key in keys:
@@ -155,6 +158,46 @@ def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
155
  print(f"Error loading embeddings: {e}")
156
  return None
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
160
  """Load document data from HTML articles in a specified folder."""
@@ -195,16 +238,87 @@ def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
195
  data['df'] = pd.DataFrame()
196
  return False
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def load_data():
199
  """Load all required data"""
200
  embeddings_success = load_embeddings()
201
  documents_success = load_documents_data()
202
-
203
- if not embeddings_success:
 
204
  print("Warning: Failed to load embeddings, falling back to basic functionality")
205
- if not documents_success:
206
  print("Warning: Failed to load documents data, falling back to basic functionality")
207
-
208
  return True
209
 
210
  # Initialize application
@@ -248,6 +362,21 @@ def query_embeddings(query_embedding, embeddings_data=None, n_results=5):
248
  print(f"Error in query_embeddings: {e}")
249
  return []
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def get_page_title(url):
252
  try:
253
  response = requests.get(url)
@@ -280,6 +409,48 @@ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded
280
  texts.append("")
281
  return texts
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
285
  try:
@@ -646,9 +817,9 @@ async def recipes_endpoint(profile: MedicalProfile):
646
  raise ValueError("Failed to generate query embedding.")
647
 
648
  # Load embeddings and retrieve initial results
649
- embeddings_data = load_embeddings()
650
- folder_path = 'downloaded_articles/downloaded_articles'
651
- initial_results = query_embeddings(query_embedding, embeddings_data, n_results=10)
652
  if not initial_results:
653
  raise ValueError("No relevant recipes found.")
654
 
@@ -656,7 +827,7 @@ async def recipes_endpoint(profile: MedicalProfile):
656
  document_ids = [doc_id for doc_id, _ in initial_results]
657
 
658
  # Retrieve document texts
659
- document_texts = retrieve_document_texts(document_ids, folder_path)
660
  if not document_texts:
661
  raise ValueError("Failed to retrieve document texts.")
662
 
 
7
  import torch
8
  import pandas as pd
9
  import requests
10
+ import zipfile
11
+ import tempfile
12
+ from PyPDF2 import PdfReader
13
  from fastapi import FastAPI, HTTPException
14
  from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel
 
133
  # Open the safetensors file
134
  with safe_open(embeddings_path, framework="pt") as f:
135
  keys = f.keys()
136
+ #0print(f"Available keys in the .safetensors file: {list(keys)}") # Debugging info
137
 
138
  # Iterate over the keys and load tensors
139
  for key in keys:
 
158
  print(f"Error loading embeddings: {e}")
159
  return None
160
 
161
+ def load_recipes_embeddings() -> Optional[Dict[str, np.ndarray]]:
162
+ try:
163
+ # Locate or download the embeddings file
164
+ embeddings_path = 'recipes_embeddings.safetensors'
165
+ if not os.path.exists(embeddings_path):
166
+ print("File not found locally. Attempting to download from Hugging Face Hub...")
167
+ embeddings_path = hf_hub_download(
168
+ repo_id=os.environ.get('HF_SPACE_ID', 'thechaiexperiment/TeaRAG'),
169
+ filename="embeddings.safetensors",
170
+ repo_type="space"
171
+ )
172
+ # Initialize a dictionary to store embeddings
173
+ embeddings = {}
174
+ # Open the safetensors file
175
+ with safe_open(embeddings_path, framework="pt") as f:
176
+ keys = list(f.keys())
177
+ #print(f"Available keys in the .safetensors file: {keys}") # Debugging info
178
+
179
+ # Iterate over the keys and load tensors
180
+ for key in keys:
181
+ try:
182
+ tensor = f.get_tensor(key) # Get the tensor associated with the key
183
+ if tensor.shape[0] != 384: # Optional: Validate tensor shape
184
+ print(f"Warning: Tensor for key {key} has unexpected shape {tensor.shape}")
185
+
186
+ # Convert tensor to NumPy array
187
+ embeddings[key] = tensor.numpy()
188
+ except Exception as key_error:
189
+ print(f"Failed to process key {key}: {key_error}")
190
+
191
+ if embeddings:
192
+ print(f"Successfully loaded {len(embeddings)} embeddings.")
193
+ else:
194
+ print("No embeddings could be loaded. Please check the file format and content.")
195
+
196
+ return embeddings
197
+
198
+ except Exception as e:
199
+ print(f"Error loading embeddings: {e}")
200
+ return None
201
 
202
  def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
203
  """Load document data from HTML articles in a specified folder."""
 
238
  data['df'] = pd.DataFrame()
239
  return False
240
 
241
+ def load_recipes_data(folder_path='pdf kb.zip'):
242
+ try:
243
+ print("Loading documents data...")
244
+ temp_dir = None
245
+
246
+ # Handle .zip file
247
+ if folder_path.endswith('.zip'):
248
+ if not os.path.exists(folder_path):
249
+ print(f"Error: .zip file '{folder_path}' not found.")
250
+ return False
251
+
252
+ # Create a temporary directory for extracting the .zip
253
+ temp_dir = tempfile.TemporaryDirectory()
254
+ extract_path = temp_dir.name
255
+
256
+ # Extract the .zip file
257
+ try:
258
+ with zipfile.ZipFile(folder_path, 'r') as zip_ref:
259
+ zip_ref.extractall(extract_path)
260
+ print(f"Extracted .zip file to temporary folder: {extract_path}")
261
+ except Exception as e:
262
+ print(f"Error extracting .zip file: {e}")
263
+ return False
264
+
265
+ # Update the folder_path to the extracted directory
266
+ folder_path = extract_path
267
+
268
+ # Check if the folder exists
269
+ if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
270
+ print(f"Error: Folder '{folder_path}' not found.")
271
+ return False
272
+
273
+ # List all HTML or PDF files in the folder
274
+ html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')]
275
+ pdf_files = [f for f in os.listdir(folder_path) if f.endswith('.pdf')]
276
+
277
+ if not html_files and not pdf_files:
278
+ print(f"No HTML or PDF files found in folder '{folder_path}'.")
279
+ return False
280
+
281
+ documents = []
282
+
283
+ # Process PDF files (requires a PDF parser like PyPDF2)
284
+ for file_name in pdf_files:
285
+ file_path = os.path.join(folder_path, file_name)
286
+ try:
287
+ from PyPDF2 import PdfReader # Import here to avoid dependency issues
288
+ reader = PdfReader(file_path)
289
+ text = "\n".join(page.extract_text() for page in reader.pages if page.extract_text())
290
+ documents.append({"file_name": file_name, "content": text})
291
+ except Exception as e:
292
+ print(f"Error reading PDF file {file_name}: {e}")
293
+
294
+ # Convert the list of documents to a DataFrame
295
+ data['df'] = pd.DataFrame(documents)
296
+
297
+ if data['df'].empty:
298
+ print("No valid documents loaded.")
299
+ return False
300
+
301
+ print(f"Successfully loaded {len(data['df'])} document records.")
302
+ return True
303
+ except Exception as e:
304
+ print(f"Error loading documents data: {e}")
305
+ data['df'] = pd.DataFrame()
306
+ return False
307
+ finally:
308
+ # Clean up the temporary directory, if created
309
+ if temp_dir:
310
+ temp_dir.cleanup()
311
+
312
  def load_data():
313
  """Load all required data"""
314
  embeddings_success = load_embeddings()
315
  documents_success = load_documents_data()
316
+ recipes_success = load_recipes_data()
317
+ recipes_embeddings_success = load_recipes_embeddings()
318
+ if not recipes_embeddings_success:
319
  print("Warning: Failed to load embeddings, falling back to basic functionality")
320
+ if not recipes_success:
321
  print("Warning: Failed to load documents data, falling back to basic functionality")
 
322
  return True
323
 
324
  # Initialize application
 
362
  print(f"Error in query_embeddings: {e}")
363
  return []
364
 
365
+ def query_recipes_embeddings(query_embedding, embeddings_data=None, n_results=5):
366
+ embeddings_data = load_recipes_embeddings()
367
+ if not embeddings_data:
368
+ print("No embeddings data available.")
369
+ return []
370
+ try:
371
+ doc_ids = list(embeddings_data.keys())
372
+ doc_embeddings = np.array(list(embeddings_data.values()))
373
+ similarities = cosine_similarity(query_embedding, doc_embeddings).flatten()
374
+ top_indices = similarities.argsort()[-n_results:][::-1]
375
+ return [(doc_ids[i], similarities[i]) for i in top_indices]
376
+ except Exception as e:
377
+ print(f"Error in query_embeddings: {e}")
378
+ return []
379
+
380
  def get_page_title(url):
381
  try:
382
  response = requests.get(url)
 
409
  texts.append("")
410
  return texts
411
 
412
+ def retrieve_recipes_texts(doc_ids, zip_path='pdf kb.zip'):
413
+ texts = []
414
+
415
+ try:
416
+ # Check if the .zip file exists
417
+ if not os.path.exists(zip_path):
418
+ print(f"Error: Zip file not found at '{zip_path}'")
419
+ return ["" for _ in doc_ids]
420
+
421
+ # Create a temporary directory to extract the .zip contents
422
+ with tempfile.TemporaryDirectory() as temp_dir:
423
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
424
+ zip_ref.extractall(temp_dir) # Extract all files to the temp directory
425
+
426
+ # Iterate through the document IDs
427
+ for doc_id in doc_ids:
428
+ # Construct the expected PDF file path
429
+ pdf_path = os.path.join(temp_dir, f"{doc_id}.pdf")
430
+ try:
431
+ # Check if the PDF file exists
432
+ if not os.path.exists(pdf_path):
433
+ print(f"Warning: PDF file not found: {pdf_path}")
434
+ texts.append("")
435
+ continue
436
+
437
+ # Read and extract text from the PDF
438
+ with open(pdf_path, 'rb') as pdf_file:
439
+ reader = PdfReader(pdf_file)
440
+ pdf_text = ""
441
+ for page in reader.pages:
442
+ pdf_text += page.extract_text()
443
+
444
+ # Add the extracted text to the result list
445
+ texts.append(pdf_text.strip())
446
+ except Exception as e:
447
+ print(f"Error retrieving text from document {doc_id}: {e}")
448
+ texts.append("")
449
+
450
+ except Exception as e:
451
+ print(f"Error handling zip file: {e}")
452
+ return ["" for _ in doc_ids]
453
+ return texts
454
 
455
  def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
456
  try:
 
817
  raise ValueError("Failed to generate query embedding.")
818
 
819
  # Load embeddings and retrieve initial results
820
+ embeddings_data = load_recipes_embeddings()
821
+ folder_path = 'pdf kb.zip'
822
+ initial_results = query_recipes_embeddings(query_embedding, embeddings_data, n_results=10)
823
  if not initial_results:
824
  raise ValueError("No relevant recipes found.")
825
 
 
827
  document_ids = [doc_id for doc_id, _ in initial_results]
828
 
829
  # Retrieve document texts
830
+ document_texts = retrieve_recipes_texts(document_ids, folder_path)
831
  if not document_texts:
832
  raise ValueError("Failed to retrieve document texts.")
833