thechaiexperiment commited on
Commit
bca5800
·
verified ·
1 Parent(s): f673cee

Update general_rag.py

Browse files
Files changed (1) hide show
  1. general_rag.py +222 -32
general_rag.py CHANGED
@@ -1,24 +1,39 @@
 
 
1
  import os
2
  import re
3
  import numpy as np
 
 
4
  import torch
5
  import pandas as pd
6
  import requests
 
 
 
 
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  from transformers import (
11
  AutoTokenizer,
12
  AutoModelForSeq2SeqLM,
13
- pipeline
 
 
 
 
14
  )
15
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
16
  from sklearn.metrics.pairwise import cosine_similarity
17
  from bs4 import BeautifulSoup
18
  from huggingface_hub import hf_hub_download
19
- from safetensors.torch import safe_open
20
  from typing import List, Dict, Optional
21
- from openai import OpenAI
 
 
 
22
 
23
  app = FastAPI()
24
  app.add_middleware(
@@ -28,7 +43,6 @@ app.add_middleware(
28
  allow_methods=["*"],
29
  allow_headers=["*"],
30
  )
31
-
32
  models = {}
33
  data = {}
34
 
@@ -39,6 +53,21 @@ class QueryRequest(BaseModel):
39
  class ChatQuery(BaseModel):
40
  query: str
41
  language_code: int = 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def get_completion(prompt: str, model: str = "deepseek/deepseek-prover-v2:free") -> str:
44
  api_key = os.environ.get('OPENROUTER_API_KEY')
@@ -82,11 +111,9 @@ def get_completion(prompt: str, model: str = "deepseek/deepseek-prover-v2:free")
82
 
83
  def load_models():
84
  try:
85
- print("Loading general models...")
86
  device = "cuda" if torch.cuda.is_available() else "cpu"
87
  print(f"Device set to use {device}")
88
-
89
- # General models for all domains
90
  models['embedding_model'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
91
  models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
92
  models['semantic_model'] = SentenceTransformer('all-MiniLM-L6-v2')
@@ -94,11 +121,20 @@ def load_models():
94
  models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
95
  models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
96
  models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
97
-
98
- print("General models loaded successfully")
 
 
 
 
 
 
 
 
 
99
  return True
100
  except Exception as e:
101
- print(f"Error loading general models: {e}")
102
  return False
103
 
104
  def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
@@ -132,18 +168,45 @@ def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
132
  print(f"Error loading embeddings: {e}")
133
  return None
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
136
  try:
137
  print("Loading documents data...")
138
  if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
139
  print(f"Error: Folder '{folder_path}' not found")
140
  return False
141
-
142
  html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')]
143
  if not html_files:
144
  print(f"No HTML files found in folder '{folder_path}'")
145
  return False
146
-
147
  documents = []
148
  for file_name in html_files:
149
  file_path = os.path.join(folder_path, file_name)
@@ -154,18 +217,44 @@ def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
154
  documents.append({"file_name": file_name, "content": text})
155
  except Exception as e:
156
  print(f"Error reading file {file_name}: {e}")
157
-
158
- data['df'] = pd.DataFrame(documents)
159
- if data['df'].empty:
160
- print("No valid documents loaded.")
161
- return False
162
-
163
- print(f"Successfully loaded {len(data['df'])} document records.")
164
- return True
165
  except Exception as e:
166
  print(f"Error loading docs: {e}")
167
  return None
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def embed_query_text(query_text):
170
  embedding = models['embedding_model']
171
  query_embedding = embedding.encode([query_text])
@@ -186,6 +275,33 @@ def query_embeddings(query_embedding, embeddings_data, n_results):
186
  print(f"Error in query_embeddings: {e}")
187
  return []
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded_articles'):
190
  texts = []
191
  for doc_id in doc_ids:
@@ -204,6 +320,58 @@ def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded
204
  texts.append("")
205
  return texts
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
208
  try:
209
  pairs = [(query, doc) for doc in document_texts]
@@ -218,20 +386,40 @@ def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
218
  print(f"Error reranking documents: {e}")
219
  return []
220
 
221
- def translate_text(text, source_to_target='ar_to_en'):
222
  try:
223
- if source_to_target == 'ar_to_en':
224
- tokenizer = models['ar_to_en_tokenizer']
225
- model = models['ar_to_en_model']
226
- else:
227
- tokenizer = models['en_to_ar_tokenizer']
228
- model = models['en_to_ar_model']
229
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
230
- outputs = model.generate(**inputs)
231
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
232
  except Exception as e:
233
- print(f"Translation error: {e}")
234
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  @app.get("/")
237
  async def root():
@@ -248,6 +436,8 @@ async def health_check():
248
  }
249
  return status
250
 
 
 
251
  if __name__ == "__main__":
252
  import uvicorn
253
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ import transformers
2
+ import pickle
3
  import os
4
  import re
5
  import numpy as np
6
+ import torchvision
7
+ import nltk
8
  import torch
9
  import pandas as pd
10
  import requests
11
+ import zipfile
12
+ import tempfile
13
+ from openai import OpenAI
14
+ from PyPDF2 import PdfReader
15
  from fastapi import FastAPI, HTTPException
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel
18
  from transformers import (
19
  AutoTokenizer,
20
  AutoModelForSeq2SeqLM,
21
+ AutoModelForTokenClassification,
22
+ AutoModelForCausalLM,
23
+ pipeline,
24
+ Qwen2Tokenizer,
25
+ BartForConditionalGeneration
26
  )
27
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
28
  from sklearn.metrics.pairwise import cosine_similarity
29
  from bs4 import BeautifulSoup
30
  from huggingface_hub import hf_hub_download
31
+ from safetensors.torch import load_file
32
  from typing import List, Dict, Optional
33
+ from safetensors.numpy import load_file
34
+ from safetensors.torch import safe_open
35
+ nltk.download('punkt_tab')
36
+
37
 
38
  app = FastAPI()
39
  app.add_middleware(
 
43
  allow_methods=["*"],
44
  allow_headers=["*"],
45
  )
 
46
  models = {}
47
  data = {}
48
 
 
53
  class ChatQuery(BaseModel):
54
  query: str
55
  language_code: int = 1
56
+ #conversation_id: str
57
+
58
+ class ChatMessage(BaseModel):
59
+ role: str
60
+ content: str
61
+ timestamp: str
62
+
63
+ def init_nltk():
64
+ try:
65
+ nltk.download('punkt', quiet=True)
66
+ return True
67
+ except Exception as e:
68
+ print(f"Error initializing NLTK: {e}")
69
+ return False
70
+
71
 
72
  def get_completion(prompt: str, model: str = "deepseek/deepseek-prover-v2:free") -> str:
73
  api_key = os.environ.get('OPENROUTER_API_KEY')
 
111
 
112
  def load_models():
113
  try:
114
+ print("Loading models...")
115
  device = "cuda" if torch.cuda.is_available() else "cpu"
116
  print(f"Device set to use {device}")
 
 
117
  models['embedding_model'] = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
118
  models['cross_encoder'] = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
119
  models['semantic_model'] = SentenceTransformer('all-MiniLM-L6-v2')
 
121
  models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
122
  models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
123
  models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
124
+ models['att_tokenizer'] = AutoTokenizer.from_pretrained("facebook/bart-base")
125
+ models['att_model'] = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
126
+ models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
127
+ models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
128
+ models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
129
+ model_name = "M4-ai/Orca-2.0-Tau-1.8B"
130
+ models['llm_tokenizer'] = AutoTokenizer.from_pretrained(model_name)
131
+ models['llm_model'] = AutoModelForCausalLM.from_pretrained(model_name)
132
+ models['gen_tokenizer'] = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-1.7B-Instruct")
133
+ models['gen_model'] = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-1.7B-Instruct")
134
+ print("Models loaded successfully")
135
  return True
136
  except Exception as e:
137
+ print(f"Error loading models: {e}")
138
  return False
139
 
140
  def load_embeddings() -> Optional[Dict[str, np.ndarray]]:
 
168
  print(f"Error loading embeddings: {e}")
169
  return None
170
 
171
+ def normalize_key(key: str) -> str:
172
+ match = re.search(r'file_(\d+)', key)
173
+ if match:
174
+ return match.group(1)
175
+ return key
176
+
177
+ def load_recipes_embeddings() -> Optional[np.ndarray]:
178
+ try:
179
+ embeddings_path = 'recipes_embeddings.safetensors'
180
+ if not os.path.exists(embeddings_path):
181
+ print("File not found locally. Attempting to download from Hugging Face Hub...")
182
+ embeddings_path = hf_hub_download(
183
+ repo_id=os.environ.get('HF_SPACE_ID', 'thechaiexperiment/TeaRAG'),
184
+ filename="embeddings.safetensors",
185
+ repo_type="space"
186
+ )
187
+ embeddings = load_file(embeddings_path)
188
+ if "embeddings" not in embeddings:
189
+ raise ValueError("Key 'embeddings' not found in the safetensors file.")
190
+ tensor = embeddings["embeddings"]
191
+ print(f"Successfully loaded embeddings.")
192
+ print(f"Shape of embeddings: {tensor.shape}")
193
+ print(f"Dtype of embeddings: {tensor.dtype}")
194
+ print(f"First few values of the first embedding: {tensor[0][:5]}")
195
+ return tensor
196
+ except Exception as e:
197
+ print(f"Error loading embeddings: {e}")
198
+ return None
199
+
200
  def load_documents_data(folder_path='downloaded_articles/downloaded_articles'):
201
  try:
202
  print("Loading documents data...")
203
  if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
204
  print(f"Error: Folder '{folder_path}' not found")
205
  return False
 
206
  html_files = [f for f in os.listdir(folder_path) if f.endswith('.html')]
207
  if not html_files:
208
  print(f"No HTML files found in folder '{folder_path}'")
209
  return False
 
210
  documents = []
211
  for file_name in html_files:
212
  file_path = os.path.join(folder_path, file_name)
 
217
  documents.append({"file_name": file_name, "content": text})
218
  except Exception as e:
219
  print(f"Error reading file {file_name}: {e}")
220
+ data['df'] = pd.DataFrame(documents)
221
+ if data['df'].empty:
222
+ print("No valid documents loaded.")
223
+ return False
224
+ print(f"Successfully loaded {len(data['df'])} document records.")
225
+ return True
 
 
226
  except Exception as e:
227
  print(f"Error loading docs: {e}")
228
  return None
229
 
230
+ def load_data():
231
+ embeddings_success = load_embeddings()
232
+ documents_success = load_documents_data()
233
+ if not embeddings_success:
234
+ print("Warning: Failed to load embeddings, falling back to basic functionality")
235
+ if not documents_success:
236
+ print("Warning: Failed to load documents data, falling back to basic functionality")
237
+ return True
238
+
239
+ print("Initializing application...")
240
+ init_success = load_models() and load_data()
241
+
242
+
243
+ def translate_text(text, source_to_target='ar_to_en'):
244
+ try:
245
+ if source_to_target == 'ar_to_en':
246
+ tokenizer = models['ar_to_en_tokenizer']
247
+ model = models['ar_to_en_model']
248
+ else:
249
+ tokenizer = models['en_to_ar_tokenizer']
250
+ model = models['en_to_ar_model']
251
+ inputs = tokenizer(text, return_tensors="pt", truncation=True)
252
+ outputs = model.generate(**inputs)
253
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
254
+ except Exception as e:
255
+ print(f"Translation error: {e}")
256
+ return text
257
+
258
  def embed_query_text(query_text):
259
  embedding = models['embedding_model']
260
  query_embedding = embedding.encode([query_text])
 
275
  print(f"Error in query_embeddings: {e}")
276
  return []
277
 
278
+ def query_recipes_embeddings(query_embedding, embeddings_data, n_results):
279
+ embeddings_data = load_recipes_embeddings()
280
+ if embeddings_data is None:
281
+ print("No embeddings data available.")
282
+ return []
283
+ try:
284
+ if query_embedding.ndim == 1:
285
+ query_embedding = query_embedding.reshape(1, -1)
286
+ similarities = cosine_similarity(query_embedding, embeddings_data).flatten()
287
+ top_indices = similarities.argsort()[-n_results:][::-1]
288
+ return [(index, similarities[index]) for index in top_indices]
289
+ except Exception as e:
290
+ print(f"Error in query_recipes_embeddings: {e}")
291
+ return []
292
+
293
+ def get_page_title(url):
294
+ try:
295
+ response = requests.get(url)
296
+ if response.status_code == 200:
297
+ soup = BeautifulSoup(response.text, 'html.parser')
298
+ title = soup.find('title')
299
+ return title.get_text() if title else "No title found"
300
+ else:
301
+ return None
302
+ except requests.exceptions.RequestException:
303
+ return None
304
+
305
  def retrieve_document_texts(doc_ids, folder_path='downloaded_articles/downloaded_articles'):
306
  texts = []
307
  for doc_id in doc_ids:
 
320
  texts.append("")
321
  return texts
322
 
323
+ def retrieve_rec_texts(
324
+ document_indices,
325
+ folder_path='downloaded_articles/downloaded_articles',
326
+ metadata_path='recipes_metadata.xlsx'
327
+ ):
328
+ try:
329
+ metadata_df = pd.read_excel(metadata_path)
330
+ if "id" not in metadata_df.columns or "original_file_name" not in metadata_df.columns:
331
+ raise ValueError("Metadata file must contain 'id' and 'original_file_name' columns.")
332
+ metadata_df = metadata_df.sort_values(by="id").reset_index(drop=True)
333
+ if metadata_df.index.max() < max(document_indices):
334
+ raise ValueError("Some document indices exceed the range of metadata.")
335
+ document_texts = []
336
+ for idx in document_indices:
337
+ if idx >= len(metadata_df):
338
+ print(f"Warning: Index {idx} is out of range for metadata.")
339
+ continue
340
+ original_file_name = metadata_df.iloc[idx]["original_file_name"]
341
+ if not original_file_name:
342
+ print(f"Warning: No file name found for index {idx}")
343
+ continue
344
+ file_path = os.path.join(folder_path, original_file_name)
345
+ if os.path.exists(file_path):
346
+ with open(file_path, "r", encoding="utf-8") as f:
347
+ document_texts.append(f.read())
348
+ else:
349
+ print(f"Warning: File not found at {file_path}")
350
+ return document_texts
351
+ except Exception as e:
352
+ print(f"Error in retrieve_rec_texts: {e}")
353
+ return []
354
+
355
+ def retrieve_metadata(document_indices: List[int], metadata_path: str = 'recipes_metadata.xlsx') -> Dict[int, Dict[str, str]]:
356
+ try:
357
+ metadata_df = pd.read_excel(metadata_path)
358
+ required_columns = {'id', 'original_file_name', 'url'}
359
+ if not required_columns.issubset(metadata_df.columns):
360
+ raise ValueError(f"Metadata file must contain columns: {required_columns}")
361
+ metadata_df['id'] = metadata_df['id'].astype(int)
362
+ filtered_metadata = metadata_df[metadata_df['id'].isin(document_indices)]
363
+ metadata_dict = {
364
+ int(row['id']): {
365
+ "original_file_name": row['original_file_name'],
366
+ "url": row['url']
367
+ }
368
+ for _, row in filtered_metadata.iterrows()
369
+ }
370
+ return metadata_dict
371
+ except Exception as e:
372
+ print(f"Error retrieving metadata: {e}")
373
+ return {}
374
+
375
  def rerank_documents(query, document_ids, document_texts, cross_encoder_model):
376
  try:
377
  pairs = [(query, doc) for doc in document_texts]
 
386
  print(f"Error reranking documents: {e}")
387
  return []
388
 
389
+ def translate_ar_to_en(text):
390
  try:
391
+ ar_to_en_tokenizer = models['ar_to_en_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
392
+ ar_to_en_model= models['ar_to_en_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
393
+ inputs = ar_to_en_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
394
+ translated_ids = ar_to_en_model.generate(
395
+ inputs.input_ids,
396
+ max_length=512,
397
+ num_beams=4,
398
+ early_stopping=True
399
+ )
400
+ translated_text = ar_to_en_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
401
+ return translated_text
402
  except Exception as e:
403
+ print(f"Error during Arabic to English translation: {e}")
404
+ return None
405
+
406
+ def translate_en_to_ar(text):
407
+ try:
408
+ en_to_ar_tokenizer = models['en_to_ar_tokenizer'] = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
409
+ en_to_ar_model = models['en_to_ar_model'] = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
410
+ inputs = en_to_ar_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
411
+ translated_ids = en_to_ar_model.generate(
412
+ inputs.input_ids,
413
+ max_length=512,
414
+ num_beams=4,
415
+ early_stopping=True
416
+ )
417
+ translated_text = en_to_ar_tokenizer.decode(translated_ids[0], skip_special_tokens=True)
418
+ return translated_text
419
+ except Exception as e:
420
+ print(f"Error during English to Arabic translation: {e}")
421
+ return None
422
+
423
 
424
  @app.get("/")
425
  async def root():
 
436
  }
437
  return status
438
 
439
+ if not init_success:
440
+ print("Warning: Application initialized with partial functionality")
441
  if __name__ == "__main__":
442
  import uvicorn
443
  uvicorn.run(app, host="0.0.0.0", port=7860)