wakeupmh commited on
Commit
6f43c31
·
1 Parent(s): a226fb9
Files changed (2) hide show
  1. app.py +5 -3
  2. faiss_index/index.py +16 -24
app.py CHANGED
@@ -41,11 +41,13 @@ def load_rag_dataset(dataset_dir="rag_dataset"):
41
  dataset_path = os.path.join(dataset_dir, "dataset")
42
  index_path = os.path.join(dataset_dir, "embeddings.faiss")
43
 
 
 
 
44
  dataset = load_from_disk(dataset_path)
 
45
 
46
- # Add FAISS index back to dataset
47
- dataset.load_faiss_index('embeddings', index_path)
48
-
49
  return dataset, dataset_path, index_path
50
  except Exception as e:
51
  st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
 
41
  dataset_path = os.path.join(dataset_dir, "dataset")
42
  index_path = os.path.join(dataset_dir, "embeddings.faiss")
43
 
44
+ if not os.path.exists(dataset_path) or not os.path.exists(index_path):
45
+ raise ValueError("Dataset or index not found")
46
+
47
  dataset = load_from_disk(dataset_path)
48
+ index = faiss.read_index(index_path)
49
 
50
+ logging.info("Successfully loaded dataset and index")
 
 
51
  return dataset, dataset_path, index_path
52
  except Exception as e:
53
  st.error(f"Error loading dataset: {str(e)}\n{traceback.format_exc()}")
faiss_index/index.py CHANGED
@@ -6,7 +6,6 @@ import os
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
8
  import logging
9
- from datasets.utils.file_utils import DownloadConfig
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO)
@@ -40,25 +39,17 @@ def fetch_arxiv_papers(query, max_results=10):
40
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
41
  """Build and save dataset with FAISS index for RAG"""
42
  try:
43
- # Create dataset
44
- dataset = Dataset.from_dict({
45
- "id": [p["id"] for p in papers],
46
- "text": [p["text"] for p in papers],
47
- "title": [p["title"] for p in papers],
48
- })
49
-
50
- logging.info(f"Created dataset with {len(dataset)} papers")
51
-
52
  # Initialize DPR encoder
53
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
54
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
55
 
56
  # Create embeddings in batches
 
57
  embeddings = []
58
  batch_size = 8
59
 
60
- for i in range(0, len(dataset), batch_size):
61
- batch = dataset[i:i + batch_size]["text"]
62
  inputs = ctx_tokenizer(
63
  batch,
64
  max_length=512,
@@ -75,28 +66,29 @@ def build_faiss_index(papers, dataset_dir="rag_dataset"):
75
  embeddings = np.vstack(embeddings)
76
  logging.info(f"Created embeddings with shape {embeddings.shape}")
77
 
78
- # Create FAISS index
 
 
 
 
 
 
 
 
 
 
79
  dimension = embeddings.shape[1]
80
  index = faiss.IndexFlatL2(dimension)
81
  index.add(embeddings.astype(np.float32))
82
 
83
  # Save everything
84
  os.makedirs(dataset_dir, exist_ok=True)
85
-
86
- # Add embeddings to dataset
87
- dataset = dataset.add_faiss_index(
88
- column='embeddings',
89
- custom_index=index,
90
- device=0 if torch.cuda.is_available() else -1
91
- )
92
- dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
93
-
94
- # Save dataset and index
95
  dataset_path = os.path.join(dataset_dir, "dataset")
96
  index_path = os.path.join(dataset_dir, "embeddings.faiss")
97
 
 
98
  dataset.save_to_disk(dataset_path)
99
- dataset.get_index('embeddings').save(index_path)
100
 
101
  logging.info(f"Saved dataset to {dataset_path}")
102
  logging.info(f"Saved index to {index_path}")
 
6
  from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
  import torch
8
  import logging
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
 
39
  def build_faiss_index(papers, dataset_dir="rag_dataset"):
40
  """Build and save dataset with FAISS index for RAG"""
41
  try:
 
 
 
 
 
 
 
 
 
42
  # Initialize DPR encoder
43
  ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
44
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
45
 
46
  # Create embeddings in batches
47
+ texts = [p["text"] for p in papers]
48
  embeddings = []
49
  batch_size = 8
50
 
51
+ for i in range(0, len(texts), batch_size):
52
+ batch = texts[i:i + batch_size]
53
  inputs = ctx_tokenizer(
54
  batch,
55
  max_length=512,
 
66
  embeddings = np.vstack(embeddings)
67
  logging.info(f"Created embeddings with shape {embeddings.shape}")
68
 
69
+ # Create dataset with embeddings
70
+ dataset = Dataset.from_dict({
71
+ "id": [p["id"] for p in papers],
72
+ "text": [p["text"] for p in papers],
73
+ "title": [p["title"] for p in papers],
74
+ "embeddings": [emb.tolist() for emb in embeddings],
75
+ })
76
+
77
+ logging.info(f"Created dataset with {len(dataset)} papers")
78
+
79
+ # Create FAISS index from embeddings
80
  dimension = embeddings.shape[1]
81
  index = faiss.IndexFlatL2(dimension)
82
  index.add(embeddings.astype(np.float32))
83
 
84
  # Save everything
85
  os.makedirs(dataset_dir, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
86
  dataset_path = os.path.join(dataset_dir, "dataset")
87
  index_path = os.path.join(dataset_dir, "embeddings.faiss")
88
 
89
+ # Save dataset and index
90
  dataset.save_to_disk(dataset_path)
91
+ faiss.write_index(index, index_path)
92
 
93
  logging.info(f"Saved dataset to {dataset_path}")
94
  logging.info(f"Saved index to {index_path}")